# RTE (Recognizing Textual Entailment) with DeBERTa
## Using a pretrained DeBERTa model fine-tuned on MNLI for zero-shot text classification on SNLI
Inspired by Keras code example [Semantic Similarity with BERT](https://keras.io/examples/nlp/semantic_similarity_with_bert/)

## Setup

In [1]:
# !pip install pytorch-lightning transformers wandb

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, DebertaModel
import wandb

  from .autonotebook import tqdm as notebook_tqdm


## Custom dataset

In [2]:
MAX_LENGTH = 128*2
HUB_MODEL_CHECKPOINT = 'microsoft/deberta-base-mnli'
MODEL_NAME = HUB_MODEL_CHECKPOINT.split("/")[-1]
PROJECT_NAME = f'{MODEL_NAME}-finetuned-snli'

In [3]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


dataset = load_dataset('snli')
dataset = dataset.filter(lambda example: example['label'] != -1) 
dataset = dataset.rename_column('label', 'labels')
dataset

Reusing dataset snli (/Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
100%|██████████| 3/3 [00:00<00:00, 433.97it/s]
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c209352940780cb9.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bc6047459630d9e8.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bcc43a57925b85f8.arrow


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9824
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 549367
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'labels'],
        num_rows: 9842
    })
})

In [4]:
tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CHECKPOINT)

example = dataset['train'][0]
tokenizer(example['premise'], example['hypothesis'])

{'input_ids': [1, 250, 621, 15, 10, 5253, 13855, 81, 10, 3187, 159, 16847, 4, 2, 250, 621, 16, 1058, 39, 5253, 13, 10, 1465, 4, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [5]:
def tokenization(example):
    return tokenizer(example['premise'], 
                     example['hypothesis'],
                     padding='max_length',
                     max_length=MAX_LENGTH, 
                     truncation=True)

dataset = dataset.map(tokenization, batched=True)

for key in dataset.keys():
    dataset[key].set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])

print(dataset['train'][0].keys())

100%|██████████| 10/10 [00:01<00:00,  6.23ba/s]
100%|██████████| 550/550 [01:25<00:00,  6.45ba/s]
100%|██████████| 10/10 [00:01<00:00,  6.51ba/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])





In [7]:
example = dataset['train'][0]
example

