# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for zero-shot text classification on SNLI

## Setup

In [21]:
# !pip install transformers
# !pip install torch
# !pip install sentence_transformers

In [12]:
LABELS = ['entailment', 'neutral', 'contradiction']
print(LABELS)

['entailment', 'neutral', 'contradiction']


## Load and preprocess SNLI dataset

In [2]:
from datasets import load_dataset
from torch.utils.data import DataLoader


dataset = load_dataset('snli')
dataset


Reusing dataset snli (/home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [3]:
for key in dataset.keys():
    df = dataset[key].to_pandas()
    print(f'{key.capitalize()} label distribution:')
    print(df.label.value_counts())

Test label distribution:
 0    3368
 2    3237
 1    3219
-1     176
Name: label, dtype: int64
Train label distribution:
 0    183416
 2    183187
 1    182764
-1       785
Name: label, dtype: int64
Validation label distribution:
 0    3329
 2    3278
 1    3235
-1     158
Name: label, dtype: int64


In [5]:
dataset = dataset.filter(lambda example: example['label'] != -1)

Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-69cfe6f5138230bc.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-cdd6af5a251d99d5.arrow
Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-1cdec3ecad3c0af4.arrow


## Build inference pipeline

In [59]:
# import numpy as np
# from sentence_transformers import CrossEncoder


# def model_fn():
#     return CrossEncoder('cross-encoder/nli-distilroberta-base')

# def predict_fn(batch, model):
#     tuples = [(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]
#     model = CrossEncoder('cross-encoder/nli-distilroberta-base')
#     return model.predict(tuples).argmax(axis=1)

# def output_fn(preds):
#     label_mapping = ['contradiction', 'entailment', 'neutral']
#     labels = [label_mapping[score_max].upper() for score_max in preds]
#     return np.array(labels)

In [6]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

# HUB_MODEL_CKPT = 'microsoft/deberta-base-mnli'
HUB_MODEL_CKPT = 'huggingface/distilbert-base-uncased-finetuned-mnli'

def _pack_item(pair):
    return f"{pair[0]} [SEP] {pair[1]}"

def model_fn():
    tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CKPT)
    model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CKPT)
    return pipeline(task='text-classification', model=model, tokenizer=tokenizer)

def predict_fn(batch, model):
    premise_list = batch['premise']
    hypothesis_list = batch['hypothesis']
    sentences = [_pack_item(pair) for pair in zip(premise_list, hypothesis_list)]
    return model(sentences)

def output_fn(preds):
    return np.array([d['label'] for d in preds])

In [7]:
model = model_fn()

In [52]:
# import random


# item = dataset['test'].__getitem__(random.randint(0, 100))
# item['label'] = LABELS[item['label']].upper()
# print(item)

# model(f"{item['premise']} [SEP] {item['hypothesis']}")


In [8]:
tmp_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=3)

batch = next(iter(tmp_dataloader))
print(batch)

preds = predict_fn(batch, model)
print(preds)

output_fn(preds)


{'premise': ['This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'This church choir sings to the masses as they sing joyous songs from the book at a church.'], 'hypothesis': ['The church has cracks in the ceiling.', 'The church is filled with song.', 'A choir singing at a baseball game.'], 'label': tensor([1, 0, 2])}
[{'label': 'contradiction', 'score': 0.790027916431427}, {'label': 'entailment', 'score': 0.9256806969642639}, {'label': 'contradiction', 'score': 0.9379742741584778}]


array(['contradiction', 'entailment', 'contradiction'], dtype='<U13')

## Evaluate pipeline on test dataset

In [9]:
def evaluate(dataloader, model):
    batch = next(iter(dataloader))
    labels = np.array(list(map(lambda x: LABELS[x], batch['label'])))
    preds = output_fn(predict_fn(batch, model))
    test_acc = np.mean(np.float32(labels == preds))
    print(f'Test accuracy: {test_acc:.3f}')


In [15]:
%%time

EVAL_BATCH_SIZE = dataset['test'].num_rows

test_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=EVAL_BATCH_SIZE)

model = model_fn()

evaluate(test_dataloader, model)


Test accuracy: 0.769
CPU times: user 19min 28s, sys: 1.78 s, total: 19min 30s
Wall time: 5min 3s
