# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for zero-shot text classification on SNLI

## Setup

In [26]:
# !pip install transformers
# !pip install torch
# !pip install pandas

In [132]:
LABELS = ['ENTAILMENT', 'NEUTRAL', 'CONTRADICTION']
print(LABELS)

['ENTAILMENT', 'NEUTRAL', 'CONTRADICTION']


## Load and preprocess SNLI dataset

In [64]:
from datasets import load_dataset
from torch.utils.data import DataLoader


dataset = load_dataset('snli')
dataset


Reusing dataset snli (/Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
100%|██████████| 3/3 [00:00<00:00, 489.40it/s]


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [65]:
for key in dataset.keys():
    df = dataset[key].to_pandas()
    print(f'{key.capitalize()} label distribution:')
    print(df.label.value_counts())

Test label distribution:
 0    3368
 2    3237
 1    3219
-1     176
Name: label, dtype: int64
Train label distribution:
 0    183416
 2    183187
 1    182764
-1       785
Name: label, dtype: int64
Validation label distribution:
 0    3329
 2    3278
 1    3235
-1     158
Name: label, dtype: int64


In [66]:
dataset = dataset.filter(lambda example: example['label'] != -1)

Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c209352940780cb9.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bc6047459630d9e8.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bcc43a57925b85f8.arrow


## Build inference pipeline

In [125]:
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

HUB_MODEL_CKPT = 'microsoft/deberta-base-mnli'

def _pack_item(pair):
    return f"{pair[0]} [SEP] {pair[1]}"

def model_fn():
    tokenizer = AutoTokenizer.from_pretrained(HUB_MODEL_CKPT)
    model = AutoModelForSequenceClassification.from_pretrained(HUB_MODEL_CKPT)
    return pipeline(task='text-classification', model=model, tokenizer=tokenizer)

def predict_fn(items, model):
    premise_list = items['premise']
    hypothesis_list = items['hypothesis']
    sentences = [_pack_item(pair) for pair in zip(premise_list, hypothesis_list)]
    return model(sentences)

def output_fn(predictions):
    return np.array([d['label'] for d in predictions])

In [146]:
# model = model_fn()

In [147]:
# import random


# item = dataset['test'].__getitem__(random.randint(0, 100))
# item['label'] = LABELS[item['label']].upper()
# print(item)

# model(f"{item['premise']} [SEP] {item['hypothesis']}")


In [148]:
# tmp_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=3)

# batch = next(iter(tmp_dataloader))
# print(batch)

# preds = predict_fn(batch, model)
# print(preds)

# output_fn(preds)


## Evaluate pipeline on test dataset

In [139]:
def evaluate(dataloader, model):
    batch = next(iter(dataloader))
    labels = np.array(list(map(lambda x: LABELS[x], batch['label'])))
    preds = output_fn(predict_fn(batch, model))
    test_acc = np.mean(np.float32(labels == preds))
    print(f'Test accuracy: {test_acc:.3f}')


In [145]:
%%time

EVAL_BATCH_SIZE = 100
EVAL_BATCH_SIZE = dataset['test'].num_rows

test_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=EVAL_BATCH_SIZE)

model = model_fn()

evaluate(test_dataloader, model)


  query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
  p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))


Test accuracy: 0.858
CPU times: user 45min, sys: 32.1 s, total: 45min 32s
Wall time: 7min 42s