{'labels': tensor(1),
 'input_ids': tensor([    1,   250,   621,    15,    10,  5253, 13855,    81,    10,  3187,
           159, 16847,     4,     2,   250,   621,    16,  1058,    39,  5253,
            13,    10,  1465,     4,     2,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,  

## Build model

In [9]:
# # LOCAL_MODEL_CHECKPOINT = f'./{PROJECT_NAME}/checkpoint-189'

# bert = DebertaModel.from_pretrained(HUB_MODEL_CHECKPOINT)
# bert_output = bert(
#     input_ids=example['input_ids'].unsqueeze(0),
#     attention_mask=example['attention_mask'].unsqueeze(0),
#     token_type_ids=example['token_type_ids'].unsqueeze(0)
#     )
# bert_output.last_hidden_state.shape

In [10]:
_loader = DataLoader(dataset['train'], batch_size=3, shuffle=False)
_batch = next(iter(_loader))
_batch.pop('labels')
# _sequence_embeddings = bert(**_batch).pooler_output
_sequence_embeddings = bert(**_batch).last_hidden_state[..., 0, :]
print(_sequence_embeddings.shape)
_clf = torch.nn.Linear(768, 3)
_clf(_sequence_embeddings)

  query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
  p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))


torch.Size([3, 768])


tensor([[-1.1097, -0.3840, -1.3507],
        [-0.2425, -0.0358, -0.6693],
        [-0.4267,  0.3751, -1.9390]], grad_fn=<AddmmBackward0>)

In [11]:
class BertNLIModel(LightningModule):
        
    def __init__(self, 
                 model_checkpoint,
                 num_labels=3,
                 freeze_bert=True,
                 learning_rate=1e-3,
                 warmup_steps=0,
                 weight_decay=0.0
                 ):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.bert = DebertaModel.from_pretrained(model_checkpoint)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.loss = torch.nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.functional.accuracy
        
    def forward(self, features):
        x = self.bert(**features).last_hidden_state[..., 0, :]
        # Could include Linear + Tanh after as in BERT
        return self.classifier(x)
    
    def _get_preds_loss_accuracy(self, batch):
        y = batch.pop('labels')
        y_hat = self(batch)
        preds = torch.argmax(y_hat, dim=1)
        loss = self.loss(y_hat, y)
        acc = self.accuracy(preds, y)
        return preds, loss, acc, y

    def training_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        preds, loss, acc, labels = self._get_preds_loss_accuracy(batch)
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def test_step(self, batch, batch_idx):
        _, loss, acc, _ = self._get_preds_loss_accuracy(batch)
        self.log('test_loss', loss)
        self.log('test_accuracy', acc)
    
    def validation_epoch_end(self, outputs):
        preds = torch.cat([x["preds"] for x in outputs]).detach().cpu()
        labels = torch.cat([x["labels"] for x in outputs]).detach().cpu()
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        acc = self.accuracy(preds, labels)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", acc, prog_bar=True)
        # self.log_dict(acc, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = AdamW(self.classifier.parameters(), 
                          lr=self.hparams.learning_rate, 
                          weight_decay=self.hparams.weight_decay, 
                          correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

## Experiments

In [14]:
TRAIN_SAMPLES = 1000
EVAL_SAMPLES = 100
BATCH_SIZE = 32
MAX_EPOCHS = 1
LR = 1e-3

wandb_logger = WandbLogger(project=PROJECT_NAME)

train_ds = dataset['train'].select(list(range(TRAIN_SAMPLES)))
valid_ds = dataset['train'].select(list(range(EVAL_SAMPLES)))

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, num_workers=10)
valid_dataloader = DataLoader(valid_ds, shuffle=False, batch_size=BATCH_SIZE, num_workers=10)

model = BertNLIModel(HUB_MODEL_CHECKPOINT, 
                     freeze_bert=True,
                     learning_rate=LR)

trainer = Trainer(
    default_root_dir=PROJECT_NAME,
    logger=wandb_logger,
    callbacks=[
        TQDMProgressBar(refresh_rate=1), 
        ModelCheckpoint(monitor='val_accuracy', mode='max')
        ],
    max_epochs=MAX_EPOCHS,
    # precision=16,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs   
)

  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaModel: ['pooler.dense.weight', 'classifier.bias', 'config', 'classifier.weight', 'pooler.dense.bias']
- This IS expected if you are initializing DebertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [15]:
trainer.fit(model, train_dataloader, valid_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.

  | Name       | Type             | Params
------------------------------------------------
0 | bert       | DebertaModel     | 138 M 
1 | classifier | Linear           | 2.3 K 
2 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
2.3 K     Trainable params
138 M     Non-trainable params
138 M     Total params
554.416   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling para

  query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
  p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))


Epoch 0:   0%|          | 0/36 [00:00<?, ?it/s]                            huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after paralleli

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 36/36 [05:55<00:00,  9.89s/it, loss=0.567, v_num=iiny, val_loss=0.279, val_accuracy=0.910]


In [16]:
test_ds = dataset['test'].select(list(range(1000)))
test_dataloader = DataLoader(test_ds, shuffle=False, batch_size=100, num_workers=10)

trainer.test(dataloaders=test_dataloader)

  + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
Restoring states from the checkpoint path at ./deberta-base-mnli-finetuned-snli/138liiny/checkpoints/epoch=0-step=32.ckpt
Loaded model weights from checkpoint at ./deberta-base-mnli-finetuned-snli/138liiny/checkpoints/epoch=0-step=32.ckpt


Testing: 0it [00:00, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism 

[{'test_loss': 0.41851457953453064, 'test_accuracy': 0.8510000109672546}]

In [17]:
wandb.finish()

0,1
epoch,▁▁
test_accuracy,▁
test_loss,▁
trainer/global_step,▁█
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
test_accuracy,0.851
test_loss,0.41851
trainer/global_step,32.0
val_accuracy,0.91
val_loss,0.27923
