# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [15]:
# !pip install transformers[torch] accelerate datasets wandb

In [1]:
import os
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, 
    AdamW, get_scheduler
#     EarlyStoppingCallback,
#     Trainer, TrainingArguments
    )
from accelerate import Accelerator
from accelerate.utils import set_seed
from tqdm import tqdm
import wandb

## Custom dataset

In [2]:
NUM_LABELS = 3
MAX_LENGTH = 128
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
# LOCAL_MODEL_CHECKPOINT = f'./{MODEL_NAME}-finetuned-snli/checkpoint-XXX'

In [3]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-21d54e6470652178.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-b746e1998966e2f4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89fb34b79586ce05.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [4]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [5]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-3e1d4c5a625bb5e4.arrow


  0%|          | 0/550 [00:00<?, ?ba/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-96104474684b9b7a.arrow


dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [6]:
examples = dataset['train'][0:2]
examples

{'labels': tensor([1, 2]),
 'input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
           2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


## Build model

In [7]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [8]:
class BERTClassifier(torch.nn.Module):
    
    def __init__(self, model_checkpoint, num_labels=3):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_checkpoint)
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
#         self.apply(self._init_weights)
        
    def forward(self, features):
        features = {k: v for k, v in features.items() if k in ['input_ids', 'token_type_ids', 'attention_mask']}
        embeddings = self.bert(**features).pooler_output ### CLS pooling
        return self.classifier(embeddings)

In [9]:
model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
model(examples)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[ 0.5563,  0.0554, -0.6191],
        [ 0.5499,  0.0641, -0.6162]], grad_fn=<AddmmBackward0>)

In [10]:
FREEZE_ENCODER = False


assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

Original number of trainable params: 109M
Actual number of trainable params: 109484547


## Trainer

