# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for zero-shot text classification on SNLI

## Setup

In [8]:
# !pip install transformers datasets
# !pip install sentence_transformers

In [9]:
LABELS = ['entailment', 'neutral', 'contradiction']
print(LABELS)

['entailment', 'neutral', 'contradiction']


## Load and preprocess SNLI dataset

In [10]:
from datasets import load_dataset
from torch.utils.data import DataLoader


dataset = load_dataset('snli')
dataset


  from .autonotebook import tqdm as notebook_tqdm
Downloading builder script: 100%|██████████| 3.82k/3.82k [00:00<00:00, 2.24MB/s]
Downloading metadata: 100%|██████████| 1.90k/1.90k [00:00<00:00, 850kB/s]
Downloading readme: 100%|██████████| 13.6k/13.6k [00:00<00:00, 6.28MB/s]


Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /home/jupyter/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading: 100%|██████████| 1.93k/1.93k [00:00<00:00, 1.17MB/s]
Downloading: 100%|██████████| 1.26M/1.26M [00:00<00:00, 35.6MB/s]
Downloading: 100%|██████████| 65.9M/65.9M [00:01<00:00, 40.3MB/s]
Downloading: 100%|██████████| 1.26M/1.26M [00:00<00:00, 35.0MB/s]


Dataset snli downloaded and prepared to /home/jupyter/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 22.91it/s]


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [12]:
for key in dataset.keys():
    df = dataset[key].to_pandas()
    print(f'{key.capitalize()} label distribution:')
    print(df.label.value_counts())

Test label distribution:
 0    3368
 2    3237
 1    3219
-1     176
Name: label, dtype: int64
Train label distribution:
 0    183416
 2    183187
 1    182764
-1       785
Name: label, dtype: int64
Validation label distribution:
 0    3329
 2    3278
 1    3235
-1     158
Name: label, dtype: int64


In [13]:
dataset = dataset.filter(lambda example: example['label'] != -1)

 90%|█████████ | 9/10 [00:00<00:00, 26.81ba/s]
100%|█████████▉| 550/551 [00:04<00:00, 125.10ba/s]
 90%|█████████ | 9/10 [00:00<00:00, 111.08ba/s]


## Build inference pipeline

In [59]:
# import numpy as np
# from sentence_transformers import CrossEncoder


# def model_fn():
#     return CrossEncoder('cross-encoder/nli-distilroberta-base')

# def predict_fn(batch, model):
#     tuples = [(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
#     model = CrossEncoder('cross-encoder/nli-distilroberta-base')
#     return model.predict(tuples).argmax(axis=1)

# def output_fn(preds):
#     label_mapping = ['contradiction', 'entailment', 'neutral']
#     labels = [label_mapping[score_max].upper() for score_max in preds]
#     return np.array(labels)

In [30]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

HUB_MODEL_CKPT = 'microsoft/deberta-base-mnli'
LABELS = list(map(lambda x: x.upper(), LABELS))
# HUB_MODEL_CKPT = 'huggingface/distilbert-base-uncased-finetuned-mnli'
              
print(LABELS)

def _pack_item(pair):
    return f"{pair[0]} [SEP] {pair[1]}"

def model_fn():
    tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CKPT)
    model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CKPT)
    return pipeline(task='text-classification', model=model, tokenizer=tokenizer)

def predict_fn(batch, model):
    premise_list = batch['premise']
    hypothesis_list = batch['hypothesis']
    sentences = [_pack_item(pair) for pair in zip(premise_list, hypothesis_list)]
    return model(sentences)

def output_fn(preds):
    return np.array([d['label'] for d in preds])

['ENTAILMENT', 'NEUTRAL', 'CONTRADICTION']


In [31]:
model = model_fn()

Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [32]:
import random


item = dataset['test'].__getitem__(random.randint(0, 100))
item['label'] = LABELS[item['label']].upper()
print(item)

model(f"{item['premise']} [SEP] {item['hypothesis']}")


{'premise': 'A woman wearing a ball cap squats down to touch the cracked earth.', 'hypothesis': 'A squatting woman wearing a hat touching the ground.', 'label': 'ENTAILMENT'}


[{'label': 'CONTRADICTION', 'score': 0.7868318557739258}]

In [33]:
tmp_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=3)

batch = next(iter(tmp_dataloader))
print(batch)

preds = predict_fn(batch, model)
print(preds)

output_fn(preds)


{'premise': ['This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.'], 'hypothesis': ['The church has cracks in the ceiling.', 'The church is filled with song.', 'A choir singing at a baseball game.'], 'label': tensor([1, 0, 2])}
[{'label': 'NEUTRAL', 'score': 0.8230648040771484}, {'label': 'ENTAILMENT', 'score': 0.618152916431427}, {'label': 'CONTRADICTION', 'score': 0.997092604637146}]


array(['NEUTRAL', 'ENTAILMENT', 'CONTRADICTION'], dtype='<U13')

## Evaluate pipeline on test dataset

In [34]:
def evaluate(dataloader, model):
    batch = next(iter(dataloader))
    labels = np.array(list(map(lambda x: LABELS[x], batch['label'])))
    preds = output_fn(predict_fn(batch, model))
    test_acc = np.mean(np.float32(labels == preds))
    print(f'Test accuracy: {test_acc:.3f}')


In [36]:
%%time

EVAL_BATCH_SIZE = 1000
# EVAL_BATCH_SIZE = dataset['test'].num_rows

test_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=EVAL_BATCH_SIZE)

model = model_fn()

evaluate(test_dataloader, model)


Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test accuracy: 0.838
CPU times: user 3min 5s, sys: 4.04 s, total: 3min 9s
Wall time: 1min 37s
