# Main Script

In [None]:
import fire
import torch
import numpy as np
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)

import evaluate
import os
from datetime import date

id2label = {0:'entailment', 1:'neutral', 2:'contradiction'}
label2id = {'entailment':0, 'neutral':1, 'contradiction':2}
num_labels = len(id2label)
max_target_length = 5

def preprocess_snli_batch(examples):
    premises = examples['premise']
    hypotheses = examples['hypothesis']
    labels = examples['label']

    def generate_input(_premise, _hypothesis):
        return " ".join(["premise:", _premise, "hypothesis:", _hypothesis])

    inputs = [generate_input(premise, hypothesis) for premise, hypothesis in zip(premises, hypotheses)]
    targets = [id2label[label] if (label) in range(num_labels) else "" for label in labels]
    return inputs, targets

def convertlabels2ids(example):
    example['label'] = label2id[example['label']]
    return example
    

def log_and_save_results(
    res,
    results_dir = "../res",
    outfile_name = "snli_model_performances.csv"
):
    outfile_path = os.path.join(results_dir, outfile_name)

    if not os.path.exists(results_dir): os.mkdir(results_dir)

    if not os.path.exists(outfile_path):
        with open(outfile_path,'a', newline='\n') as f:
            f.write("date; model_name; dataset; accuracy\n")

    today = date.today()

    for i  in res:
        model_name, dataset_str, accuracy = i
        with open(outfile_path,'a', newline='\n') as f:
            f.write(f"{today};{model_name}; {dataset_str}; {accuracy}\n")
        print(f"Accuracy of {model_name} on {dataset_str} dataset: {accuracy}")


def main(
    model_checkpoint,
    seed: int=42,
    batch_size: int=64,
    num_train_epochs: int= 3,
    num_proc: int=4,
    max_train_samples=None,
    max_eval_samples=None,
    output_dir: str="../res",
    use_peft: bool = False,
    do_train: bool = True,
    do_eval: bool=True,
    do_log: bool=True,
    save_path: str="/nfs/turbo/umms-vgvinodv/models/finetuned-checkpoints/nlp-gen"
):
    # Set Seed
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    
    checkpoint = model_checkpoint
    metric_name = "accuracy"
    model_name = checkpoint.split("/")[-1]
    save_path = f"{save_path}/{model_name}-snli"
    
    # Load Dataset
    raw_dataset = load_dataset("snli")
    raw_dataset = raw_dataset.filter(lambda sample: sample['label'] in id2label)
    
    # Load Model and Tokenizer
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    
    def preprocess_function(examples):
        inputs, targets = preprocess_snli_batch(examples)
        model_inputs = tokenizer(inputs)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    def compute_metrics(eval_pred):
        metric = evaluate.load("accuracy")
        predictions, labels = eval_pred

        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        pred_ids = [label2id[p] if p in label2id else -1 for p in decoded_preds]

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        label_ids = [label2id[l] if l in label2id else -1 for l in decoded_labels]

        result = metric.compute(predictions=pred_ids, references=label_ids) 
        return result
    
    # Tokenize raw dataset
    column_names = raw_dataset['train'].column_names
    train_dataset = raw_dataset["train"].map(
        preprocess_function,
        batched=True,
        num_proc=num_proc,
        remove_columns=column_names,
    )
    if max_train_samples is not None:
        train_dataset = train_dataset.select(range(max_train_samples))
    
    eval_dataset = raw_dataset["validation"].map(
        preprocess_function,
        batched=True,
        num_proc=num_proc,
        remove_columns=column_names,
    )
    if max_eval_samples is not None:
        eval_dataset = eval_dataset.select(range(max_eval_samples))
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer,model=model,)
    
    # Training Args
    args = Seq2SeqTrainingArguments(
        save_path,
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=0.01,
        save_total_limit=1,
        num_train_epochs=num_train_epochs,
        predict_with_generate=True,
        load_best_model_at_end=True,
        metric_for_best_model=metric_name,
        overwrite_output_dir=True,
        #push_to_hub=True,
    )
    
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    
    # Training
    trainer.train()  
    
    # HELPER_FUNC
    def evaluate_test_data():
        test_datasets = ['snli','multi_nli','sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']
        dataset2split = {'snli':"test", 'multi_nli':"validation_mismatched", 'sagnikrayc/snli-bt':"test", 'sagnikrayc/snli-cf-kaushik':"test"}
        res = []

        for dataset_str in test_datasets:
            target_split = dataset2split[dataset_str]#"validation_mismatched" if dataset_str == 'multi_nli' else "test"
            dataset = load_dataset(dataset_str, split=target_split)
            
            if dataset_str in ['sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']: dataset = dataset.map(convertlabels2ids) 
            dataset = dataset.filter(lambda sample: sample['label'] in list(range(num_labels)))
            
            tokenized_test_dataset = dataset.map(preprocess_function, batched=True, num_proc=num_proc, remove_columns=dataset.column_names,)
            
            results = trainer.evaluate(tokenized_test_dataset)
            res.append([model_name, dataset_str,results['eval_accuracy']])
        return res
    
    # Compute performance on test data
    res = evaluate_test_data()
    
    # Save results to CSV file
    if do_log:
        log_and_save_results(res, results_dir = output_dir, outfile_name = 'snli_model_performances.csv')


