<a href="https://colab.research.google.com/github/pramodith/llm_exploration/blob/bert_sparse_attention_training/bert_sparse_attention_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%pip install transformers
%pip install datasets
%pip install accelerate -U
%pip install scikit-learn
%pip install overrides

Collecting overrides
  Downloading overrides-7.4.0-py3-none-any.whl (17 kB)
Installing collected packages: overrides
Successfully installed overrides-7.4.0


In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset, load_metric
from transformers import TrainingArguments, Trainer
import torch

from transformers import BertModel, BertForSequenceClassification
from transformers.models.bert.modeling_bert import BertEncoder, logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from typing import Optional, Tuple, Union, List
from overrides import overrides


In [17]:
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
sample_text = "Neetu Neetu Neetu"

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Load the dataset and metric
dataset = load_dataset('glue', 'sst2')
metric = load_metric('glue', 'sst2')

# Split the dataset
train_dataset = dataset['train']
dev_dataset = dataset['validation']
test_dataset = dataset['test']

# Print a description of the dataset
print("Dataset Description: ", train_dataset.description)

# Print the label space
print("Label Space: ", train_dataset.features["label"].names)

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

  metric = load_metric('glue', 'sst2')


Downloading builder script:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

Dataset Description:  GLUE, the General Language Understanding Evaluation benchmark
(https://gluebenchmark.com/) is a collection of resources for training,
evaluating, and analyzing natural language understanding systems.


Label Space:  ['negative', 'positive']


In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
train_dataset

Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 67349
})

In [28]:
from typing import List
import torch
from transformers import AutoTokenizer

def custom_tokenize(tokenizer: AutoTokenizer, text: str, batch_mode=False):
    # Tokenize the texts
    result = tokenizer(text, truncation=True, padding=False)
    # Create attention mask with ones on the main diagonal
    attention_mask = torch.eye(len(result["input_ids"]))

    # Update attention mask for the specified neighborhood distance
    distance = 2
    attention_mask[abs(torch.arange(len(attention_mask))[:, None] - torch.arange(len(attention_mask))) <= distance] = 1

    # Set the first row to 1 corresponding to the CLS token
    attention_mask[0, :] = 1
    # Set the last row to 1 corresponding to the SEP token
    attention_mask[-1, :] = 1
    # Add the attention mask to the result

    if batch_mode:
      result["attention_mask"] = attention_mask.unsqueeze(0)
      result["input_ids"] = torch.LongTensor(result["input_ids"]).unsqueeze(0)
      result["token_type_ids"] = torch.LongTensor(result["token_type_ids"]).unsqueeze(0)
    else:
      result["attention_mask"] = attention_mask
      result["input_ids"] = torch.LongTensor(result["input_ids"])
      result["token_type_ids"] = torch.LongTensor(result["token_type_ids"])
    # Map the labels to the tokenized inputs
    return result

In [None]:
from torch.nn.functional import pad
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch, pad_token_id):
  input_ids = [torch.LongTensor(batch[i]["input_ids"]) for i in range(len(batch))]
  attention_mask = [torch.LongTensor(batch[i]["attention_mask"]) for i in range(len(batch))]
  token_type_ids = [torch.LongTensor(batch[i]["token_type_ids"]) for i in range(len(batch))]
  label = [batch[i]["label"] for i in range(len(batch))]
  #idx = [batch[i]["idx"] for i in range(len(batch))]
  max_len = max([len(inp) for inp in input_ids])
  padding_sizes = [max_len - len(inp) for inp in input_ids]
  input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
  token_type_ids = pad_sequence(token_type_ids, batch_first=True)
  attention_mask = [pad(attention_mask[i], (0, padding_sizes[i], 0, padding_sizes[i]), value=0) for i in range(len(batch))]
  attention_mask = torch.stack(attention_mask).squeeze(1)
  return {
      "input_ids": input_ids,
      "attention_mask": attention_mask,
      "token_type_ids": token_type_ids,
      "labels": torch.tensor(label),
  }

In [None]:
# Encode the datasets
train_dataset_dense_attention = train_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)
dev_dataset_dense_attention = dev_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)
test_dataset_dense_attention = test_dataset.map(lambda example: tokenizer(example['sentence'], truncation=True, padding=False), batched=True)



train_dataset_sparse_attention = train_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)
dev_dataset_sparse_attention = dev_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)
test_dataset_sparse_attention = test_dataset.map(lambda example: custom_tokenize(tokenizer, example["sentence"]), batched=False)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [None]:
train_dataset_sparse_attention[0]

