# RTE (Recognizing Textual Entailment) with transformers
## Using a pretrained transformer fine-tuned on MNLI for zero-shot text classification on SNLI

## Setup

In [2]:
# !pip install flair torch

## Load and preprocess SNLI dataset

In [49]:
LABELS = ['entailment', 'neutral', 'contradiction']
print(LABELS)

['entailment', 'neutral', 'contradiction']


In [52]:
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader


dataset = load_dataset('snli')
dataset


Reusing dataset snli (/Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
100%|██████████| 3/3 [00:00<00:00, 496.72it/s]


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [5]:
for key in dataset.keys():
    df = dataset[key].to_pandas()
    print(f'{key.capitalize()} label distribution:')
    print(df.label.value_counts())

Test label distribution:
 0    3368
 2    3237
 1    3219
-1     176
Name: label, dtype: int64
Train label distribution:
 0    183416
 2    183187
 1    182764
-1       785
Name: label, dtype: int64
Validation label distribution:
 0    3329
 2    3278
 1    3235
-1     158
Name: label, dtype: int64


In [6]:
dataset = dataset.filter(lambda example: example['label'] != -1)

Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-c209352940780cb9.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bc6047459630d9e8.arrow
Loading cached processed dataset at /Users/thierry.wendling/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bcc43a57925b85f8.arrow


## Explore TARS

In [72]:
example = dataset['train'][np.random.randint(0, 1000)]
print(example)
# text = f"[CLS] {example['premise']} [SEP] {example['premise']} [SEP]"
text = f"{example['premise']} {example['premise']}"
print(text)
print(LABELS[example['label']])

{'premise': 'three bikers stop in town.', 'hypothesis': "The bikers didn't stop in the town.", 'label': 2}
three bikers stop in town. three bikers stop in town.
contradiction


In [73]:
from flair.models import TARSClassifier
from flair.data import Sentence

# 1. Load our pre-trained TARS model for English
# tars = TARSClassifier.load('tars-base')

# 2. Prepare a test sentence
sentence = Sentence(text)

# 3. Define some classes that you want to predict using descriptive names
classes = LABELS
# print(classes)

#4. Predict for these classes
tars.predict_zero_shot(sentence, classes)

# Print sentence with predicted labels
print(sentence)

Sentence: "three bikers stop in town . three bikers stop in town ." → neutral (0.9107)