In [None]:
if __name__ == "__main__":
    main(model_checkpoint="t5-base")

In [None]:
if __name__ == "__main__":
    main(model_checkpoint="t5-large")

In [None]:
if __name__ == "__main__":
    main(model_checkpoint="google/flan-t5-base")

In [None]:
if __name__ == "__main__":
    main(model_checkpoint="google/flan-t5-large")

# Testing

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
from huggingface_hub import login
login(token="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")

In [None]:
import torch
import numpy as np
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)

import evaluate

In [None]:
import datasets

#def convertid2label(example):
#    example['label_str'] = id2label[example['label']]
#    return example

id2label = {0:'entailment', 1:'neutral', 2:'contradiction'}

raw_dataset = load_dataset("snli")
raw_dataset = raw_dataset.filter(lambda sample: sample['label'] in id2label)#.map(convertid2label)

# Preprocess Data

In [None]:
checkpoint = "google/flan-t5-base"
num_labels = 3
max_train_samples = 100
#max_train_samples = None

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
max_target_length = 5
#padding = "max_length"
#ignore_pad_token_for_loss = True

def preprocess_snli_batch(examples):
    premises = examples['premise']
    hypotheses = examples['hypothesis']
    labels = examples['label']

    def generate_input(_premise, _hypothesis):
        return " ".join(["premise:", _premise, "hypothesis:", _hypothesis])

    inputs = [generate_input(premise, hypothesis) for premise, hypothesis in zip(premises, hypotheses)]
    targets = [id2label[label] if (label) in range(num_labels) else "" for label in labels]
    return inputs, targets

def preprocess_function(examples):
    inputs, targets = preprocess_snli_batch(examples)

    model_inputs = tokenizer(inputs)
    # Setup the tokenizer for targets 
    #with tokenizer.as_target_tokenizer():
    #    labels = tokenizer(targets, max_length=5, padding=padding)
    labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
    
    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    #if padding == "max_length" and ignore_pad_token_for_loss:
    #    labels["input_ids"] = [
    #        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    #    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def compute_metrics(eval_pred):
    label2id = {'entailment':0,'neutral':1,'contradiction':2}
    metric = evaluate.load("accuracy")
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = preds[0]
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    pred_ids = [label2id[p] if p in label2id else -1 for p in decoded_preds]
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    label_ids = [label2id[l] if l in label2id else -1 for l in decoded_labels]
    
    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=pred_ids, references=label_ids)
    
    return result

In [None]:
train_dataset = raw_dataset["train"]

column_names = raw_dataset["train"].column_names

if max_train_samples is not None:
    # We will select sample from whole data if agument is specified
    train_dataset = train_dataset.select(range(max_train_samples))
    
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=column_names,
    #load_from_cache_file=not data_args.overwrite_cache,
    desc="Running tokenizer on train dataset",
)


eval_dataset = raw_dataset["validation"]
column_names = raw_dataset["validation"].column_names  
if max_train_samples is not None:
    # We will select sample from whole data if agument is specified
    eval_dataset = eval_dataset.select(range(max_train_samples))
