# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [5]:
# !pip install torch transformers torchmetrics datasets wandb

In [70]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, 
    EarlyStoppingCallback,
    Trainer, TrainingArguments
    )
import torchmetrics
import wandb

## Custom dataset

In [6]:
NUM_LABELS = 3
MAX_LENGTH = 128
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
# LOCAL_MODEL_CHECKPOINT = f'./{MODEL_NAME}-finetuned-snli/checkpoint-XXX'

In [7]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Downloading builder script:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/938 [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Dataset snli downloaded and prepared to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [8]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [9]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [61]:
examples = dataset['train'][0:2]
examples

{'labels': tensor([1, 2]),
 'input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
           2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


## Build model

In [12]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [91]:
class BERTClassifier(torch.nn.Module):
    
    def __init__(self, model_checkpoint, num_labels=3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_checkpoint)
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.apply(self._init_weights)
        
    def forward(self, features):
        features = {k: v for k, v in features.items() if k in ['input_ids', 'token_type_ids', 'attention_mask']}
        embeddings = self.bert(**features).pooler_output ### CLS pooling
        return self.classifier(embeddings)
    
#     def _init_weights(self, module):
#         if isinstance(module, torch.nn.Linear):
#             module.weight.data.xavier_uniform(module.weight)
#             module.bias.data.fill_(0.01)
#             if module.bias is not None:
#                 module.bias.data.fill_(0.01)

In [92]:
model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT, num_labels=3)
model(examples)

loading configuration file https://huggingface.co/bert-base-uncased/resolve/main/config.json from cache at /home/ec2-user/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.37395cee442ab11005bcd270f3c34464dc1704b715b5d7d52b1a461abe3b9e4e
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.18.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file https://huggingface.co/bert-base-u

AttributeError: 'Tensor' object has no attribute 'xavier_uniform_'

In [83]:
FREEZE_ENCODER = False


assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

Original number of trainable params: 109M
Actual number of trainable params: 109484547


## Experiments

In [68]:
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

wandb.init(project=PROJECT_NAME)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

 ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ec2-user/.netrc


In [87]:
TRAIN_SAMPLES = 1000
EVAL_SAMPLES = 100
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 20
PER_DEVICE_TRAIN_BATCH_SIZE = 1
TRAIN_STEPS_PER_EPOCH = TRAIN_SAMPLES//TRAIN_BATCH_SIZE
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 3
LR = 2e-5
WEIGHT_DECAY = 0.01
SEED = 135

train_ds = dataset['train'].shuffle(seed=SEED).select(range(TRAIN_SAMPLES))
eval_ds = dataset['validation'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))


train_args = TrainingArguments(
    output_dir=PROJECT_NAME,
    logging_dir='./models/',
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit=3,
    learning_rate=LR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    num_train_epochs=MAX_EPOCHS,
    weight_decay=WEIGHT_DECAY,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    report_to='wandb',
    gradient_accumulation_steps=(TRAIN_BATCH_SIZE/PER_DEVICE_TRAIN_BATCH_SIZE),
#     gradient_checkpointing=True,
    fp16=True
)

def compute_metrics(eval_pred):
    metric = torchmetrics.functional.accuracy
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = metric(torch.tensor(predictions).to(torch.int32), torch.tensor(labels).to(torch.int32))
    return {'accuracy': accuracy}

trainer = Trainer(
    model,
    train_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.001)
    ]
)

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-29f948db4bb7851c.arrow
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-9ba45445327e9c76.arrow
PyTorch: setting up devices
Using amp half precision backend


Number of training steps per epoch: 62


In [88]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BERTClassifier.forward` and have been ignored: attention_mask, labels, input_ids, token_type_ids, premise, hypothesis. If attention_mask, labels, input_ids, token_type_ids, premise, hypothesis are not expected by `BERTClassifier.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 16.0
  Gradient Accumulation steps = 16.0
  Total optimization steps = 186
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


ZeroDivisionError: integer division or modulo by zero

In [87]:
trainer.evaluate(
    dataset['test'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))
    )

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-5c161f11fb2dc1f9.arrow
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: hypothesis, premise, token_type_ids. If hypothesis, premise, token_type_ids are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 50


{'eval_loss': 0.4518689811229706,
 'eval_accuracy': 0.8360000252723694,
 'eval_runtime': 2.8388,
 'eval_samples_per_second': 704.532,
 'eval_steps_per_second': 14.091,
 'epoch': 3.0}

In [88]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▅▁▂█
eval/loss,▁▄█▁
eval/runtime,▃▃█▁
eval/samples_per_second,▆▆▁█
eval/steps_per_second,▆▆▁█
train/epoch,▁▂▃▃▄▅▅▆▇███
train/global_step,▁▂▃▃▄▅▅▆▇███
train/learning_rate,█▇▆▅▃▂▁
train/loss,█▇▅▃▃▁▁
train/total_flos,▁

0,1
eval/accuracy,0.836
eval/loss,0.45187
eval/runtime,2.8388
eval/samples_per_second,704.532
eval/steps_per_second,14.091
train/epoch,3.0
train/global_step,3750.0
train/learning_rate,3e-05
train/loss,0.1887
train/total_flos,1987046415360000.0
