# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for zero-shot text classification on SNLI

## Setup

In [21]:
# !pip install transformers
# !pip install torch
# !pip install sentence_transformers

In [3]:
LABELS = ['ENTAILMENT', 'NEUTRAL', 'CONTRADICTION']
print(LABELS)

['ENTAILMENT', 'NEUTRAL', 'CONTRADICTION']


## Load and preprocess SNLI dataset

In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader


dataset = load_dataset('snli')
dataset


  from .autonotebook import tqdm as notebook_tqdm
Reusing dataset snli (/Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
100%|██████████| 3/3 [00:00<00:00, 375.80it/s]


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [5]:
for key in dataset.keys():
    df = dataset[key].to_pandas()
    print(f'{key.capitalize()} label distribution:')
    print(df.label.value_counts())

Test label distribution:
 0    3368
 2    3237
 1    3219
-1     176
Name: label, dtype: int64
Train label distribution:
 0    183416
 2    183187
 1    182764
-1       785
Name: label, dtype: int64
Validation label distribution:
 0    3329
 2    3278
 1    3235
-1     158
Name: label, dtype: int64


In [6]:
dataset = dataset.filter(lambda example: example['label'] != -1)

Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c209352940780cb9.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bc6047459630d9e8.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bcc43a57925b85f8.arrow


## Build inference pipeline

In [24]:
batch

{'premise': ['This church choir sings to the masses as they sing joyous songs from the book at a church.',
  'This church choir sings to the masses as they sing joyous songs from the book at a church.',
  'This church choir sings to the masses as they sing joyous songs from the book at a church.'],
 'hypothesis': ['The church has cracks in the ceiling.',
  'The church is filled with song.',
  'A choir singing at a baseball game.'],
 'label': tensor([1, 0, 2])}

In [25]:
tuples = [(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
tuples

[('This church choir sings to the masses as they sing joyous songs from the book at a church.',
  'The church has cracks in the ceiling.'),
 ('This church choir sings to the masses as they sing joyous songs from the book at a church.',
  'The church is filled with song.'),
 ('This church choir sings to the masses as they sing joyous songs from the book at a church.',
  'A choir singing at a baseball game.')]

In [59]:
# import numpy as np
# from sentence_transformers import CrossEncoder


# def model_fn():
#     return CrossEncoder('cross-encoder/nli-distilroberta-base')

# def predict_fn(batch, model):
#     tuples = [(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
#     model = CrossEncoder('cross-encoder/nli-distilroberta-base')
#     return model.predict(tuples).argmax(axis=1)

# def output_fn(preds):
#     label_mapping = ['contradiction', 'entailment', 'neutral']
#     labels = [label_mapping[score_max].upper() for score_max in preds]
#     return np.array(labels)

In [61]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

HUB_MODEL_CKPT = 'microsoft/deberta-base-mnli'

def _pack_item(pair):
    return f"{pair[0]} [SEP] {pair[1]}"

def model_fn():
    tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CKPT)
    model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CKPT)
    return pipeline(task='text-classification', model=model, tokenizer=tokenizer)

def predict_fn(batch, model):
    premise_list = batch['premise']
    hypothesis_list = batch['hypothesis']
    sentences = [_pack_item(pair) for pair in zip(premise_list, hypothesis_list)]
    return model(sentences)

def output_fn(preds):
    return np.array([d['label'] for d in preds])

In [62]:
model = model_fn()

Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [52]:
# import random


# item = dataset['test'].__getitem__(random.randint(0, 100))
# item['label'] = LABELS[item['label']].upper()
# print(item)

# model(f"{item['premise']} [SEP] {item['hypothesis']}")


In [63]:
tmp_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=3)

batch = next(iter(tmp_dataloader))
print(batch)

preds = predict_fn(batch, model)
print(preds)

output_fn(preds)


{'premise': ['This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.'], 'hypothesis': ['The church has cracks in the ceiling.', 'The church is filled with song.', 'A choir singing at a baseball game.'], 'label': tensor([1, 0, 2])}
[{'label': 'NEUTRAL', 'score': 0.8230645656585693}, {'label': 'ENTAILMENT', 'score': 0.6181530356407166}, {'label': 'CONTRADICTION', 'score': 0.997092604637146}]


array(['NEUTRAL', 'ENTAILMENT', 'CONTRADICTION'], dtype='<U13')

## Evaluate pipeline on test dataset

In [64]:
def evaluate(dataloader, model):
    batch = next(iter(dataloader))
    labels = np.array(list(map(lambda x: LABELS[x], batch['label'])))
    preds = output_fn(predict_fn(batch, model))
    test_acc = np.mean(np.float32(labels == preds))
    print(f'Test accuracy: {test_acc:.3f}')


In [66]:
%%time

# EVAL_BATCH_SIZE = 100
EVAL_BATCH_SIZE = dataset['test'].num_rows

test_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=EVAL_BATCH_SIZE)

model = model_fn()

evaluate(test_dataloader, model)


Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test accuracy: 0.858
CPU times: user 47min 21s, sys: 1min 19s, total: 48min 41s
Wall time: 8min 13s
