In [6]:
# Importing required libraries
import torch
import pandas as pd
import numpy as np
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)

Using device: cuda


<torch._C.Generator at 0x1c25afbd2f0>

In [7]:
# Load the CNN/DailyMail dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

# Convert the dataset to a pandas DataFrame
df = pd.DataFrame(dataset["train"])

# Split the data into train and validation sets
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

# Create a custom dataset class
class SummarizationDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_input_length, max_target_length):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        article = self.dataframe.iloc[idx]["article"]
        highlights = self.dataframe.iloc[idx]["highlights"]

        inputs = self.tokenizer.encode_plus(
            "summarize: " + article,
            max_length=self.max_input_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        targets = self.tokenizer.encode_plus(
            highlights,
            max_length=self.max_target_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": targets["input_ids"].squeeze()
        }

# Initialize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)

# Create datasets and dataloaders
max_input_length = 512
max_target_length = 128

train_dataset = SummarizationDataset(train_df, tokenizer, max_input_length, max_target_length)
val_dataset = SummarizationDataset(val_df, tokenizer, max_input_length, max_target_length)

batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Training set size: 258401
Validation set size: 28712


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
# Set up optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=2e-5)
num_training_steps = len(train_dataloader) * 5  # 5 epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

# Training function
def train(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# Validation function
def validate(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            total_loss += loss.item()

    return total_loss / len(dataloader)

In [12]:
import time
import os
from datetime import datetime

def save_checkpoint(iteration, epoch, model, optimizer, train_losses, val_losses):
    checkpoint = {
        "iteration": iteration,
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "val_losses": val_losses
    }
    checkpoint_path = f"checkpoint_epoch_{epoch}_iter_{iteration}.pth"
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")

def load_latest_checkpoint():
    checkpoints = [f for f in os.listdir('.') if f.startswith('checkpoint_epoch_')]
    if not checkpoints:
        return None
    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    return torch.load(latest_checkpoint)

# Training loop
num_epochs = 1
train_losses = []
val_losses = []

checkpoint = load_latest_checkpoint()
if checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_iteration = checkpoint['iteration']
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    print(f"Resuming training from epoch {start_epoch}, iteration {start_iteration}")
else:
    start_epoch = 0
    start_iteration = 0
    print("Starting training from scratch")

checkpoint_interval = 30 * 60  # 30 minutes in seconds
last_checkpoint_time = time.time()

for epoch in range(start_epoch, num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(tqdm(train_dataloader, desc="Training")):
        if epoch == start_epoch and i < start_iteration:
            continue  # Skip already processed iterations when resuming
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()

        # Check if it's time to save a checkpoint
        current_time = time.time()
        if current_time - last_checkpoint_time >= checkpoint_interval:
            avg_loss = epoch_loss / (i + 1)
            train_losses.append(avg_loss)
            save_checkpoint(i, epoch, model, optimizer, train_losses, val_losses)
            last_checkpoint_time = current_time

    # Compute average loss for the epoch
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    train_losses.append(avg_epoch_loss)
    
    # Validation
    val_loss = validate(model, val_dataloader, device)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch + 1} - Train Loss: {avg_epoch_loss:.4f}, Validation Loss: {val_loss:.4f}")
    
    # Save checkpoint at the end of each epoch
    save_checkpoint(len(train_dataloader), epoch + 1, model, optimizer, train_losses, val_losses)



Starting training from scratch
Epoch 1/5


Training:   0%|          | 120/32301 [03:01<9:24:49,  1.05s/it] 

Checkpoint saved to checkpoint_epoch_0_iter_119.pth


Training:   0%|          | 130/32301 [03:10<13:04:38,  1.46s/it]


KeyboardInterrupt: 

In [None]:
# ... (plotting code)

# Plot learning curves
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Learning Curves")
plt.legend()
plt.show()

In [None]:
def generate_summary(model, tokenizer, article, max_length=150):
    model.eval()
    input_ids = tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=512, truncation=True).to(device)
    
    summary_ids = model.generate(input_ids, max_length=max_length, num_beams=4, length_penalty=2.0, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    return summary

# Test the model on a sample article
sample_article = """
NASA's Perseverance rover has successfully landed on Mars, marking the beginning of its mission to search for signs of ancient microbial life on the Red Planet. The rover, which is about the size of a car, touched down in the Jezero Crater on February 18, 2021, after a seven-month journey through space. Equipped with advanced scientific instruments and a small helicopter named Ingenuity, Perseverance will explore the Martian surface, collect rock and soil samples, and pave the way for future human missions to Mars. This historic landing represents a major milestone in space exploration and brings us one step closer to understanding the potential for life beyond Earth.
"""

generated_summary = generate_summary(model, tokenizer, sample_article)
print("Generated Summary:")
print(generated_summary)