# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [117]:
# !pip install transformers[torch] accelerate datasets wandb

In [125]:
import os
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, 
    AdamW, get_scheduler
#     EarlyStoppingCallback,
#     Trainer, TrainingArguments
    )
from accelerate import Accelerator
from tqdm import tqdm
import wandb

## Custom dataset

In [5]:
NUM_LABELS = 3
MAX_LENGTH = 128
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
# LOCAL_MODEL_CHECKPOINT = f'./{MODEL_NAME}-finetuned-snli/checkpoint-XXX'

In [6]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Downloading builder script:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/938 [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Dataset snli downloaded and prepared to /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [7]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [8]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [9]:
examples = dataset['train'][0:2]
examples

{'labels': tensor([1, 2]),
 'input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
           2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


## Build model

In [10]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [153]:
class BERTClassifier(torch.nn.Module):
    
    def __init__(self, model_checkpoint, num_labels=3):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_checkpoint)
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
#         self.apply(self._init_weights)
        
    def forward(self, features):
        features = {k: v for k, v in features.items() if k in ['input_ids', 'token_type_ids', 'attention_mask']}
        embeddings = self.bert(**features).pooler_output ### CLS pooling
        return self.classifier(embeddings)

In [154]:
# model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
# model(examples)

In [13]:
FREEZE_ENCODER = False


assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

Original number of trainable params: 109M
Actual number of trainable params: 109484547


## Trainer

In [155]:
def configure_optimizer(model, lr, eps, warmup_steps):
    optimizer = AdamW(model.parameters(), lr=lr, eps=eps, correct_bias=False)
    scheduler = get_scheduler('constant', optimizer, num_warmup_steps=warmup_steps)
    return optimizer, scheduler

def compute_accuracy(predictions, labels):
    predictions = predictions.argmax(axis=1)
    correct = (predictions == labels).float()
    accuracy = correct.mean()
    return {'accuracy': accuracy, 'correct': correct}

In [174]:
def train(model, dataloader, optimizer, scheduler, loss_fn):
    size = len(dataloader.dataset)
    epoch_loss = 0
    epoch_correct = 0
    model.train()
    for index, batch in tqdm(enumerate(training_dataloader)):
        optimizer.zero_grad() # clear gradients first
        labels = batch['labels']
        predictions = model(batch)
        loss = loss_fn(predictions, labels)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        
        with torch.no_grad():
            metrics = compute_accuracy(predictions, labels)
            correct = metrics['correct'].cpu().numpy().sum()
            epoch_loss += loss.item()
            epoch_correct += correct
            
    print({'train_loss': round(epoch_loss/size, 5), 'train_accuracy': epoch_correct/size})

## Experiments

In [175]:
TRAIN_SAMPLES = 20
EVAL_SAMPLES = 10
PER_DEVICE_TRAIN_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 2
TRAIN_BATCH_SIZE = GRADIENT_ACCUMULATION_STEPS * PER_DEVICE_TRAIN_BATCH_SIZE
EVAL_BATCH_SIZE = 2
TRAIN_STEPS_PER_EPOCH = math.ceil(TRAIN_SAMPLES/TRAIN_BATCH_SIZE)
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 2
LR = 2e-5
EPS = 1e-6
WEIGHT_DECAY = 0.01
WARMUP_PERCENT = 0.2
TOTAL_STEPS = MAX_EPOCHS * TRAIN_STEPS_PER_EPOCH
print(f'Total training steps: {TOTAL_STEPS}')
WARMUP_STEPS = int(TOTAL_STEPS*WARMUP_PERCENT)
print(f'Warmup steps: {WARMUP_STEPS}')
SEED = 135

Number of training steps per epoch: 10
Total training steps: 20
Warmup steps: 4


In [176]:
def train_loop():
    accelerator = Accelerator()

    train_ds = dataset['train'].shuffle(seed=SEED).select(range(TRAIN_SAMPLES))
    eval_ds = dataset['validation'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))
    train_dataloader = DataLoader(train_ds, num_workers=os.cpu_count(), batch_size=TRAIN_BATCH_SIZE, shuffle=True)
    eval_dataloader = DataLoader(eval_ds, num_workers=os.cpu_count(), batch_size=EVAL_BATCH_SIZE, shuffle=False)

    model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
    optimizer, lr_scheduler = configure_optimizer(model, LR, EPS, WARMUP_STEPS)
    loss_fn = torch.nn.CrossEntropyLoss()

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    
    for epoch in range(MAX_EPOCHS):

        train(model, train_dataloader, optimizer, lr_scheduler, loss_fn)

In [177]:
train_loop()

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-29f948db4bb7851c.arrow
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-9ba45445327e9c76.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model fr

{'train_loss': 0.55879, 'train_accuracy': 0.45}


10it [00:01,  7.87it/s]

{'train_loss': 0.43805, 'train_accuracy': 0.6}





In [152]:
# PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

# wandb.init(project=PROJECT_NAME)

In [119]:
# wandb.finish()