# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for fine-tuning on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

Executed on AWS SageMaker `ml.g4dn.2xlarge` GPU instance

## Setup

In [2]:
# !pip install transformers[torch] accelerate datasets wandb

In [2]:
import os
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, 
    AdamW, get_scheduler,
#     EarlyStoppingCallback,
#     Trainer, TrainingArguments
    )
from accelerate import Accelerator
from accelerate.utils import set_seed
from tqdm import tqdm
import wandb

## Custom dataset

In [20]:
NUM_LABELS = 3
MAX_LENGTH = 128
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'
LOCAL_MODEL_CHECKPOINT = f'./models/{PROJECT_NAME}/checkpoint.pt'

In [54]:
LABELS  = ["entailment", "neutral", "contradiction"]

In [4]:
dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-21d54e6470652178.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-b746e1998966e2f4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89fb34b79586ce05.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [5]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'], return_token_type_ids=True)

{'input_ids': [101, 1037, 2711, 2006, 1037, 3586, 14523, 2058, 1037, 3714, 2091, 13297, 1012, 102, 1037, 2711, 2003, 2731, 2010, 3586, 2005, 1037, 2971, 1012, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [6]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-3e1d4c5a625bb5e4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-15aa4b713d7ef0ac.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-5ec5b04d410c1bcb.arrow


dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [7]:
examples = dataset['train'][0:2]
examples

{'labels': tensor([1, 2]),
 'input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
           2005,  1037,  2971,  1012,   102,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


## Build model

In [10]:
def get_number_of_trainable_params(model):
    return np.sum(np.array([p.numel() for p in model.parameters() if p.requires_grad]))

In [11]:
class BERTClassifier(torch.nn.Module):
    
    def __init__(self, model_checkpoint, num_labels=3):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_checkpoint)
        self.num_labels = num_labels
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        
    def forward(self, features):
        features = {k: v for k, v in features.items() if k in ['input_ids', 'token_type_ids', 'attention_mask']}
        embeddings = self.bert(**features).pooler_output ### CLS pooling
        return self.classifier(embeddings)

In [12]:
model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
model(examples)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[ 0.4356,  0.1097, -0.0216],
        [ 0.4466,  0.0991, -0.0077]], grad_fn=<AddmmBackward0>)

In [13]:
FREEZE_ENCODER = False


assert model.num_labels == NUM_LABELS, f'The number of labels should be {NUM_LABELS}'
print(f'Original number of trainable params: {round(get_number_of_trainable_params(model)/1_000_000)}M')

if FREEZE_ENCODER:
    for name, param in model.named_parameters():
        if not name.startswith('classifier'):
            param.requires_grad = False

print(f'Actual number of trainable params: {get_number_of_trainable_params(model)}')

Original number of trainable params: 109M
Actual number of trainable params: 109484547


## Trainer

In [14]:
def configure_optimizer(model, lr, eps, warmup_steps, weight_decay, num_training_steps):
    optimizer = AdamW(model.parameters(), lr=lr, eps=eps, correct_bias=False, weight_decay=weight_decay)
    scheduler = get_scheduler('linear', optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
#     scheduler = get_scheduler('constant', optimizer, num_warmup_steps=warmup_steps)
    return optimizer, scheduler

def compute_metrics(predictions, labels):
    predictions = predictions.argmax(axis=1)
    correct = (predictions == labels).float()
    accuracy = correct.mean()
    return {'accuracy': accuracy, 'correct': correct}

In [15]:
def train_loop(accelerator, model, dataloader, optimizer, scheduler, loss_fn):
    epoch_size = 0
    epoch_loss = 0
    epoch_correct = 0
    model.train()
    for batch in tqdm(dataloader):        
        optimizer.zero_grad() # clear gradients first
        labels = batch['labels']
        batch_size = labels.shape[0]
        predictions = model(batch)
        loss = loss_fn(predictions, labels)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        
        with torch.no_grad():
            metrics = compute_metrics(accelerator.gather(predictions), accelerator.gather(labels))
            correct = metrics['correct'].cpu().numpy().sum()
            epoch_loss += loss.item()
            epoch_correct += correct
            epoch_size += batch_size
            
    return {
        'train_loss': round(epoch_loss/epoch_size, 5), 
        'train_accuracy': epoch_correct/epoch_size, 
        'learning_rate': scheduler.get_last_lr()[0]
    }

In [16]:
def eval_loop(accelerator, model, dataloader, loss_fn):
    epoch_loss = 0
    epoch_correct = 0
    epoch_size = 0
    model.eval()
    for batch in tqdm(dataloader):   
        labels = batch['labels']
        batch_size = labels.shape[0]
        with torch.no_grad():
            predictions = model(batch)
        loss = loss_fn(predictions, labels)
        metrics = compute_metrics(accelerator.gather(predictions), accelerator.gather(labels))
        correct = metrics['correct'].cpu().numpy().sum()
        epoch_loss += loss.item()
        epoch_correct += correct
        epoch_size += batch_size
    return {
        'eval_loss': round(epoch_loss/epoch_size, 5), 
        'eval_accuracy': epoch_correct/epoch_size
    }

In [17]:
def test_loop(accelerator, model, dataloader):
    epoch_correct = 0
    epoch_size = 0
    model.eval()
    for batch in tqdm(dataloader):   
        labels = batch['labels']
        batch_size = labels.shape[0]
        with torch.no_grad():
            predictions = model(batch)
        metrics = compute_metrics(accelerator.gather(predictions), accelerator.gather(labels))
        correct = metrics['correct'].cpu().numpy().sum()
        epoch_correct += correct
        epoch_size += batch_size
    return {
        'test_accuracy': epoch_correct/epoch_size
    }

In [21]:
class EarlyStoppingCallback:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    
    def __init__(self, patience=3, verbose=False, delta=0.0, path=f'models/{PROJECT_NAME}/checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        
    def __call__(self, val_loss, model, optimizer):

        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer)
        elif score > self.best_score - self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, optimizer):
        '''Saves model when validation loss decrease.'''
        save_dir = self.path.split('checkpoint.pt')[0][:-1]
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': {'eval_loss': val_loss},
        }, self.path)
        self.val_loss_min = val_loss

