# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

## Setup

In [1]:
# !pip install pandas pytorch-lightning transformers wandb

In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, DebertaModel
import wandb

  from .autonotebook import tqdm as notebook_tqdm


## Custom dataset

In [3]:
MAX_LENGTH = 128*2
HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

In [4]:
# tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)
# print(tokenizer.cls_token_id)
# print(tokenizer.sep_token_id)
# tokenizer('my name is thierry', 'my name is thierry')

In [5]:
def _construct_data_path(mode):
    mode = mode if mode != 'valid' else 'dev'
    return f'SNLI_Corpus/snli_1.0_{mode}.csv'


def _preprocess(df):
    df.dropna(axis=0, inplace=True) 
    df = df[df.similarity != "-"]
    df['label'] = df["similarity"].apply(
        lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
        )
    for key in ['sentence1', 'sentence2']:
        df[key] = df[key].astype(str)
    return df


class SNLIDataset(Dataset):
    def __init__(self, mode, tokenizer_name, nrows=None) -> None:
        self.df = pd.read_csv(_construct_data_path(mode), nrows=nrows)
        self.df = _preprocess(self.df)
        self.sentence_pairs = self.df[['sentence1', 'sentence2']].values
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sentence_pair = self.sentence_pairs[idx]
        encoded = self.tokenizer(sentence_pair[0],
                                 sentence_pair[1],
                                 padding='max_length',
                                 max_length=MAX_LENGTH, 
                                 return_tensors='pt', 
                                 truncation=True)
        labels = self.df.label.values[idx]
        features = {feature: encoded[feature].to(torch.int32).squeeze() for feature in ['input_ids', 'attention_mask', 'token_type_ids']}
        features.update({'labels': labels})
        return features

In [6]:
# train_ds = SNLIDataset('train', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=1000)
# inputs = train_ds.__getitem__(0)
# inputs

In [7]:
# print(inputs['input_ids'].shape)
# inputs.keys()

## Build model

In [8]:
# # LOCAL_MODEL_CHECKPOINT = f'./{PROJECT_NAME}/checkpoint-189'

# bert = DebertaModel.from_pretrained(HUB_MODEL_CHECKPOINT)
# bert_output = bert(
#     input_ids=inputs['input_ids'].unsqueeze(0),
#     attention_mask=inputs['attention_mask'].unsqueeze(0),
#     token_type_ids=inputs['token_type_ids'].unsqueeze(0)
#     )
# bert_output.last_hidden_state.shape

In [9]:
# _loader = DataLoader(train_ds, batch_size=3, shuffle=False)
# _batch = next(iter(_loader))
# _batch.pop('labels')
# # _sequence_embeddings = bert(**_batch).pooler_output
# _sequence_embeddings = bert(**_batch).last_hidden_state[..., 0, :]
# print(_sequence_embeddings.shape)
# _clf = torch.nn.Linear(768, 3)
# _clf(_sequence_embeddings)