In [78]:
def configure_optimizer(model, lr, eps, warmup_steps, weight_decay, num_training_steps):
    optimizer = AdamW(model.parameters(), lr=lr, eps=eps, correct_bias=False, weight_decay=weight_decay)
    scheduler = get_scheduler('linear', optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
    return optimizer, scheduler

def compute_metrics(predictions, labels):
    predictions = predictions.argmax(axis=1)
    correct = (predictions == labels).float()
    accuracy = correct.mean()
    return {'accuracy': accuracy, 'correct': correct}

In [79]:
def train(accelerator, model, dataloader, optimizer, scheduler, loss_fn):
    epoch_size = 0
    epoch_loss = 0
    epoch_correct = 0
    model.train()
    for batch in tqdm(dataloader):        
        optimizer.zero_grad() # clear gradients first
        labels = batch['labels']
        batch_size = labels.shape[0]
        predictions = model(batch)
        loss = loss_fn(predictions, labels)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        
        with torch.no_grad():
            metrics = compute_metrics(accelerator.gather(predictions), accelerator.gather(labels))
            correct = metrics['correct'].cpu().numpy().sum()
            epoch_loss += loss.item()
            epoch_correct += correct
            epoch_size += batch_size
            
    return {'train_loss': round(epoch_loss/epoch_size, 5), 'train_accuracy': epoch_correct/epoch_size, 'lr': scheduler.get_last_lr()}

In [80]:
def evaluate(accelerator, model, dataloader, loss_fn):
    epoch_loss = 0
    epoch_correct = 0
    epoch_size = 0
    model.eval()
    for batch in tqdm(dataloader):   
        labels = batch['labels']
        batch_size = labels.shape[0]
        with torch.no_grad():
            predictions = model(batch)
        loss = loss_fn(predictions, labels)
        metrics = compute_metrics(accelerator.gather(predictions), accelerator.gather(labels))
        correct = metrics['correct'].cpu().numpy().sum()
        epoch_loss += loss.item()
        epoch_correct += correct
        epoch_size += batch_size
    return {'eval_loss': round(epoch_loss/epoch_size, 5), 'eval_accuracy': epoch_correct/epoch_size}

## Experiments

In [87]:
FREEZE_ENCODER = True
TRAIN_SAMPLES = 10000
EVAL_SAMPLES = 1000
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 16
TRAIN_BATCH_SIZE = GRADIENT_ACCUMULATION_STEPS * PER_DEVICE_TRAIN_BATCH_SIZE
print(f'Effective training batch size: {TRAIN_BATCH_SIZE}')
EVAL_BATCH_SIZE = 100
TRAIN_STEPS_PER_EPOCH = math.ceil(TRAIN_SAMPLES/TRAIN_BATCH_SIZE)
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 3
print(f'Max number of epochs: {MAX_EPOCHS}')
LR = 2e-5
EPS = 1e-6
WEIGHT_DECAY = 0.01
WARMUP_PERCENT = 0.2
TOTAL_STEPS = MAX_EPOCHS * TRAIN_STEPS_PER_EPOCH
print(f'Total training steps: {TOTAL_STEPS}')
WARMUP_STEPS = int(TOTAL_STEPS*WARMUP_PERCENT)
print(f'Warmup steps: {WARMUP_STEPS}')

Effective training batch size: 32
Number of training steps per epoch: 313
Max number of epochs: 3
Total training steps: 939
Warmup steps: 187


In [88]:
def train_loop(seed=42, mixed_precision='fp16'):
    set_seed(seed)
    accelerator = Accelerator(mixed_precision=mixed_precision)

    train_ds = dataset['train'].shuffle(seed=SEED).select(range(TRAIN_SAMPLES))
    eval_ds = dataset['validation'].shuffle(seed=SEED).select(range(EVAL_SAMPLES))
    train_dataloader = DataLoader(train_ds, num_workers=os.cpu_count(), batch_size=TRAIN_BATCH_SIZE, shuffle=True)
    eval_dataloader = DataLoader(eval_ds, num_workers=os.cpu_count(), batch_size=EVAL_BATCH_SIZE, shuffle=False)

    model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
    if FREEZE_ENCODER:
        for name, param in model.named_parameters():
            if not name.startswith('classifier'):
                param.requires_grad = False
    
    optimizer, lr_scheduler = configure_optimizer(model, LR, EPS, WARMUP_STEPS, WEIGHT_DECAY, TOTAL_STEPS)
    loss_fn = torch.nn.CrossEntropyLoss()

    model, optimizer, train_dataloader, lr_scheduler, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler, eval_dataloader
    )

    for epoch in range(MAX_EPOCHS):
        train_metrics = train(accelerator, model, train_dataloader, optimizer, lr_scheduler, loss_fn)
        eval_metrics = evaluate(accelerator, model, eval_dataloader, loss_fn)
        epoch_metrics = {**train_metrics, **eval_metrics}
        accelerator.print(f"epoch {epoch}: {epoch_metrics}")

In [89]:
train_loop()

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c87be39ba90012f8.arrow
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-9ba45445327e9c76.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model fr

epoch 0: {'train_loss': 0.02406, 'train_accuracy': 0.6576, 'lr': [1.670212765957447e-05], 'eval_loss': 0.00537, 'eval_accuracy': 0.78}


100%|██████████| 313/313 [01:17<00:00,  4.02it/s]
100%|██████████| 10/10 [00:03<00:00,  3.33it/s]


epoch 1: {'train_loss': 0.01304, 'train_accuracy': 0.8472, 'lr': [8.377659574468086e-06], 'eval_loss': 0.00516, 'eval_accuracy': 0.812}


100%|██████████| 313/313 [01:16<00:00,  4.10it/s]
100%|██████████| 10/10 [00:02<00:00,  3.60it/s]

epoch 2: {'train_loss': 0.00696, 'train_accuracy': 0.9267, 'lr': [5.319148936170213e-08], 'eval_loss': 0.0059, 'eval_accuracy': 0.806}





In [152]:
# PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

# wandb.init(project=PROJECT_NAME)

In [119]:
# wandb.finish()