# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [1]:
# !pip install transformers[torch] wandb torchmetrics datasets

In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, EarlyStoppingCallback
    )
import torchmetrics
import wandb

In [3]:
torch.cuda.empty_cache()

## Custom dataset

In [4]:
NUM_LABELS = 3
MAX_LENGTH = 128
# HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
# HUB_MODEL_CHECKPOINT = 'huggingface/distilbert-base-uncased-finetuned-mnli'
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
LOCAL_MODEL_CHECKPOINT = f'./{MODEL_NAME}-finetuned-snli/checkpoint-1250'

In [5]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-21d54e6470652178.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-b746e1998966e2f4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89fb34b79586ce05.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [6]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [7]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-3e1d4c5a625bb5e4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-15aa4b713d7ef0ac.arrow


  0%|          | 0/10 [00:00<?, ?ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [8]:
examples = dataset['train'][0:2]
examples

{'labels': tensor([1, 2]),
 'input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
           2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


## Build model

In [9]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [10]:
FREEZE_ENCODER = False

model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CHECKPOINT, num_labels=NUM_LABELS)
assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Original number of trainable params: 109M
Actual number of trainable params: 109484547


In [11]:
outputs = model(**examples)
outputs.logits

tensor([[ 0.0547, -0.3566, -0.7725],
        [ 0.0798, -0.3619, -0.8001]], grad_fn=<AddmmBackward0>)

## Experiments