<a href="https://colab.research.google.com/github/sarvadutt/T5-Summarization-Reddit/blob/main/T5_Summarization_Reddit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install transformers datasets evaluate rouge_score accelerate sentencepiece

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer,T5ForConditionalGeneration
import evaluate
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.preprocessing import LabelEncoder
from torch import optim
import torch.nn.functional as F

In [None]:
reddit = load_dataset('reddit_tifu', 'long')

In [None]:
reddit["train"][0]

{'ups': 115.0,
 'num_comments': 23.0,
 'upvote_ratio': 0.8799999952316284,
 'score': 115.0,
 'documents': 'this actually happened a couple of years ago. i grew up in germany where i went to a german secondary school that went from 5th to 13th grade (we still had 13 grades then, they have since changed that). my school was named after anne frank and we had a club that i was very active in from 9th grade on, which was dedicated to teaching incoming 5th graders about anne franks life, discrimination, anti-semitism, hitler, the third reich and that whole spiel. basically a day where the students\' classes are cancelled and instead we give them an interactive history and social studies class with lots of activities and games. \n\nthis was my last year at school and i already had a lot of experience doing these project days with the kids. i was running the thing with a friend, so it was just the two of us and 30-something 5th graders. we start off with a brief introduction and brainstorming:

In [None]:
# Load T5 tokenizer
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
prefix = "summarize: "

In [None]:
# Assuming 'documents', 'tldr', and 'title' are the relevant columns
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["documents"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["tldr"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
#  Tokenize the Reddit TIFU dataset
tokenized_reddit = reddit.map(preprocess_function, batched=True)


In [None]:
#  Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
#  Load Rouge metric
rouge = evaluate.load("rouge")

In [None]:
#  Compute metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}


In [None]:
#  Load T5 base model
model = T5ForConditionalGeneration.from_pretrained(checkpoint)

##Preparing data & Loading

In [None]:
import gc
# Assuming 'text' is the key for the input data and 'summary' is the key for the target data
input_key = 'documents'
labels_key = 'tldr'

# Assuming reddit["train"] is your raw data
#raw_data = reddit["train"]
raw_data = tokenized_reddit["train"]

# Use LabelEncoder to convert text labels to numerical values
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform([item[labels_key] for item in raw_data])

# Define a custom dataset
class MyDataset(Dataset):
    def __init__(self, text, labels, max_length=1024):
        self.text = text
        self.labels = labels
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Tokenize the text with truncation and padding
        inputs = tokenizer(
            self.text[idx][input_key],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),

            # check before run
            'labels': inputs['input_ids'].squeeze().unsqueeze(0).clone()[:, 1:]

        }

# Assuming you have defined batch_size previously
batch_size = 8

# Create DataLoader for the full dataset
full_dataset = MyDataset(raw_data, labels)
full_data_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)

# Clear unnecessary variables
del raw_data
gc.collect()
del labels
gc.collect()

# Define the sizes for train and validation sets
train_size = int(0.8 * len(full_dataset))
valid_size = len(full_dataset) - train_size

# Split the dataset
train_dataset, valid_dataset = random_split(full_dataset, [train_size, valid_size])

# Clear unnecessary variables
del full_dataset
gc.collect()

# Create DataLoader for training and validation sets
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# Clear unnecessary variables
del train_dataset
gc.collect()
del valid_dataset
gc.collect()

0

##Training and saving the model to local

In [None]:
# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)

# Assuming you have defined the number of training epochs: num_epochs
num_epochs = 4

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(train_data_loader):
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits

        # Prepare target distributions for label smoothing
        target_dist = torch.full_like(logits, 0.1 / (logits.size(-1) - 1))
        target_dist.scatter_(-1, input_ids.unsqueeze(-1), 0.9)

        # Compute Kullback-Leibler Divergence loss
        loss = criterion(F.log_softmax(logits, dim=-1), target_dist)

        # Backward pass
        loss.backward()

        # Accumulate gradients
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # Validation loop (optional)
    model.eval()
    total_valid_loss = 0.0
    with torch.no_grad():
        for batch in valid_data_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(input_ids, labels=input_ids)
            logits = outputs.logits

            # Prepare target distributions for label smoothing
            target_dist = torch.full_like(logits, 0.1 / (logits.size(-1) - 1))
            target_dist.scatter_(-1, input_ids.unsqueeze(-1), 0.9)

            # Compute Kullback-Leibler Divergence loss
            loss = criterion(F.log_softmax(logits, dim=-1), target_dist)
            total_valid_loss += loss.item()

    average_valid_loss = total_valid_loss / len(valid_data_loader)

    # Print or log training/validation loss
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {loss.item()}, Validation Loss: {average_valid_loss}")



model.save_pretrained("/content")
tokenizer.save_pretrained("/content")