<a href="https://colab.research.google.com/github/pramodith/llm_exploration/blob/bert_sparse_attention_training/bert_sparse_attention_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [45]:
%pip install transformers
%pip install datasets
%pip install accelerate -U
%pip install scikit-learn
%pip install overrides



In [46]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset, load_metric
from transformers import TrainingArguments, Trainer
import torch

from transformers import BertModel, BertForSequenceClassification
from transformers.models.bert.modeling_bert import BertEncoder, logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import ModuleUtilsMixin, warnings
from typing import Optional, Tuple, Union, List
from overrides import overrides
from torch import Tensor


In [47]:
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
sample_text = "Neetu Neetu Neetu"

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:
# Load the dataset and metric
dataset = load_dataset('glue', 'sst2')
metric = load_metric('glue', 'sst2')

# Split the dataset
train_dataset = dataset['train']
dev_dataset = dataset['validation']
test_dataset = dataset['test']

# Print a description of the dataset
print("Dataset Description: ", train_dataset.description)

# Print the label space
print("Label Space: ", train_dataset.features["label"].names)

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

  metric = load_metric('glue', 'sst2')


Downloading builder script:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

Dataset Description:  GLUE, the General Language Understanding Evaluation benchmark
(https://gluebenchmark.com/) is a collection of resources for training,
evaluating, and analyzing natural language understanding systems.


Label Space:  ['negative', 'positive']


In [49]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [50]:
train_dataset

Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 67349
})

In [101]:
class CustomBertEncoder(BertEncoder):
    def __init__(self, config):
        super().__init__(config)

    @overrides
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        attention_mask_is_layerwise = False
        if attention_mask.dim() == 5:
            attention_mask_is_layerwise = True

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if attention_mask_is_layerwise:
               attention_mask.shape
               attention_mask_to_use = attention_mask[i]
            else:
                attention_mask_to_use = attention_mask

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask_to_use,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask_to_use,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = CustomBertEncoder(config)

    @overrides
    def get_extended_attention_mask(
        self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        """
        if dtype is None:
            dtype = self.dtype

        if not (attention_mask.dim() == 2 and self.config.is_decoder):
            # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
            if device is not None:
                warnings.warn(
                    "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
                )
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 4:
            extended_attention_mask = attention_mask[:, :, None, :, :]
        elif attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
                    input_shape, attention_mask, device
                )
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and the dtype's smallest value for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
        return extended_attention_mask

    @overrides
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        r"""
        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        if token_type_ids is None:
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.bert = CustomBertModel(config)



In [102]:
test_custom_model = CustomBertForSequenceClassification.from_pretrained(model_name)
inputs = custom_tokenize(tokenizer, sample_text, batch_mode=True)
output = test_custom_model(**inputs, output_attentions=True)
print(f"Attention score for Layer 0, Batch 0, Head 0 {output.attentions[0][0,0]}")
print(f"Attention score for Layer 1, Batch 0, Head 0 {output.attentions[1][0,0]}")

inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0).repeat(12,1,1,1)
output_layerwise_mask = test_custom_model(**inputs, output_attentions=True)
assert torch.allclose(output.attentions[0], output_layerwise_mask.attentions[0]), "Attention scores have to be the same for the same attention mask even if it is repeated layerwise"

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Attention score for Layer 0, Batch 0, Head 0 tensor([[0.0883, 0.0511, 0.0975, 0.0484, 0.1103, 0.0534, 0.1214, 0.4296],
        [0.1546, 0.2441, 0.3183, 0.2830, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0794, 0.2734, 0.1752, 0.2829, 0.1892, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1726, 0.2420, 0.1859, 0.2291, 0.1704, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1666, 0.2510, 0.1681, 0.2556, 0.1587, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1897, 0.2229, 0.1610, 0.2120, 0.2145],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.3852, 0.2432, 0.1049],
        [0.1801, 0.0651, 0.0971, 0.0677, 0.1166, 0.0731, 0.1155, 0.2849]],
       grad_fn=<SelectBackward0>)