{'sentence': 'hide new secretions from the parental units ',
 'label': 0,
 'idx': 0,
 'input_ids': [101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 'attention_mask': [[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
   [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
   [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
   [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
   [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]]}

In [None]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,            # total number of training steps
    per_device_train_batch_size=64,  # batch size per device during training
    per_device_eval_batch_size=128,   # batch size for evaluation
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    fp16=True,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    logging_steps=100
)

# Initialize the trainer
trainer_dense_attention = Trainer(
    model=model,
    tokenizer=tokenizer,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset_dense_attention,         # training dataset
    eval_dataset=dev_dataset_dense_attention,       # evaluation dataset
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer, padding=True),
    compute_metrics=compute_metrics
)


In [None]:
trainer_dense_attention.train()

Step,Training Loss,Validation Loss,Accuracy
100,0.1139,0.280868,0.915138
200,0.2527,0.227146,0.912844
300,0.2266,0.238781,0.909404
400,0.2216,0.237614,0.912844
500,0.1966,0.262964,0.885321
600,0.1837,0.217739,0.919725
700,0.1836,0.220342,0.917431
800,0.1809,0.201191,0.926606
900,0.1689,0.210163,0.923165
1000,0.1574,0.214784,0.920872




TrainOutput(global_step=3159, training_loss=0.11348400654842598, metrics={'train_runtime': 315.8455, 'train_samples_per_second': 639.702, 'train_steps_per_second': 10.002, 'total_flos': 4633893920893020.0, 'train_loss': 0.11348400654842598, 'epoch': 3.0})

In [None]:
trainer_dense_attention.evaluate(dev_dataset_dense_attention)

{'eval_loss': 0.21478354930877686,
 'eval_accuracy': 0.9208715596330275,
 'eval_runtime': 0.3051,
 'eval_samples_per_second': 2858.042,
 'eval_steps_per_second': 22.943,
 'epoch': 3.0}

In [None]:
from functools import partial
trainer_sparse_attention = Trainer(
    model=model,
    tokenizer=tokenizer,                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset_sparse_attention,         # training dataset
    eval_dataset=dev_dataset_sparse_attention,           # evaluation dataset
    data_collator=partial(custom_collate, pad_token_id=tokenizer.pad_token_id),
    compute_metrics=compute_metrics
  )

trainer_sparse_attention.train()



Step,Training Loss,Validation Loss,Accuracy
100,0.1953,0.356762,0.855505
200,0.203,0.373309,0.849771
300,0.1671,0.316258,0.873853
400,0.1373,0.418967,0.848624
500,0.1423,0.30723,0.870413
600,0.1307,0.457298,0.848624
700,0.1283,0.301373,0.892202
800,0.1241,0.302129,0.893349
900,0.1106,0.353116,0.887615
1000,0.128,0.306217,0.889908




TrainOutput(global_step=3159, training_loss=0.10449925590519966, metrics={'train_runtime': 349.7689, 'train_samples_per_second': 577.659, 'train_steps_per_second': 9.032, 'total_flos': 4633893920893020.0, 'train_loss': 0.10449925590519966, 'epoch': 3.0})

In [None]:
trainer_sparse_attention.evaluate(dev_dataset_sparse_attention)

{'eval_loss': 0.28912436962127686,
 'eval_accuracy': 0.893348623853211,
 'eval_runtime': 0.6541,
 'eval_samples_per_second': 1333.109,
 'eval_steps_per_second': 10.702,
 'epoch': 3.0}

In [None]:
metric

Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:

    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(res

In [5]:
class CustomBertEncoder(BertEncoder):
    def __init__(self, config):
        super().__init__(config)

    @overrides
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        attention_mask_is_layerwise = False
        if attention_mask.shape[0] == 4:
            attention_mask_is_layerwise = True

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if attention_mask_is_layerwise:
               attention_mask_to_use = attention_mask[i]
            else:
                attention_mask_to_use = attention_mask

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask_to_use,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask_to_use,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class CustomBertModel(BertModel):
    def __init__(self, config):
        super().__init__(config)
        self.encoder = CustomBertEncoder(config)

class CustomBertForSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.bert = CustomBertModel(config)



In [44]:
test_custom_model = CustomBertForSequenceClassification.from_pretrained(model_name)
inputs = custom_tokenize(tokenizer, sample_text, batch_mode=True)
output = test_custom_model(**inputs, output_attentions=True)
print(f"Attention score for Layer 0, Batch 0, Head 0 {output.attentions[0][0,0]}")
print(f"Attention score for Layer 1, Batch 0, Head 0 {output.attentions[1][0,0]}")

inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0).repeat(12,1,1,1)
output_layerwise_mask = test_custom_model(**inputs, output_attentions=True)
assert torch.allclose(output.attentions[0], output_layerwise_mask.attentions[0]), "Attention scores have to be the same for the same attention mask even if it is repeated layerwise"

Some weights of CustomBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Attention score for Layer 0, Batch 0, Head 0 tensor([[0.0883, 0.0511, 0.0975, 0.0484, 0.1103, 0.0534, 0.1214, 0.4296],
        [0.1546, 0.2441, 0.3183, 0.2830, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0794, 0.2734, 0.1752, 0.2829, 0.1892, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.1726, 0.2420, 0.1859, 0.2291, 0.1704, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1666, 0.2510, 0.1681, 0.2556, 0.1587, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.1897, 0.2229, 0.1610, 0.2120, 0.2145],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2667, 0.3852, 0.2432, 0.1049],
        [0.1801, 0.0651, 0.0971, 0.0677, 0.1166, 0.0731, 0.1155, 0.2849]],
       grad_fn=<SelectBackward0>)
Attention score for Layer 1, Batch 0, Head 0 tensor([[0.6098, 0.0170, 0.0526, 0.0216, 0.0703, 0.0349, 0.1029, 0.0908],
        [0.9450, 0.0092, 0.0428, 0.0030, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7919, 0.0171, 0.0759, 0.0160, 0.0990, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0250, 0.4966, 0.0107, 0.4445, 0.0233, 0.00

ValueError: ignored