In [22]:
# callback = EarlyStoppingCallback(verbose=True, patience=3, delta=0.0001)
# callback(0.02733, model, optimizer)
# print(callback.best_score)
# print(callback.counter)
# callback.early_stop

In [23]:
# checkpoint = torch.load(f'./models/{PROJECT_NAME}/checkpoint.pt')
# print(checkpoint.keys())
# checkpoint['loss']

## Experiments

In [24]:
FREEZE_ENCODER = False
MIXED_PRECISION = 'fp16'
TRAIN_SAMPLES = 90000
EVAL_SAMPLES = 9000
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 16
TRAIN_BATCH_SIZE = GRADIENT_ACCUMULATION_STEPS * PER_DEVICE_TRAIN_BATCH_SIZE
print(f'Effective training batch size: {TRAIN_BATCH_SIZE}')
EVAL_BATCH_SIZE = 90
TRAIN_STEPS_PER_EPOCH = math.ceil(TRAIN_SAMPLES/TRAIN_BATCH_SIZE)
print(f'Number of training steps per epoch: {TRAIN_STEPS_PER_EPOCH}')
MAX_EPOCHS = 10
print(f'Max number of epochs: {MAX_EPOCHS}')
LR = 2e-5
EPS = 1e-6
WEIGHT_DECAY = 1e-4
WARMUP_PERCENT = 0.1
TOTAL_STEPS = MAX_EPOCHS * TRAIN_STEPS_PER_EPOCH
print(f'Total training steps: {TOTAL_STEPS}')
WARMUP_STEPS = int(TOTAL_STEPS*WARMUP_PERCENT)
print(f'Warmup steps: {WARMUP_STEPS}')

Effective training batch size: 32
Number of training steps per epoch: 2813
Max number of epochs: 10
Total training steps: 28130
Warmup steps: 2813


In [25]:
def epoch_loop(seed=42, mixed_precision='fp16'):
    set_seed(seed)
    accelerator = Accelerator(mixed_precision=mixed_precision)

    train_ds = dataset['train'].shuffle(seed=seed).select(range(TRAIN_SAMPLES))
    eval_ds = dataset['validation'].shuffle(seed=seed).select(range(EVAL_SAMPLES))
    train_dataloader = DataLoader(train_ds, num_workers=os.cpu_count(), batch_size=TRAIN_BATCH_SIZE, shuffle=True)
    eval_dataloader = DataLoader(eval_ds, num_workers=os.cpu_count(), batch_size=EVAL_BATCH_SIZE, shuffle=False)

    model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
    if FREEZE_ENCODER:
        for name, param in model.named_parameters():
            if not name.startswith('classifier'):
                param.requires_grad = False
    
    optimizer, lr_scheduler = configure_optimizer(model, LR, EPS, WARMUP_STEPS, WEIGHT_DECAY, TOTAL_STEPS)
    loss_fn = torch.nn.CrossEntropyLoss()
    early_stopping_callback = EarlyStoppingCallback(patience=3, delta=1e-5, verbose=True)

    model, optimizer, train_dataloader, lr_scheduler, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler, eval_dataloader
    )

    for epoch in range(MAX_EPOCHS):
        train_metrics = train_loop(accelerator, model, train_dataloader, optimizer, lr_scheduler, loss_fn)
        eval_metrics = eval_loop(accelerator, model, eval_dataloader, loss_fn)
        epoch_metrics = {**train_metrics, **eval_metrics}
        accelerator.print(f"epoch {epoch}: {epoch_metrics}")
        wandb.log(epoch_metrics)
        early_stopping_callback(eval_metrics['eval_loss'], model, optimizer)
        if early_stopping_callback.early_stop:
            print("Early stopping at epoch:", epoch)
            break
            
    test_ds = dataset['test'].shuffle(seed=42).select(range(EVAL_SAMPLES))
    test_dataloader = DataLoader(test_ds, num_workers=os.cpu_count(), batch_size=EVAL_BATCH_SIZE, shuffle=False)
    test_metrics = test_loop(accelerator, model, eval_dataloader)
    accelerator.print(f"Test metrics: {test_metrics}")
    wandb.log(test_metrics)