In [10]:
class BertNLIModel(LightningModule):
        
    def __init__(self, 
                 model_checkpoint,
                 num_labels=3,
                 freeze_bert=True,
                 learning_rate=1e-3,
                 warmup_steps=0,
                 weight_decay=0.0
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.bert = DebertaModel.from_pretrained(model_checkpoint)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.loss = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.functional.accuracy
        
    def forward(self, features):
        x = self.bert(**features).last_hidden_state[..., 0, :]
        # Could include Linear + Tanh after as in BERT
        return self.classifier(x)
    
    def _get_preds_loss_accuracy(self, batch):
        y = batch.pop('labels')
        y_hat = self(batch)
        preds = torch.argmax(y_hat, dim=1)
        loss = self.loss(y_hat, y)
        acc = self.accuracy(preds, y)
        return preds, loss, acc, y

    def training_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        preds, loss, acc, labels = self._get_preds_loss_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def test_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
    
    def validation_epoch_end(self, outputs):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        acc = self.accuracy(preds, labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", acc, prog_bar=True)
        # self.log_dict(acc, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = AdamW(self.classifier.parameters(), 
                          lr=self.hparams.learning_rate, 
                          weight_decay=self.hparams.weight_decay, 
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

## Experiments

In [11]:
TRAIN_SAMPLES = 1000
EVAL_SAMPLES = 100
BATCH_SIZE = 10
MAX_EPOCHS = 3
LR = 1e-3

wandb_logger = WandbLogger(project=PROJECT_NAME)

train_ds = SNLIDataset('train', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=TRAIN_SAMPLES)
valid_ds = SNLIDataset('valid', tokenizer_name=HUB_MODEL_CHECKPOINT, nrows=EVAL_SAMPLES)

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, num_workers=10)
valid_dataloader = DataLoader(valid_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=10)

model = BertNLIModel(HUB_MODEL_CHECKPOINT, 
                     freeze_bert=True,
                     learning_rate=LR
                     )

trainer = Trainer(
    default_root_dir=f'./models/{PROJECT_NAME}',
    logger=wandb_logger,
    callbacks=[
        TQDMProgressBar(refresh_rate=1), 
        ModelCheckpoint(monitor='val_accuracy', mode='max')
        ],
    max_epochs=MAX_EPOCHS,
    # precision=16,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs   
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mthierry-wendling-research[0m. Use [1m`wandb login --relogin`[0m to force relogin


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  # Remove the CWD from sys.path while we load stuff.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  del sys.path[0]
Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaModel: ['classifier.weight', 'pooler.dense.bias', 'classifier.bias', 'pooler.dense.weight', 'config']
- This IS expected if you are initializing DebertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification mo

In [12]:
trainer.fit(model, train_dataloader, valid_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.

  | Name       | Type             | Params
------------------------------------------------
0 | bert       | DebertaModel     | 138 M 
1 | classifier | Linear           | 2.3 K 
2 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
2.3 K     Trainable params
138 M     Non-trainable params
138 M     Total params
554.416   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling para

  query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
  p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))


Epoch 2: 100%|██████████| 110/110 [05:29<00:00,  2.99s/it, loss=0.258, v_num=m62j, val_loss=0.277, val_accuracy=0.909]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 110/110 [05:29<00:00,  2.99s/it, loss=0.258, v_num=m62j, val_loss=0.277, val_accuracy=0.909]


In [13]:
test_ds = SNLIDataset('test', HUB_MODEL_CHECKPOINT, nrows=1000)
test_dataloader = DataLoader(test_ds, shuffle=False, batch_size=100, num_workers=10)

trainer.test(dataloaders=test_dataloader)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  # Remove the CWD from sys.path while we load stuff.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  del sys.path[0]
  + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
Restoring states from the checkpoint path at ./deberta-base-mnli-finetuned-snli/tu92m62j/checkpoints/epoch=0-step=100.ckpt
Loaded model weights from checkpoint at ./deberta-base-mnli-finetuned-snli/tu92m62j/checkpoints/epoch=0-step=100.ckpt


Testing DataLoader 0: 100%|██████████| 10/10 [04:08<00:00, 24.88s/it]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.8441295623779297
        test_loss           0.4813233017921448
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.4813233017921448, 'test_accuracy': 0.8441295623779297}]

In [14]:
wandb.finish()

0,1
epoch,▁▁▁▅▅▅███▁
test_accuracy,▁
test_loss,▁
train_accuracy,█▁█▇▃▇
train_loss,▁█▃▁▆▃
trainer/global_step,▁▂▂▄▅▅▇███
val_accuracy,▁▁▁
val_loss,██▁

0,1
epoch,0.0
test_accuracy,0.84413
test_loss,0.48132
train_accuracy,0.875
train_loss,0.37176
trainer/global_step,300.0
val_accuracy,0.90909
val_loss,0.27661
