# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [8]:
# !pip install torch transformers wandb torchmetrics datasets
# !pip install accelerate nvidia-ml-py3

In [24]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, 
    TrainingArguments, Trainer, EarlyStoppingCallback
    )
import torchmetrics
import wandb

## Custom dataset

In [11]:
MAX_LENGTH = 128*2
HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]

In [12]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Downloading builder script:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/938 [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Dataset snli downloaded and prepared to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [13]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'])

Downloading:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/728 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

{'input_ids': [1, 250, 621, 15, 10, 5253, 13855, 81, 10, 3187, 159, 16847, 4, 2, 250, 621, 16, 1058, 39, 5253, 13, 10, 1465, 4, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [14]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [15]:
example = dataset['train'][0]
example

{'labels': tensor(1),
 'input_ids': tensor([    1,   250,   621,    15,    10,  5253, 13855,    81,    10,  3187,
           159, 16847,     4,     2,   250,   621,    16,  1058,    39,  5253,
            13,    10,  1465,     4,     2,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,  

## Build model

In [16]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [23]:
# LOCAL_MODEL_CHECKPOINT = './deberta-base-mnli-finetuned-snli/checkpoint-189'

model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CHECKPOINT)
assert model.num_labels == 3, 'The number of labels should be 3 for a RTE task'
print(f'Original number of trainable params: {get_number_of_trainable_params(model)}')

for name, param in model.named_parameters():
    if not name.startswith('classifier'):
        param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

loading configuration file https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/config.json from cache at /home/ec2-user/.cache/huggingface/transformers/f7b31c39c192044791f5fdbf3d688249c69e527477a29901c7b1c3529cfd2d2b.486b7fcfb74d817138771852e9a12ae2309a5895b952e819a646622b0a75ecc0
Model config DebertaConfig {
  "_name_or_path": "microsoft/deberta-base-mnli",
  "architectures": [
    "DebertaForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "CONTRADICTION",
    "1": "NEUTRAL",
    "2": "ENTAILMENT"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "CONTRADICTION": 0,
    "ENTAILMENT": 2,
    "NEUTRAL": 1
  },
  "layer_norm_eps": 1e-07,
  "max_position_embeddings": 512,
  "max_relative_positions": -1,
  "model_type": "deberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_dropout": 

Original number of trainable params: 139194627
Actual number of trainable params: 2307


## Experiments

In [69]:
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

wandb.init(project=PROJECT_NAME)

In [84]:
TRAIN_SAMPLES = 100000
EVAL_SAMPLES = 2000
BATCH_SIZE = 32
PER_DEVICE_TRAIN_BATCH_SIZE = 2
TRAIN_STEPS_PER_EPOCH = TRAIN_SAMPLES//BATCH_SIZE
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 3
LR = 1e-3
WEIGHT_DECAY = 0.01

train_ds = dataset['train'].select(np.random.randint(0, dataset['train'].num_rows, TRAIN_SAMPLES).tolist())
eval_ds = dataset['validation'].select(np.random.randint(0, dataset['validation'].num_rows, TRAIN_SAMPLES).tolist())


train_args = TrainingArguments(
    output_dir=PROJECT_NAME,
    evaluation_strategy = "steps",
#     eval_steps=(TRAIN_STEPS_PER_EPOCH//EVAL_STEPS_PER_EPOCH),
    save_strategy = "steps",
#     save_total_limit=TRAIN_STEPS_PER_EPOCH//1000,
    learning_rate=LR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=MAX_EPOCHS,
    weight_decay=WEIGHT_DECAY,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    report_to='wandb',
    gradient_accumulation_steps=(BATCH_SIZE/PER_DEVICE_TRAIN_BATCH_SIZE),
    fp16=True
)

def compute_metrics(eval_pred):
    metric = torchmetrics.functional.accuracy
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = metric(torch.tensor(predictions).to(torch.int32), torch.tensor(labels).to(torch.int32))
    return {'accuracy': accuracy}

trainer = Trainer(
    model,
    train_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3)
    ]
)

Number of training steps per epoch: 3125


using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
Using amp half precision backend


In [None]:
trainer.train()

***** Running training *****
  Num examples = 100000
  Num Epochs = 3
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 32.0
  Gradient Accumulation steps = 16.0
  Total optimization steps = 9375
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss,Accuracy
500,0.4277,0.352182,0.87063
1000,0.4334,0.349204,0.87113
1500,0.4269,0.36665,0.8696
2000,0.4367,0.347999,0.87297
2500,0.4326,0.347488,0.87383
3000,0.4141,0.35391,0.87167
3500,0.4247,0.346398,0.87264


***** Running Evaluation *****
  Num examples = 100000
  Batch size = 32


In [None]:
trainer.evaluate(
    dataset['test']
    )

In [None]:
wandb.finish()