# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for fine-tuning on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

## Setup

In [3]:
# !pip install pytorch-lightning transformers wandb

In [19]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from transformers import AutoTokenizer, AdamW, get_constant_schedule_with_warmup, DebertaModel
import wandb

## Custom dataset

In [5]:
MAX_LENGTH = 128*2
HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

In [6]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-21d54e6470652178.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-b746e1998966e2f4.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89fb34b79586ce05.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [7]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'])

{'input_ids': [1, 250, 621, 15, 10, 5253, 13855, 81, 10, 3187, 159, 16847, 4, 2, 250, 621, 16, 1058, 39, 5253, 13, 10, 1465, 4, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [8]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-89e08e2169e07c1c.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-5fc1ed23f08797ae.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-4ac3457c935c3968.arrow


dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [9]:
example = dataset['train'][0]
example

{'labels': tensor(1),
 'input_ids': tensor([    1,   250,   621,    15,    10,  5253, 13855,    81,    10,  3187,
           159, 16847,     4,     2,   250,   621,    16,  1058,    39,  5253,
            13,    10,  1465,     4,     2,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,  

## Build model

In [9]:
# # LOCAL_MODEL_CHECKPOINT = f'./{PROJECT_NAME}/checkpoint-189'

# bert = DebertaModel.from_pretrained(HUB_MODEL_CHECKPOINT)
# bert_output = bert(
#     input_ids=example['input_ids'].unsqueeze(0),
#     attention_mask=example['attention_mask'].unsqueeze(0),
#     token_type_ids=example['token_type_ids'].unsqueeze(0)
#     )
# bert_output.last_hidden_state.shape

In [10]:
# _loader = DataLoader(dataset['train'], batch_size=3, shuffle=False)
# _batch = next(iter(_loader))
# _batch.pop('labels')
# # _sequence_embeddings = bert(**_batch).pooler_output
# _sequence_embeddings = bert(**_batch).last_hidden_state[..., 0, :]
# print(_sequence_embeddings.shape)
# _clf = torch.nn.Linear(768, 3)
# _clf(_sequence_embeddings)

In [10]:
class BertNLIModel(LightningModule):
        
    def __init__(self, 
                 model_checkpoint,
                 num_labels=3,
                 freeze_bert=True,
                 learning_rate=1e-3,
                 warmup_steps=10,
                 weight_decay=0.01
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.bert = DebertaModel.from_pretrained(model_checkpoint)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.loss = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.functional.accuracy
        
    def forward(self, features):
        x = self.bert(**features).last_hidden_state[..., 0, :]
        # Could include Linear + Tanh after as in BERT
        return self.classifier(x)
    
    def _get_preds_loss_accuracy(self, batch):
        y = batch.pop('labels')
        y_hat = self(batch)
        preds = torch.argmax(y_hat, dim=1)
        loss = self.loss(y_hat, y)
        acc = self.accuracy(preds, y)
        return preds, loss, acc, y

    def training_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return {"loss": loss}
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        preds, loss, acc, labels = self._get_preds_loss_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return {"loss": loss, "preds": preds, "labels": labels, 'accuracy': acc}
    
    def test_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
        return {"loss": loss, 'accuracy': acc}
    
    def validation_epoch_end(self, outputs):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        acc = self.accuracy(preds, labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", acc, prog_bar=True)
        self.log_dict({'accuracy': acc}, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = AdamW(self.classifier.parameters(), 
                          lr=self.hparams.learning_rate, 
                          weight_decay=self.hparams.weight_decay, 
                          correct_bias=False)
        scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

## Experiments

In [20]:
TRAIN_SAMPLES = 10000
EVAL_SAMPLES = 1000
BATCH_SIZE = 32
PER_DEVICE_TRAIN_BATCH_SIZE = 2
MAX_EPOCHS = 3
LR = 1e-3
WEIGHT_DECAY = 0.01
FREEZE_BERT = True

wandb_logger = WandbLogger(project=PROJECT_NAME)

train_ds = dataset['train'].select(np.random.randint(0, dataset['train'].num_rows, TRAIN_SAMPLES).tolist())
eval_ds = dataset['validation'].select(np.random.randint(0, dataset['validation'].num_rows, EVAL_SAMPLES).tolist())

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, num_workers=6)
eval_dataloader = DataLoader(eval_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=6)

model = BertNLIModel(HUB_MODEL_CHECKPOINT, 
                     freeze_bert=FREEZE_BERT,
                     learning_rate=LR,
                     weight_decay=WEIGHT_DECAY,
                    )

trainer = Trainer(
    default_root_dir=PROJECT_NAME,
    logger=wandb_logger,
    callbacks=[
        TQDMProgressBar(refresh_rate=1), 
        ModelCheckpoint(monitor='val_accuracy', mode='max'),
        EarlyStopping('val_loss', patience=1, min_delta=0.1, mode='min')
        ],
    max_epochs=MAX_EPOCHS,
    precision=16,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs   
    accumulate_grad_batches=PER_DEVICE_TRAIN_BATCH_SIZE
)

Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaModel: ['pooler.dense.weight', 'classifier.weight', 'config', 'pooler.dense.bias', 'classifier.bias']
- This IS expected if you are initializing DebertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [21]:
trainer.fit(model, train_dataloader, eval_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"

  | Name       | Type             | Params
------------------------------------------------
0 | bert       | DebertaModel     | 138 M 
1 | classifier | Linear           | 2.3 K 
2 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
2.3 K     Trainable params
138 M     Non-trainable params
138 M     Total params
277.208   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [22]:
test_ds = dataset['test']
test_dataloader = DataLoader(test_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=6)

trainer.test(dataloaders=test_dataloader)

  f"`.{fn}(ckpt_path=None)` was called without a model."
Restoring states from the checkpoint path at deberta-base-mnli-finetuned-snli/deberta-base-mnli-finetuned-snli/q4tyf11v/checkpoints/epoch=1-step=313.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at deberta-base-mnli-finetuned-snli/deberta-base-mnli-finetuned-snli/q4tyf11v/checkpoints/epoch=1-step=313.ckpt


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.8686889410018921, 'test_loss': 0.38150331377983093}
--------------------------------------------------------------------------------


[{'test_loss': 0.38150331377983093, 'test_accuracy': 0.8686889410018921}]

In [24]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁█
epoch,▁▁▁▁▁▁▅▅▅▅█
test_accuracy,▁
test_loss,▁
train_accuracy,▆▇▆▁▄▃█▃
train_loss,▄▁▂▅█▅▃▆
trainer/global_step,▁▂▁▂▄▄▅▆███
val_accuracy,▁█
val_loss,▁█

0,1
accuracy,0.869
epoch,2.0
test_accuracy,0.86869
test_loss,0.3815
train_accuracy,0.75
train_loss,0.54398
trainer/global_step,314.0
val_accuracy,0.869
val_loss,0.37473
