**Supported models**
* BERT
* RoBERTa

## Load Dependencies

In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForSequenceClassification

## Helper Functions

In [None]:
id2label = {0:'entailment', 1:'neutral', 2:'contradiction'}
label2id = {'entailment':0, 'neutral':1, 'contradiction':2}
num_labels = len(id2label)

def convertlabels2ids(example):
    example['label'] = label2id[example['label']]
    return example

def tokenize_function(examples):
    return tokenizer(examples['premise'],examples['hypothesis'])

def compute_metrics(eval_pred):
    metric = evaluate.load("accuracy")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

## Load Model and Tokenizer

In [None]:
seed = 42
num_proc = 4 # num of cpu workers
checkpoint = 'varun-v-rao/bert-base-cased-snli'

torch.cuda.manual_seed(seed)
torch.manual_seed(seed)

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

## Evaluate Performance

In [None]:
num_test_epochs = 5

test_datasets = ['snli', 'multi_nli','sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']
dataset2split = {'snli':"test", 'multi_nli':"validation_mismatched", 'sagnikrayc/snli-bt':"test", 'sagnikrayc/snli-cf-kaushik':"test"}

trainer = Trainer(
    model,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

for dataset_str in test_datasets:
    target_split = dataset2split[dataset_str] #"validation_mismatched" if dataset_str == 'multi_nli' else "test"
    dataset = load_dataset(dataset_str, split=target_split)
    
    if dataset_str in ['sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']: dataset = dataset.map(convertlabels2ids) 

    tokenized_test_dataset = dataset.map(tokenize_function, batched=True, num_proc=num_proc).filter(lambda sample: sample['label'] in list(range(num_labels)))

    col_names = dataset.column_names
    col_names.remove('label')
    tokenized_test_dataset = tokenized_test_dataset.rename_column('label', 'labels').remove_columns(col_names)
    
    tmp_results = []
    print(f"--- Evaluating performance on {dataset_str} ---")
    for i in range(num_test_epochs):
        results = trainer.evaluate(tokenized_test_dataset)
        tmp_results.append(results['eval_accuracy'])
        
    print(f"Results array: {tmp_results}")    
    averaged_results = np.mean(np.asarray(tmp_results))
    print(f"Results averaged over {num_test_epochs} epochs: {averaged_results*100} %\n")