In [2]:
import os
import sys
import pandas as pd
import tqdm
import torch
import torch.nn as nn
from typing import Optional
from torch.utils.tensorboard import SummaryWriter
import transformers
from transformers import BartForConditionalGeneration, BartTokenizer,BartConfig,BartModel
import math
from typing import Union,Tuple
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers.models.bart.modeling_bart import BartScaledWordEmbedding,BartLearnedPositionalEmbedding,BartPreTrainedModel,BartEncoderLayer,BartAttention,BartEncoder
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa,_prepare_4d_attention_mask

In [1]:
#Hyperparameters
bsz=8
epoch=10
adam_eps=1e-5
weight_decay=0.01
learning_rate=1e-5
warmup=50
save_step=1000
save_pth='bart_distillition'
save_model_name='cnn_bart_encoder_'
dataset_pth='./cnn_dataset/'
val_length=100

In [4]:
class LinformerAttention(BartAttention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[BartConfig] = None,
        linformer_dim:Optional[int]=512,
    ):
        super().__init__(embed_dim=embed_dim,num_heads=num_heads,dropout=dropout,is_decoder=is_decoder,bias=bias,is_causal=is_causal,config=config)

        self.kv_former = nn.Linear(self.config.max_position_embeddings,linformer_dim,bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None

        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
        # is checking that the `sequence_length` of the `past_key_value` is the same as
        # the provided `key_value_states` to support prefix tuning


        hidden_states = self.kv_former(hidden_states.transpose(-1,-2)).transpose(-1,-2)

        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.reshape(*proj_shape)
        value_states = value_states.reshape(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
        
        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
                    f" {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
        # partitioned across GPUs when using tensor-parallelism.
        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value
class LinformerEncoderLayer(BartEncoderLayer):
    def __init__(self, config: BartConfig):
        super().__init__(config=config)
        self.embed_dim = config.d_model

        self.self_attn = LinformerAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
            linformer_dim=512,
        )
class LinformerEncoder(BartPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`BartEncoderLayer`].

    Args:
        config: BartConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.embed_tokens = BartScaledWordEmbedding(
            config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
        )

        if embed_tokens is not None:
            self.embed_tokens.weight = embed_tokens.weight

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        self.layers = nn.ModuleList([LinformerEncoderLayer(config) for _ in range(config.encoder_layers)])
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        distil_layer:Optional[int]=-1,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input = input_ids
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        elif inputs_embeds is not None:
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

            embed_pos = self.embed_positions(input)
            embed_pos = embed_pos.to(inputs_embeds.device)

            hidden_states = inputs_embeds + embed_pos
            hidden_states = self.layernorm_embedding(hidden_states)
        else: # Pankaj This is for distillition training, as input_embedding is already having positional encoding in it.
            hidden_states=inputs_embeds
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # expand attention_mask
        if attention_mask is not None:
            if self._use_flash_attention_2:
                attention_mask = attention_mask if 0 in attention_mask else None
            elif self._use_sdpa and head_mask is None and not output_attentions:
                # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
                # the manual implementation that requires a 4D causal mask in all cases.
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
            else:
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    (head_mask[idx] if head_mask is not None else None),
                    output_attentions,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]
            if distil_layer==idx:
                return hidden_states
                
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )
    
class BartDataset(torch.utils.data.Dataset):
    def __init__(self,dataset,tokenizer,config):
        self.dataset=dataset
        self.tokenizer=tokenizer
        self.config=config

    def __len__(self):
        return len(self.dataset)
    def add_prefix(self,text):
        prefix = "SUMMARIZE NEWS : "
        return prefix+text
    def __getitem__(self,idx):
        data = self.dataset.iloc[idx]
        text,summary = data.article,data.highlights
        text = self.add_prefix(text)
        inputs = self.tokenizer(text,max_length=self.config.max_position_embeddings,padding='max_length',truncation=True,return_tensors='pt')
        label = self.tokenizer.encode(summary,max_length=256,padding='max_length',truncation=True,return_tensors='pt')
        input_ids = inputs.input_ids
        attention_mask=inputs.attention_mask

        return dict(
            input_ids =input_ids.squeeze(0),
            attention_mask=attention_mask.squeeze(0),
            label=label.squeeze(0),
            text=text,
            summary=summary
        )

In [5]:
cnn_large_model = BartEncoder.from_pretrained("facebook/bart-large-cnn").to(device)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

torch.save(cnn_large_model.state_dict(),'large_cnn.pt')
lin_conf = BartConfig.from_pretrained("facebook/bart-large-cnn")
lin_conf._attn_implementation="eager"
lin_model = LinformerEncoder(lin_conf).to(device)
torch.save(lin_model.state_dict(),'linformer_cnn.pt')