Attention score for Layer 1, Batch 0, Head 0 tensor([[0.6098, 0.0170, 0.0526, 0.0216, 0.0703, 0.0349, 0.1029, 0.0908],
        [0.9450, 0.0092, 0.0428, 0.0030, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7919, 0.0171, 0.0759, 0.0160, 0.0990, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0250, 0.4966, 0.0107, 0.4445, 0.0233, 0.00

In [51]:
from typing import List
import torch
from transformers import AutoTokenizer

def custom_tokenize(tokenizer: AutoTokenizer, text: str, batch_mode=False):
    # Tokenize the texts
    result = tokenizer(text, truncation=True, padding=False)
    # Create attention mask with ones on the main diagonal
    attention_mask = torch.eye(len(result["input_ids"]))

    # Update attention mask for the specified neighborhood distance
    distance = 2
    attention_mask[abs(torch.arange(len(attention_mask))[:, None] - torch.arange(len(attention_mask))) <= distance] = 1

    # Set the first row to 1 corresponding to the CLS token
    attention_mask[0, :] = 1
    # Set the last row to 1 corresponding to the SEP token
    attention_mask[-1, :] = 1
    # Add the attention mask to the result

    if batch_mode:
      result["attention_mask"] = attention_mask.unsqueeze(0)
      result["input_ids"] = torch.LongTensor(result["input_ids"]).unsqueeze(0)
      result["token_type_ids"] = torch.LongTensor(result["token_type_ids"]).unsqueeze(0)
    else:
      result["attention_mask"] = attention_mask
      result["input_ids"] = torch.LongTensor(result["input_ids"])
      result["token_type_ids"] = torch.LongTensor(result["token_type_ids"])
    # Map the labels to the tokenized inputs
    return result

In [120]:
def get_layerwise_attention_mask(attention_mask: Tensor, pad_token_pos, total_num_layers=12, full_attention_layers=[]):
  attention_mask = attention_mask.unsqueeze(0)
  if full_attention_layers:
    attention_mask = attention_mask.repeat(total_num_layers, 1, 1, 1)
    for layer_num in full_attention_layers:
      attention_mask[layer_num, :, : pad_token_pos, : pad_token_pos] = 1
  return attention_mask

In [67]:
dummy_attn_mask = torch.zeros(1, 4, 4)
layerwise_attn_mask = get_layerwise_attention_mask(dummy_attn_mask, full_attention_layers=[2,4])
assert layerwise_attn_mask.shape == (12, 1, 4, 4), f"{layerwise_attn_mask.shape} is the wrong shape"
assert torch.allclose(layerwise_attn_mask[0], dummy_attn_mask), "Masks at layer 0 should all be 0s"
assert torch.allclose(layerwise_attn_mask[2], torch.ones(1,4,4)), "Masks at layer 1 should all be 1s"

In [115]:
from torch.nn.functional import pad
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch, pad_token_id, full_attention_layers=[]):
  input_ids = [torch.LongTensor(batch[i]["input_ids"]) for i in range(len(batch))]
  attention_mask = [torch.LongTensor(batch[i]["attention_mask"]) for i in range(len(batch))]
  token_type_ids = [torch.LongTensor(batch[i]["token_type_ids"]) for i in range(len(batch))]
  label = [batch[i]["label"] for i in range(len(batch))]
  #idx = [batch[i]["idx"] for i in range(len(batch))]
  max_len = max([len(inp) for inp in input_ids])
  padding_sizes = [max_len - len(inp) for inp in input_ids]
  input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
  token_type_ids = pad_sequence(token_type_ids, batch_first=True)

  attention_mask = [get_layerwise_attention_mask(pad(attention_mask[i], (0, padding_sizes[i], 0, padding_sizes[i]), value=0), padding_sizes[i], full_attention_layers=full_attention_layers).squeeze(1) for i in range(len(batch))]
  attention_mask = torch.stack(attention_mask)
  if full_attention_layers:
    attention_mask = attention_mask.permute(1,0,2,3)
  else:
    attention_mask = attention_mask.squeeze(1)
  return {
      "input_ids": input_ids,
      "attention_mask": attention_mask,
      "token_type_ids": token_type_ids,
      "labels": torch.tensor(label),
  }

In [77]:
a = torch.zeros(8,12,3,3)
a.permute(1,0,2,3).shape

torch.Size([12, 8, 3, 3])

In [70]:
# Encode the datasets
train_dataset_dense_attention = train_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)
dev_dataset_dense_attention = dev_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)
test_dataset_dense_attention = test_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)



train_dataset_sparse_attention = train_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)
dev_dataset_sparse_attention = dev_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)
test_dataset_sparse_attention = test_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

In [71]:
train_dataset_sparse_attention[0]

{'sentence': 'hide new secretions from the parental units ',
 'label': 0,
 'idx': 0,
 'input_ids': [101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
  [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]}

In [80]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,            # total number of training steps
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=128,   # batch size for evaluation
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    fp16=True,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    logging_steps=100
)

# Initialize the trainer
trainer_dense_attention = Trainer(
    model=model,
    tokenizer=tokenizer,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset_dense_attention,         # training dataset
    eval_dataset=dev_dataset_dense_attention,       # evaluation dataset
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer, padding=True),
    compute_metrics=compute_metrics
)


