# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

## Setup

In [1]:
# !pip install torch
# !pip install pytorch-lightning
# !pip install transformers
# !pip install sklearn
# !pip install evaluate
# !pip install pandas
# !pip install wandb

In [2]:
# !wandb login

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from transformers import AutoTokenizer, BertModel, AdamW, get_constant_schedule_with_warmup
import evaluate
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthierry-wendling-research[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Custom dataset

In [2]:
MAX_LENGTH = 128*2
HUB_MODEL_CHECKPOINT = 'bert-base-uncased'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

wandb_logger = WandbLogger(project=PROJECT_NAME, log_model='all')

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.033371996879577634, max=1.0…

In [5]:
# tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)
# print(tokenizer.cls_token_id)
# print(tokenizer.sep_token_id)
# tokenizer('my name is thierry', 'my name is thierry')

In [4]:
def _construct_data_path(mode):
    mode = mode if mode != 'valid' else 'dev'
    return f'SNLI_Corpus/snli_1.0_{mode}.csv'


def _preprocess(df):
    df.dropna(axis=0, inplace=True) 
    df = df[df.similarity != "-"]
    df['label'] = df["similarity"].apply(
        lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
        )
    for key in ['sentence1', 'sentence2']:
        df[key] = df[key].astype(str)
    return df


class SNLIDataset(Dataset):
    def __init__(self, mode, tokenizer_name, nrows=None) -> None:
        self.df = pd.read_csv(_construct_data_path(mode), nrows=nrows)
        self.df = _preprocess(self.df)
        self.sentence_pairs = self.df[['sentence1', 'sentence2']].values
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sentence_pair = self.sentence_pairs[idx]
        encoded = self.tokenizer(sentence_pair[0],
                                 sentence_pair[1],
                                 padding='max_length',
                                 max_length=MAX_LENGTH, 
                                 return_tensors='pt', 
                                 truncation=True)
        labels = self.df.label.values[idx]
        features = {feature: encoded[feature].to(torch.int32).squeeze() for feature in ['input_ids', 'attention_mask', 'token_type_ids']}
        features.update({'labels': labels})
        return features

In [7]:
# train_ds = SNLIDataset('train', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=1000)
# inputs = train_ds.__getitem__(0)
# inputs

In [8]:
# print(inputs['input_ids'].shape)
# inputs.keys()

## Build model

In [9]:
# # LOCAL_MODEL_CHECKPOINT = f'./{PROJECT_NAME}/checkpoint-189'

# bert = BertModel.from_pretrained(HUB_MODEL_CHECKPOINT)
# bert_output = bert(
#     input_ids=inputs['input_ids'].unsqueeze(0),
#     attention_mask=inputs['attention_mask'].unsqueeze(0),
#     token_type_ids=inputs['token_type_ids'].unsqueeze(0)
#     )
# bert_output.last_hidden_state.shape

In [10]:
# _loader = DataLoader(train_ds, batch_size=3, shuffle=False)
# _batch = next(iter(_loader))
# _batch.pop('labels')
# _sequence_embeddings = bert(**_batch).pooler_output
# print(_sequence_embeddings.shape)
# _clf = torch.nn.Linear(768, 3)
# _clf(_sequence_embeddings)

In [5]:
class BertNLIModel(LightningModule):
        
    def __init__(self, 
                 model_checkpoint,
                 num_labels=3,
                 metric_name='accuracy',
                 freeze_bert=True,
                 learning_rate=2e-5,
                 adam_epsilon=1e-6,
                 warmup_steps=0,
                 weight_decay=0.0
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained(model_checkpoint)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.loss = torch.nn.CrossEntropyLoss()
        self.metric = evaluate.load(metric_name)
        
    def forward(self, features):
        x = self.bert(**features).pooler_output
        return self.classifier(x)
    
    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        y = batch.pop('labels')
        y_hat = self(batch)
        preds = torch.argmax(y_hat, dim=1)
        loss = self.loss(y_hat, y)
        acc = self.metric.compute(predictions=preds, references=y)
        return preds, loss, acc, y

    def training_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        preds, loss, acc, labels = self._get_preds_loss_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def validation_epoch_end(self, outputs):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        acc = self.metric.compute(predictions=preds, references=labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", acc, prog_bar=True)
        self.log_dict(acc, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = AdamW(self.classifier.parameters(), lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon, correct_bias=False)
        scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

## Experiments

In [7]:
TRAIN_SAMPLES = 100
EVAL_SAMPLES = 10
BATCH_SIZE = 32
EPOCHS = 1

train_ds = SNLIDataset('train', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=TRAIN_SAMPLES)
valid_ds = SNLIDataset('valid', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=EVAL_SAMPLES)

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, num_workers=3)
valid_dataloader = DataLoader(valid_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=3)

model = BertNLIModel(HUB_MODEL_CHECKPOINT)

trainer = Trainer(
    default_root_dir=PROJECT_NAME,
    logger=wandb_logger,
    callbacks=[TQDMProgressBar(refresh_rate=10), ModelCheckpoint(monitor='val_accuracy', mode='max')],
    max_epochs=EPOCHS,
    precision=16,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs   
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IP

In [8]:
trainer.fit(model, train_dataloader, valid_dataloader)


  | Name       | Type             | Params
------------------------------------------------
0 | bert       | BertModel        | 109 M 
1 | classifier | Linear           | 2.3 K 
2 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
2.3 K     Trainable params
109 M     Non-trainable params
109 M     Total params
437.938   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [None]:
# test_ds = SNLIDataset('test', HUB_MODEL_CHECKPOINT, nrows=None)

# trainer.evaluate(test_ds)