Some weights of BartEncoder were not initialized from the model checkpoint at facebook/bart-large-cnn and are newly initialized: ['embed_positions.weight', 'embed_tokens.weight', 'layernorm_embedding.bias', 'layernorm_embedding.weight', 'layers.0.fc1.bias', 'layers.0.fc1.weight', 'layers.0.fc2.bias', 'layers.0.fc2.weight', 'layers.0.final_layer_norm.bias', 'layers.0.final_layer_norm.weight', 'layers.0.self_attn.k_proj.bias', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.out_proj.bias', 'layers.0.self_attn.out_proj.weight', 'layers.0.self_attn.q_proj.bias', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.v_proj.bias', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn_layer_norm.bias', 'layers.0.self_attn_layer_norm.weight', 'layers.1.fc1.bias', 'layers.1.fc1.weight', 'layers.1.fc2.bias', 'layers.1.fc2.weight', 'layers.1.final_layer_norm.bias', 'layers.1.final_layer_norm.weight', 'layers.1.self_attn.k_proj.bias', 'layers.1.self_attn.k_proj.weight', 'layers.1.self_

In [15]:
lin_model.load_state_dict(torch.load('linformer_cnn.pt'))

<All keys matched successfully>

In [7]:
train_df = pd.read_csv(os.path.join(dataset_pth,"train.csv")) 
train_ds = BartDataset(train_df,tokenizer,lin_conf)
train_loader = torch.utils.data.DataLoader(train_ds,batch_size=bsz,shuffle=True)
total_steps = epoch*len(train_loader)
print(f"Training for Total : {total_steps}")

del train_df


Training for Total : 358900


In [None]:
optimizer=torch.optim.Adam(lin_model.parameters(),lr=learning_rate,weight_decay=weight_decay,eps=adam_eps)
scheduler = transformers.get_cosine_schedule_with_warmup(optimizer,num_warmup_steps=warmup,num_training_steps=total_steps)
loss_fn = torch.nn.L1Loss()
# loss_fn = torch.nn.KLDivLoss()

In [None]:
layers=1
for layer in range(layers):
    total_loss=0
    for ep in range(epoch):
        lin_model.train(True)
        loader = tqdm.tqdm(train_loader)
        loader.set_description(f"Layer : {layer} | Epoch : {ep}")

        for idx,data in enumerate(loader):
            input_ids = data['input_ids'].to(device)
            with torch.no_grad():
                label_out = cnn_large_model(input_ids,attention_mask=None,output_hidden_states=True,output_attentions=True)
            dist_inp = label_out.hidden_states[layer]
            dist_out = label_out.hidden_states[layer + 1]
            
            optimizer.zero_grad()
            
            output = lin_model(inputs_embeds=dist_inp,distil_layer=layer)

            loss = loss_fn(dist_out,output)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss+=loss.item()
            loader.set_postfix({'Loss',loss.item()})

        torch.save(dict(
            model=lin_model.state_dict(),
            loss=total_loss/len(train_loader),
            epoch=ep,
            layer=layer,

        ),f'lin_model_dist_l{layer}.pt')
            


0 0.5360766649246216
1 0.5356727838516235
2 0.5355427265167236
3 0.5350778102874756
4 0.5347800254821777
5 0.5346702337265015
6 0.5343059301376343
7 0.5340139865875244
8 0.5335469841957092
9 0.5337418913841248
10 0.53326416015625


In [16]:
lin_model.layers[0].self_attn.kv_former.state_dict()

OrderedDict([('weight',
              tensor([[ 1.8826e-02, -2.8632e-02, -1.4102e-02,  ..., -8.6234e-03,
                       -6.2434e-02, -4.5047e-03],
                      [ 2.2725e-03, -1.2984e-02, -1.0146e-02,  ..., -4.4377e-02,
                       -1.1882e-03,  8.9085e-03],
                      [-5.8181e-03, -3.8661e-02, -2.2369e-02,  ...,  5.7755e-03,
                       -2.5276e-02, -2.0668e-03],
                      ...,
                      [ 4.7489e-02, -4.0045e-03,  4.5967e-03,  ..., -7.3304e-03,
                       -1.7578e-03, -1.2105e-02],
                      [-3.0868e-02,  9.6113e-03,  2.7596e-02,  ...,  9.0568e-05,
                        6.7902e-03,  3.7463e-03],
                      [ 7.6572e-03,  9.9940e-03, -2.5892e-02,  ...,  1.2484e-02,
                       -1.8995e-02, -2.5772e-02]], device='cuda:0'))])

In [11]:
torch.load('linformer_cnn.pt')['layers.0.self_attn.kv_former.weight']

tensor([[ 1.9033e-02, -2.8631e-02, -1.4096e-02,  ..., -8.8318e-03,
         -6.2435e-02, -4.5005e-03],
        [ 2.4507e-03, -1.2979e-02, -1.0137e-02,  ..., -4.4594e-02,
         -1.1195e-03,  9.1001e-03],
        [-5.8138e-03, -3.8660e-02, -2.2366e-02,  ...,  5.7585e-03,
         -2.5275e-02, -2.0509e-03],
        ...,
        [ 4.7698e-02, -3.9829e-03,  4.7976e-03,  ..., -7.5373e-03,
         -1.7084e-03, -1.2113e-02],
        [-3.0878e-02,  9.8195e-03,  2.7811e-02,  ..., -5.9353e-05,
          6.9939e-03,  3.9239e-03],
        [ 7.8544e-03,  1.0202e-02, -2.5889e-02,  ...,  1.2477e-02,
         -1.8992e-02, -2.5794e-02]], device='cuda:0')