In [None]:
trainer_dense_attention.train()

Step,Training Loss,Validation Loss,Accuracy
100,0.1139,0.280868,0.915138
200,0.2527,0.227146,0.912844
300,0.2266,0.238781,0.909404
400,0.2216,0.237614,0.912844
500,0.1966,0.262964,0.885321
600,0.1837,0.217739,0.919725
700,0.1836,0.220342,0.917431
800,0.1809,0.201191,0.926606
900,0.1689,0.210163,0.923165
1000,0.1574,0.214784,0.920872




TrainOutput(global_step=3159, training_loss=0.11348400654842598, metrics={'train_runtime': 315.8455, 'train_samples_per_second': 639.702, 'train_steps_per_second': 10.002, 'total_flos': 4633893920893020.0, 'train_loss': 0.11348400654842598, 'epoch': 3.0})

In [None]:
trainer_dense_attention.evaluate(dev_dataset_dense_attention)

{'eval_loss': 0.21478354930877686,
 'eval_accuracy': 0.9208715596330275,
 'eval_runtime': 0.3051,
 'eval_samples_per_second': 2858.042,
 'eval_steps_per_second': 22.943,
 'epoch': 3.0}

In [100]:
from functools import partial
trainer_sparse_attention = Trainer(
    model=model,
    tokenizer=tokenizer,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset_sparse_attention,         # training dataset
    eval_dataset=dev_dataset_sparse_attention,           # evaluation dataset
    data_collator=partial(custom_collate, pad_token_id=tokenizer.pad_token_id),
    compute_metrics=compute_metrics
  )

trainer_sparse_attention.train()



Step,Training Loss,Validation Loss,Accuracy
100,0.6784,0.648368,0.610092
200,0.4876,0.434287,0.798165
300,0.3618,0.553588,0.794725
400,0.3034,0.468317,0.818807
500,0.2818,0.419138,0.824541
600,0.2567,0.308851,0.876147
700,0.2334,0.302835,0.881881
800,0.23,0.300463,0.87844
900,0.2151,0.328031,0.877294
1000,0.2107,0.310702,0.883028




TrainOutput(global_step=3159, training_loss=0.17626571685462558, metrics={'train_runtime': 354.9609, 'train_samples_per_second': 569.209, 'train_steps_per_second': 8.9, 'total_flos': 4633893920893020.0, 'train_loss': 0.17626571685462558, 'epoch': 3.0})

In [111]:
trainer_sparse_attention.evaluate(dev_dataset_sparse_attention)

Step,Training Loss,Validation Loss,Accuracy
100,0.1319,0.282347,0.893349
167,0.1319,0.298096,0.889908


{'eval_loss': 0.2980963885784149, 'eval_accuracy': 0.8899082568807339}

In [122]:
from functools import partial
test_custom_model = CustomBertForSequenceClassification.from_pretrained(model_name)

trainer_sparse_layerwise_attention = Trainer(
    model=test_custom_model,
    tokenizer=tokenizer,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset_sparse_attention,         # training dataset
    eval_dataset=dev_dataset_sparse_attention,           # evaluation dataset
    data_collator=partial(custom_collate, pad_token_id=tokenizer.pad_token_id, full_attention_layers=[0,1,2,3]),
    compute_metrics=compute_metrics
  )

trainer_sparse_layerwise_attention.train()

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy
100,0.6423,0.520197,0.75
200,0.3953,0.381379,0.823394
300,0.3292,0.438466,0.833716
400,0.2873,0.481485,0.826835
500,0.2598,0.34701,0.847477
600,0.2344,0.301658,0.880734
700,0.2321,0.251139,0.90367
800,0.2141,0.28462,0.884174
900,0.2075,0.291587,0.879587
1000,0.1953,0.34716,0.876147




TrainOutput(global_step=3159, training_loss=0.164511976909245, metrics={'train_runtime': 397.7391, 'train_samples_per_second': 507.989, 'train_steps_per_second': 7.942, 'total_flos': 4633893920893020.0, 'train_loss': 0.164511976909245, 'epoch': 3.0})

In [123]:
trainer_sparse_layerwise_attention.evaluate(dev_dataset_sparse_attention)

{'eval_loss': 0.2847082018852234,
 'eval_accuracy': 0.8990825688073395,
 'eval_runtime': 0.8041,
 'eval_samples_per_second': 1084.416,
 'eval_steps_per_second': 8.705,
 'epoch': 3.0}

In [None]:
metric

Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:

    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(res