In [239]:
wandb.init(project=PROJECT_NAME)

epoch_loop(seed=42, mixed_precision=MIXED_PRECISION)

Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-1c0f125df70dcab0.arrow
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-a88183f3ca55bbac.arrow
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model fr

epoch 0: {'train_loss': 0.01866, 'train_accuracy': 0.7485444444444445, 'learning_rate': 1.9985780305723426e-05, 'eval_loss': 0.00438, 'eval_accuracy': 0.8541111111111112}
Validation loss decreased (inf --> 0.004380).  Saving model ...


100%|██████████| 2813/2813 [10:59<00:00,  4.27it/s]
100%|██████████| 100/100 [00:22<00:00,  4.36it/s]


epoch 1: {'train_loss': 0.01205, 'train_accuracy': 0.8550444444444445, 'learning_rate': 1.7780147726823876e-05, 'eval_loss': 0.00401, 'eval_accuracy': 0.8681111111111111}
Validation loss decreased (0.004380 --> 0.004010).  Saving model ...


100%|██████████| 2813/2813 [11:00<00:00,  4.26it/s]
100%|██████████| 100/100 [00:22<00:00,  4.36it/s]


epoch 2: {'train_loss': 0.00819, 'train_accuracy': 0.9061444444444444, 'learning_rate': 1.555871548761702e-05, 'eval_loss': 0.00403, 'eval_accuracy': 0.8717777777777778}
EarlyStopping counter: 1 out of 3


100%|██████████| 2813/2813 [11:04<00:00,  4.23it/s]
100%|██████████| 100/100 [00:21<00:00,  4.70it/s]


epoch 3: {'train_loss': 0.00546, 'train_accuracy': 0.9406111111111111, 'learning_rate': 1.3338073231425525e-05, 'eval_loss': 0.00493, 'eval_accuracy': 0.8733333333333333}
EarlyStopping counter: 2 out of 3


100%|██████████| 2813/2813 [11:00<00:00,  4.26it/s]
100%|██████████| 100/100 [00:23<00:00,  4.31it/s]
Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-7bb9322cbe3cec9f.arrow


epoch 4: {'train_loss': 0.00367, 'train_accuracy': 0.9609222222222222, 'learning_rate': 1.1117430975234033e-05, 'eval_loss': 0.00554, 'eval_accuracy': 0.8703333333333333}
EarlyStopping counter: 3 out of 3
Early stopping at epoch: 4


100%|██████████| 100/100 [00:23<00:00,  4.33it/s]

Test metrics: {'test_accuracy': 0.8703333333333333}





In [240]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval_accuracy,▁▆▇█▇
eval_loss,▃▁▁▅█
learning_rate,█▆▅▃▁
test_accuracy,▁
train_accuracy,▁▅▆▇█
train_loss,█▅▃▂▁

0,1
eval_accuracy,0.87033
eval_loss,0.00554
learning_rate,1e-05
test_accuracy,0.87033
train_accuracy,0.96092
train_loss,0.00367


## Test finetuned model on examples

In [92]:
checkpoint = torch.load(LOCAL_MODEL_CHECKPOINT)
checkpoint.keys()

dict_keys(['model_state_dict', 'optimizer_state_dict', 'loss'])

In [31]:
model = BERTClassifier(model_checkpoint=HUB_MODEL_CHECKPOINT)
model.load_state_dict(checkpoint['model_state_dict'])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [93]:
def encode_example(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     return_tensors='pt',
                     return_token_type_ids=True,
                     return_attention_mask=True,
                     truncation=True)

In [94]:
def predict_class(premise, hypothesis):
    example = {
        "premise": premise,
        "hypothesis": hypothesis,
    }
    encoded = encode_example(example)
    label_idx = model(encoded).argmax(dim=-1).item()
    print(LABELS[label_idx])

In [95]:
pairs = [ 
    ('I go to the office with my personal car.', 'I take the bus to reach the office'),
    ('He is techy', 'He has good knowledge of tech'),
    ('I use my personal car to go to work', 'I usually share a taxi to reach the office')
]
for pair in pairs:
    print(pair)
    predict_class(pair[0], pair[1])

('I go to the office with my personal car.', 'I take the bus to reach the office')
contradiction
('He is techy', 'He has good knowledge of tech')
entailment
('I use my personal car to go to work', 'I usually share a taxi to reach the office')
neutral
