In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset, load_metric
import torch

In [None]:
pd.read_csv("history_eval.csv")

In [None]:
ax = plt.subplot(111)
pd.read_csv("history_train.csv").plot(x="step", y="loss", ax=ax)
pd.read_csv("history_eval.csv").plot(x="step", y="eval_loss", ax=ax, label="evaluation loss")

plt.title("")
plt.show()

In [None]:
ax = plt.subplot(111)
pd.read_csv("train_large.csv").plot(x="step", y="loss", ax=ax)
pd.read_csv("eval_large.csv").plot(x="step", y="eval_loss", ax=ax, label="evaluation loss")

plt.title("")
plt.show()

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "checkpoint-9200/checkpoint-9200"
model_dir = f"{model_name}"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

max_input_length = 512

In [None]:
dataset = load_dataset("csv", data_files="test_ds.csv")
dataset["train"] = dataset["train"].shuffle().select(range(1000))

In [None]:
import numpy as np
import nltk
nltk.download('punkt')

metric = load_metric("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip()))
                      for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) 
                      for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels,
                            use_stemmer=True)

    # Extract ROUGE f1 scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length to metrics
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id)
                      for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
import torch
import tensorflow as tf

# get test split

max_input_length = 512
max_target_length = 64

def preprocess_test(examples):
   inputs = [
      "question: " + text for text in examples["question"]]
   model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, padding="max_length")

   return model_inputs

test_tokenized_dataset = dataset.map(preprocess_test, batched=True)

# prepare dataloader
test_tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = torch.utils.data.DataLoader(test_tokenized_dataset["train"], batch_size=32)


# generate text for each batch
all_predictions = []
for i,batch in enumerate(dataloader):
   predictions = model.generate(**batch)
   all_predictions.append(predictions)
   print(i, end="\r")
   # progress_bar.add(1)

# flatten predictions
all_predictions_flattened = [pred for preds in all_predictions for pred in preds]

In [None]:
# tokenize and pad titles
all_titles = tokenizer(test_tokenized_dataset["train"]["answer"], max_length=max_target_length,
                       truncation=True, padding="max_length")["input_ids"]

# compute metrics
predictions_labels = [all_predictions_flattened, all_titles]
compute_metrics(predictions_labels)

In [None]:
Answer = "holding that while witnesses enjoy absolute immunity for their actions in testifying they are not immune for extra-judicial actions such as an alleged conspiracy to present false testimony"
inputs = "question: What is the immunity of state and regional legislators?"
inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
pred = nltk.sent_tokenize(decoded_output.strip())[0]

print(pred)