# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [70]:
# !pip install torch transformers wandb torchmetrics datasets

In [71]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, 
    TrainingArguments, Trainer, EarlyStoppingCallback
    )
import torchmetrics
import wandb

## Custom dataset

In [72]:
NUM_LABELS = 3
MAX_LENGTH = 128
#HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
HUB_MODEL_CHECKPOINT = 'huggingface/distilbert-base-uncased-finetuned-mnli'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
LOCAL_MODEL_CHECKPOINT = f'./{MODEL_NAME}-finetuned-snli/checkpoint-1250'

In [73]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-21d54e6470652178.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-b746e1998966e2f4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89fb34b79586ce05.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [74]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

loading configuration file https://huggingface.co/huggingface/distilbert-base-uncased-finetuned-mnli/resolve/main/config.json from cache at /home/ec2-user/.cache/huggingface/transformers/240bd330b0e7919215436efe944c4073bfcc0bac4b7ed0a3378ab3d1793beb1a.40731de3fff94b7bb8465819755f4978a9a082bcb78e7caa728178dab1b68f86
Model config DistilBertConfig {
  "_name_or_path": "huggingface/distilbert-base-uncased-finetuned-mnli",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "dim": 768,
  "dropout": 0.1,
  "eos_token_ids": 0,
  "finetuning_task": "mnli",
  "hidden_dim": 3072,
  "id2label": {
    "0": "contradiction",
    "1": "entailment",
    "2": "neutral"
  },
  "initializer_range": 0.02,
  "label2id": {
    "contradiction": "0",
    "entailment": "1",
    "neutral": "2"
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "p

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [75]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-72d845b338b48fbd.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-3353feacf67e3c85.arrow


  0%|          | 0/10 [00:00<?, ?ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [76]:
example = dataset['train'][0]
example

{'labels': tensor(1),
 'input_ids': tensor([  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
          2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
          2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,  

## Build model

In [77]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [78]:
FREEZE_ENCODER = False

model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CHECKPOINT, num_labels=NUM_LABELS)
assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

loading configuration file https://huggingface.co/huggingface/distilbert-base-uncased-finetuned-mnli/resolve/main/config.json from cache at /home/ec2-user/.cache/huggingface/transformers/240bd330b0e7919215436efe944c4073bfcc0bac4b7ed0a3378ab3d1793beb1a.40731de3fff94b7bb8465819755f4978a9a082bcb78e7caa728178dab1b68f86
Model config DistilBertConfig {
  "_name_or_path": "huggingface/distilbert-base-uncased-finetuned-mnli",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "dim": 768,
  "dropout": 0.1,
  "eos_token_ids": 0,
  "finetuning_task": "mnli",
  "hidden_dim": 3072,
  "id2label": {
    "0": "contradiction",
    "1": "entailment",
    "2": "neutral"
  },
  "initializer_range": 0.02,
  "label2id": {
    "contradiction": "0",
    "entailment": "1",
    "neutral": "2"
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "p

Original number of trainable params: 67M
Actual number of trainable params: 66955779


## Experiments

In [84]:
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

wandb.init(project=PROJECT_NAME)

In [85]:
TRAIN_SAMPLES = 20000
EVAL_SAMPLES = 2000
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 50
PER_DEVICE_TRAIN_BATCH_SIZE = 1
TRAIN_STEPS_PER_EPOCH = TRAIN_SAMPLES//TRAIN_BATCH_SIZE
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 6
LR = 5e-5
WEIGHT_DECAY = 0.01
SEED = 123

train_ds = dataset['train'].shuffle(seed=SEED).select(range(TRAIN_SAMPLES))
eval_ds = dataset['validation'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))


train_args = TrainingArguments(
    output_dir=PROJECT_NAME,
    logging_dir='./models/',
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit=3,
    learning_rate=LR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    num_train_epochs=MAX_EPOCHS,
    weight_decay=WEIGHT_DECAY,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    report_to='wandb',
    gradient_accumulation_steps=(TRAIN_BATCH_SIZE/PER_DEVICE_TRAIN_BATCH_SIZE),
#     gradient_checkpointing=True,
    fp16=True
)

def compute_metrics(eval_pred):
    metric = torchmetrics.functional.accuracy
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = metric(torch.tensor(predictions).to(torch.int32), torch.tensor(labels).to(torch.int32))
    return {'accuracy': accuracy}

trainer = Trainer(
    model,
    train_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.001)
    ]
)

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-80118d143176e1dd.arrow
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-65d09f29d5beb625.arrow
PyTorch: setting up devices
Using amp half precision backend


Number of training steps per epoch: 1250


In [86]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: hypothesis, premise, token_type_ids. If hypothesis, premise, token_type_ids are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 20000
  Num Epochs = 6
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 16.0
  Gradient Accumulation steps = 16.0
  Total optimization steps = 7500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Accuracy
1,0.4668,0.450834,0.8275
2,0.2988,0.589267,0.815
3,0.1887,0.767445,0.819


The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: hypothesis, premise, token_type_ids. If hypothesis, premise, token_type_ids are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 50
Saving model checkpoint to distilbert-base-uncased-finetuned-mnli-finetuned-snli/checkpoint-1250
Configuration saved in distilbert-base-uncased-finetuned-mnli-finetuned-snli/checkpoint-1250/config.json
Model weights saved in distilbert-base-uncased-finetuned-mnli-finetuned-snli/checkpoint-1250/pytorch_model.bin
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: hypothesis, premise, token_type_ids. If hypothesis, premise, token_type_ids are not expected by `DistilBertForSequenceClassificati

TrainOutput(global_step=3750, training_loss=0.3213387013753255, metrics={'train_runtime': 1312.4725, 'train_samples_per_second': 91.43, 'train_steps_per_second': 5.714, 'total_flos': 1987046415360000.0, 'train_loss': 0.3213387013753255, 'epoch': 3.0})

In [87]:
trainer.evaluate(
    dataset['test'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))
    )

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-5c161f11fb2dc1f9.arrow
The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: hypothesis, premise, token_type_ids. If hypothesis, premise, token_type_ids are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 50


{'eval_loss': 0.4518689811229706,
 'eval_accuracy': 0.8360000252723694,
 'eval_runtime': 2.8388,
 'eval_samples_per_second': 704.532,
 'eval_steps_per_second': 14.091,
 'epoch': 3.0}

In [88]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▅▁▂█
eval/loss,▁▄█▁
eval/runtime,▃▃█▁
eval/samples_per_second,▆▆▁█
eval/steps_per_second,▆▆▁█
train/epoch,▁▂▃▃▄▅▅▆▇███
train/global_step,▁▂▃▃▄▅▅▆▇███
train/learning_rate,█▇▆▅▃▂▁
train/loss,█▇▅▃▃▁▁
train/total_flos,▁

0,1
eval/accuracy,0.836
eval/loss,0.45187
eval/runtime,2.8388
eval/samples_per_second,704.532
eval/steps_per_second,14.091
train/epoch,3.0
train/global_step,3750.0
train/learning_rate,3e-05
train/loss,0.1887
train/total_flos,1987046415360000.0