eval_dataset = eval_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=column_names,
    #load_from_cache_file=not data_args.overwrite_cache,
    desc="Running tokenizer on eval dataset",
)

In [None]:
# Data collator
#label_pad_token_id = -100 if ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    #label_pad_token_id=label_pad_token_id,
)

In [None]:
# Training Args
metric_name = "accuracy"
batch_size = 64
num_train_epochs = 1
model_name = checkpoint.split("/")[-1]

args = Seq2SeqTrainingArguments(
    f"finetuned-checkpoints/{model_name}-snli",
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    #load_best_model_at_end=True,
    #metric_for_best_model=metric_name,
    #push_to_hub=True,
)

In [None]:
# Initialize our Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
#train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.train()

In [None]:
trainer.push_to_hub(f"{model_name}-snli")

In [None]:
trainer.evaluate()

In [None]:
text = 'premise: A person on a horse jumps over a broken down airplane. hypothesis: A person is at a diner, ordering an omelette.'
inputs = tokenizer.encode_plus(text, padding='max_length', max_length=512, return_tensors='pt').to('cuda')
outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=100)
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(prediction)

In [None]:
inputs = train_dataset[0]
outputs = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=100)
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(prediction)

# Evaluate

In [None]:
def convertids2labels(example):
    example['label'] = ids2label[example['label']]
    return example

ids2label = {'entailment':0, 'neutral':1, 'contradiction':2}

test_datasets = ['snli','multi_nli','sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']
res = []

for dataset_str in test_datasets:
    target_split = "validation_mismatched" if dataset_str == 'multi_nli' else "test"
    dataset = load_dataset(dataset_str, split=target_split)
    if dataset_str in ['sagnikrayc/snli-bt','sagnikrayc/snli-cf-kaushik']: dataset = dataset.map(convertids2labels) 
    dataset = dataset.filter(lambda sample: sample['label'] in list(range(num_labels)))
    tokenized_test_dataset = dataset.map(preprocess_function,batched=True,remove_columns=column_names,)
    results = trainer.evaluate(tokenized_test_dataset)
    res.append([model_name, dataset_str,results['eval_accuracy']])
    #print(f"Accuracy of {model_name} on {dataset_str} dataset: {results['eval_accuracy']}")   

In [None]:
print(res)

# Log and save results

In [None]:
import os
from datetime import date

results_dir = 'res'
outfile_name = 'snli_model_performances.csv'

outfile_path = os.path.join(results_dir, outfile_name)

if not os.path.exists(results_dir): os.mkdir(results_dir)

if not os.path.exists(outfile_path):
    with open(outfile_path,'a', newline='\n') as f:
        f.write("date; model_name; dataset; accuracy\n")

today = date.today()

for i  in res:
    model_name, dataset_str, accuracy = i
    with open(outfile_path,'a', newline='\n') as f:
        f.write(f"{today};{model_name}; {dataset_str}; {accuracy}\n")
    print(f"Accuracy of {model_name} on {dataset_str} dataset: {accuracy}")

## Testing T5forSeqClass

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import re
import glob
from datasets import load_dataset
import datasets

In [None]:
dataset_id = "snli"

In [None]:
from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset(dataset_id)

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

# Train dataset size: 550152
# Test dataset size: 10000

In [None]:
dataset

In [None]:
dataset['train'][1]

In [None]:
def preprocess(sample):
    def generate_input(_premise, _hypothesis):
        return " ".join(["premise:", _premise, "hypothesis:", _hypothesis])
    sample["text"] = [generate_input(_premise,_hypothesis) for _premise,_hypothesis in zip(sample['premise'],sample['hypothesis'])]
    return sample

dataset = dataset.map(preprocess, batched=True, num_proc=4, remove_columns=['premise','hypothesis'])

In [None]:
dataset

In [None]:
dataset['train'][1]

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_id="google/flan-t5-base"

# Load tokenizer of FLAN-t5-base
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
def tokenize_function(samples):
    model_inputs = tokenizer(samples["text"])

    labels = tokenizer(text_target=samples["label"], max_length=5, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
dataset = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=['text','label'])

In [None]:
print(tokenizer.encode('0'))
print(tokenizer.encode('1'))
print(tokenizer.encode('2'))