In [None]:
from typing import Dict, Tuple
import nltk

from tqdm import tqdm

import pandas as pd
from sklearn.model_selection import train_test_split

import numpy as np
from datasets import Dataset
import evaluate
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import Trainer, T5Tokenizer, T5ForConditionalGeneration, TrainingArguments
from transformers import StoppingCriteria, StoppingCriteriaList

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {DEVICE}")

In [None]:
model_path = f"t5-wikismall-mimic-dir/checkpoint-11500"
t5_model = T5ForConditionalGeneration.from_pretrained(model_path)
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [None]:
MAX_LENGTH = 512

def summarize_text_t5(text):
    inputs = t5_tokenizer.encode(
        "summarize: " + text,
        return_tensors='pt',
        max_length=MAX_LENGTH,
        truncation=True
    )
    inputs = inputs.to(DEVICE)
    len1 = len(inputs[0])
 
    summary_ids = t5_model.generate(
        inputs,
        exponential_decay_length_penalty=((int) (len1 * 0.8), -1.05),
        encoder_repetition_penalty=0.3,
        no_repeat_ngram_size=4,
        max_length=50,
        num_beams=5,
        temperature=0.9,
    )
 
    return t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

In [None]:
t5_model = t5_model.to(DEVICE)
with open("metrics/sample.txt", 'r') as f:
    sample = [l.strip("\n") for l in f.readlines()]
# print(sample)
summary = [summarize_text_t5(l) for l in sample]
for s in summary:
    print(s)