# Project 3: Fine-Tuning FLAN-T5 for Summarization & Measuring Forgetting

**Authors:** Shaunak Kapur & Pranav Krishnan

This notebook implements the Project 3 proposal: fine-tuning a small language model (`google/flan-t5-small`) on the Amazon Fine Food Reviews dataset to generate product review summaries. It also evaluates "forgetting" by checking the model's performance on a set of general knowledge questions before and after fine-tuning.


## 1. Setup and Installation

Installing required libraries: `transformers`, `datasets`, `evaluate`, `rouge_score`, `accelerate`, `sentencepiece`.


In [None]:
!pip install -q transformers datasets evaluate rouge_score accelerate sentencepiece


In [None]:
import torch
import pandas as pd
import numpy as np
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


## 2. Load and Preprocess Data

We use the Amazon Fine Food Reviews dataset from Hugging Face. The dataset will be automatically downloaded using `load_dataset`.

We will:
1. Download the dataset from Hugging Face.
2. Convert to pandas DataFrame.
3. Drop rows with missing values.
4. Sample the data (e.g., 20,000 rows) to keep training time reasonable.
5. Split into Train (80%), Validation (10%), and Test (10%).


In [None]:
# Load dataset from Hugging Face
print("=" * 80)
print("STEP 1: Downloading dataset from Hugging Face...")
print("=" * 80)
ds = load_dataset("jhan21/amazon-food-reviews-dataset")
print(f"✓ Dataset loaded. Available splits: {list(ds.keys())}")

# Convert to pandas DataFrame (the dataset has a 'train' split)
print("\nConverting to pandas DataFrame...")
df = ds["train"].to_pandas()
print(f"✓ Original dataset size: {len(df)} rows, {len(df.columns)} columns")
print(f"  Columns: {list(df.columns)}")

# Keep relevant columns and drop NaNs
print("\nFiltering data...")
print(f"  Before filtering: {len(df)} rows")
df = df[["Summary", "Text"]].dropna()
print(f"  After dropping NaN: {len(df)} rows")

# Filter out very long reviews to save memory/time
df = df[df["Text"].str.len() <= 512]
print(f"  After filtering long reviews (<=512 chars): {len(df)} rows")

# Sample data for faster training (adjust as needed)
SAMPLE_SIZE = 20000
if len(df) > SAMPLE_SIZE:
    print(f"\nSampling {SAMPLE_SIZE} rows from {len(df)} total rows...")
    df = df.sample(SAMPLE_SIZE, random_state=42)
    print(f"✓ Sampled dataset size: {len(df)} rows")
else:
    print(f"\nUsing full dataset: {len(df)} rows")

print("\nSample data preview:")
print(df.head(3))
print(f"\n✓ Dataset ready: {len(df)} rows")
print("=" * 80)


In [None]:
from sklearn.model_selection import train_test_split

print("=" * 80)
print("STEP 2: Splitting dataset into train/val/test...")
print("=" * 80)

train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"✓ Split complete:")
print(f"  Train: {len(train_df)} rows (80%)")
print(f"  Validation: {len(val_df)} rows (10%)")
print(f"  Test: {len(test_df)} rows (10%)")

train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))
val_ds = Dataset.from_pandas(val_df.reset_index(drop=True))
test_ds = Dataset.from_pandas(test_df.reset_index(drop=True))

print(f"\n✓ Datasets created:")
print(f"  Train: {len(train_ds)} samples")
print(f"  Val: {len(val_ds)} samples")
print(f"  Test: {len(test_ds)} samples")
print("=" * 80)


## 3. Model and Tokenizer Setup

We use `google/flan-t5-small`. We load two copies:
1. `base_model`: Keeps original weights to measure baseline performance and forgetting.
2. `model`: Will be fine-tuned.


In [None]:
# Model selection based on AI recommendation; see [1]
MODEL_NAME = "google/flan-t5-small"

print("=" * 80)
print("STEP 3: Loading model and tokenizer...")
print("=" * 80)
print(f"Model: {MODEL_NAME}")
print(f"Device: {device}")

print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"✓ Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"  Pad token: {tokenizer.pad_token_id}, EOS token: {tokenizer.eos_token_id}")

print("\nLoading model for fine-tuning (GPU-optimized)...")
# GPU-OPTIMIZED: Load model directly on GPU with explicit dtype
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # Explicit FP32 dtype
    device_map="auto"  # Auto-place model on GPU (no CPU intermediate step)
)
model.train()  # Explicit training mode
print(f"✓ Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Model device: {next(model.parameters()).device}")
print(f"  Model dtype: {next(model.parameters()).dtype}")

print("\nLoading base model for comparison...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map="auto"
)
base_model.eval()  # Set to eval mode (won't be trained)
print(f"✓ Base model loaded")
print(f"  Base model device: {next(base_model.parameters()).device}")
print("=" * 80)

## 4. Tokenization

We preprocess the text inputs with a prefix "Summarize this review: ".


In [None]:
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 32
PREFIX = "Summarize this review: "

print("=" * 80)
print("STEP 4: Tokenizing datasets...")
print("=" * 80)
print(f"Max input length: {MAX_INPUT_LENGTH} tokens")
print(f"Max target length: {MAX_TARGET_LENGTH} tokens")
print(f"Prefix: '{PREFIX}'")

def preprocess_function(examples):
    inputs = [PREFIX + doc for doc in examples["Text"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)

    labels = tokenizer(text_target=examples["Summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("\nTokenizing training set...")
tokenized_train = train_ds.map(preprocess_function, batched=True, remove_columns=train_ds.column_names)
print(f"✓ Train tokenized: {len(tokenized_train)} samples")
print(f"  Columns after tokenization: {tokenized_train.column_names}")

print("Tokenizing validation set...")
tokenized_val = val_ds.map(preprocess_function, batched=True, remove_columns=val_ds.column_names)
print(f"✓ Val tokenized: {len(tokenized_val)} samples")

print("Tokenizing test set...")
tokenized_test = test_ds.map(preprocess_function, batched=True, remove_columns=test_ds.column_names)
print(f"✓ Test tokenized: {len(tokenized_test)} samples")

# Show example
print("\nExample tokenized input:")
example = tokenized_train[0]
print(f"  Input IDs length: {len(example['input_ids'])}")
print(f"  Labels length: {len(example['labels'])}")
print(f"  Decoded input: {tokenizer.decode(example['input_ids'][:50])}...")
print(f"  Decoded label: {tokenizer.decode([l for l in example['labels'] if l != -100])}")
print("=" * 80)

## 5. Forgetting Analysis (Before Training)

We define a small set of general knowledge questions to test the "forgetting" hypothesis. We check how well the base model answers them.


In [None]:
# Forgetting analysis approach based on AI recommendation; see [4]
qa_pairs = [
    ("What is the capital of France?", "Paris"),
    ("How many days are in a week?", "7"),
    ("What gas do plants absorb?", "carbon dioxide"),
    ("What is the largest planet in our solar system?", "Jupiter"),
    ("What is H2O?", "water"),
    ("Who wrote Romeo and Juliet?", "Shakespeare"),
    ("What color is the sky on a clear day?", "blue"),
    ("What is 2 + 2?", "4")
]

def evaluate_forgetting(model_obj, tokenizer_obj, questions, device):
    model_obj.eval()
    correct = 0
    results = []
    
    print("--- Forgetting Analysis ---")
    for q, ans in questions:
        # FLAN-T5 prompt format based on AI guidance; see [5]
        prompt = f"Question: {q}\nAnswer:"
        input_ids = tokenizer_obj(prompt, return_tensors="pt", max_length=128, truncation=True).input_ids.to(device)
        
        with torch.no_grad():
            outputs = model_obj.generate(
                input_ids, 
                max_length=50,
                num_beams=2,
                early_stopping=True,
                do_sample=False
            )
        
        pred = tokenizer_obj.decode(outputs[0], skip_special_tokens=True).strip()
        
        # More flexible answer matching
        pred_lower = pred.lower()
        ans_lower = ans.lower()
        
        # Check if answer is in prediction (handles partial matches)
        is_correct = (
            ans_lower in pred_lower or 
            pred_lower in ans_lower or
            any(word in pred_lower for word in ans_lower.split() if len(word) > 2)
        )
        
        # Special cases for numeric answers
        if ans.isdigit():
            # Extract numbers from prediction
            import re
            numbers = re.findall(r'\d+', pred)
            is_correct = ans in numbers or is_correct
        
        if is_correct:
            correct += 1
            
        results.append({"Question": q, "Expected": ans, "Predicted": pred, "Correct": is_correct})
        print(f"Q: {q}")
        print(f"  Expected: {ans} | Predicted: {pred} | {'✓' if is_correct else '✗'}")
    
    accuracy = correct / len(questions)
    print(f"\nAccuracy: {accuracy:.2%} ({correct}/{len(questions)})")
    return accuracy, results

print("=" * 80)
print("STEP 5: Evaluating Base Model on QA set (Before Training)...")
print("=" * 80)
print(f"Number of QA pairs: {len(qa_pairs)}")
print(f"Device: {device}")
print(f"Base model on device: {next(base_model.parameters()).device}")

try:
    base_qa_acc, base_qa_results = evaluate_forgetting(base_model, tokenizer, qa_pairs, device)
    print(f"\n✓ Base model evaluation complete!")
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Base model evaluation failed: {e}")
    import traceback
    traceback.print_exc()
    raise


## 6. Fine-Tuning

We use `Seq2SeqTrainer` to fine-tune the model.


In [None]:
# ROUGE metric implementation based on AI guidance; see [2]rouge = evaluate.load("rouge")print("✓ ROUGE metric loaded")print(f"CUDA available: {torch.cuda.is_available()}")print(f"CUDA version: {torch.version.cuda}")print(f"PyTorch version: {torch.__version__}")if torch.cuda.is_available():    print(f"GPU name: {torch.cuda.get_device_name(0)}")    print(f"GPU capability: {torch.cuda.get_device_capability(0)}")    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")    print(f"Supports BF16: {torch.cuda.get_device_capability()[0] >= 7}")print("=" * 80)def compute_metrics(eval_pred):    try:        predictions, labels = eval_pred                # Convert to numpy if needed and ensure valid token IDs        predictions = np.array(predictions)        labels = np.array(labels)                # Clip predictions to valid token ID range (0 to vocab_size-1) - fix for OverflowError; see [3]        vocab_size = tokenizer.vocab_size        predictions = np.clip(predictions, 0, vocab_size - 1)                # Decode predictions        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)                # Replace -100 (ignored labels) with pad_token_id for decoding        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)        labels = np.clip(labels, 0, vocab_size - 1)        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)                # Compute ROUGE        result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)                # Calculate actual generation length (only count non-padding tokens up to EOS)        gen_lens = []        for pred in predictions:            # Find EOS token or count non-padding tokens            pred_list = pred.tolist() if hasattr(pred, 'tolist') else list(pred)            # Remove padding tokens (0) and count until EOS (1 for T5)            length = 0            for token_id in pred_list:                if token_id == tokenizer.eos_token_id or token_id == 1:  # EOS token                    break                if token_id != tokenizer.pad_token_id and token_id != 0:                    length += 1            gen_lens.append(length)                avg_gen_len = np.mean(gen_lens) if gen_lens else 0        result["gen_len"] = avg_gen_len                # Convert ROUGE scores to percentages (but NOT gen_len)        final_result = {}        for k, v in result.items():            if k == "gen_len":                final_result[k] = round(v, 2)  # Keep gen_len as-is, just round            else:                final_result[k] = round(v * 100, 4)  # Convert ROUGE to percentage        return final_result    except Exception as e:        import traceback        traceback.print_exc()        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0, "gen_len": 0.0}# Create data collatordata_collator = DataCollatorForSeq2Seq(tokenizer, model=model)# Get a small batchsample_batch = [tokenized_train[i] for i in range(2)]collated_batch = data_collator(sample_batch)print(f"Batch keys: {list(collated_batch.keys())}")print(f"Input IDs shape: {collated_batch['input_ids'].shape}")print(f"Labels shape: {collated_batch['labels'].shape}")print(f"\nSample input IDs (first 20): {collated_batch['input_ids'][0][:20].tolist()}")print(f"Sample labels (first 20): {collated_batch['labels'][0][:20].tolist()}")# Count how many labels are NOT -100 (i.e., actual labels vs padding)labels_array = collated_batch['labels'].numpy()non_ignore_labels = np.sum(labels_array != -100)total_labels = labels_array.sizeignore_ratio = (total_labels - non_ignore_labels) / total_labelsprint(f"\nLabel statistics:")print(f"  Total label positions: {total_labels}")print(f"  Non-ignore labels (not -100): {non_ignore_labels}")print(f"  Ignore labels (-100): {total_labels - non_ignore_labels}")print(f"  Ignore ratio: {ignore_ratio:.2%}")if ignore_ratio > 0.95:    print("   This will cause very low or zero loss!")# Try a forward pass to see actual lossprint("\n" + "=" * 80)model.eval()with torch.no_grad():    # Move batch to device    batch_device = {k: v.to(device) for k, v in collated_batch.items()}    outputs = model(**batch_device)    loss = outputs.loss    loss_value = loss.item()    print(f"Forward pass loss: {loss_value:.6f}")        if np.isnan(loss_value):        print("   This indicates a GPU-specific numerical instability issue.")        print(f"   Model device: {next(model.parameters()).device}")        print(f"   Batch device: {batch_device['input_ids'].device}")        print(f"   Model dtype: {next(model.parameters()).dtype}")    elif loss_value == 0.0:        print("   This confirms the model is not computing loss correctly.")        print("   The issue is likely with label preparation.")    else:        print(f"✓ Forward pass loss looks good: {loss_value:.6f} - Model should learn!")print("=" * 80)print("=" * 80)print("STEP 6: Setting up training...")print("=" * 80)# GPU-optimized training arguments with BF16 supportuse_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7args = Seq2SeqTrainingArguments(    output_dir="./flan-t5-summarizer",    eval_strategy="epoch",    learning_rate=2e-4,    per_device_train_batch_size=4,    per_device_eval_batch_size=4,    weight_decay=0.01,    save_total_limit=1,    num_train_epochs=2,    predict_with_generate=True,    generation_max_length=MAX_TARGET_LENGTH,    generation_num_beams=4,    fp16=False,  # Don't use FP16    logging_steps=100,    save_strategy="epoch",    load_best_model_at_end=True,    metric_for_best_model="rouge1",)print(f"Training configuration:")print(f"  Epochs: {args.num_train_epochs}")print(f"  Batch size: {args.per_device_train_batch_size}")print(f"  Learning rate: {args.learning_rate}")print(f"  Generation max length: {args.generation_max_length}")print(f"  Generation num beams: {args.generation_num_beams}")print(f"  BF16: {args.bf16} (GPU-optimized for {'modern' if use_bf16 else 'older'} GPU)")print(f"  FP16: {args.fp16}")print(f"  Output dir: {args.output_dir}")print("\nCreating trainer...")trainer = Seq2SeqTrainer(    model=model,    args=args,    train_dataset=tokenized_train,    eval_dataset=tokenized_val,    data_collator=data_collator,    tokenizer=tokenizer,    compute_metrics=compute_metrics,)print(f"✓ Trainer created")print(f"  Train samples: {len(tokenized_train)}")print(f"  Eval samples: {len(tokenized_val)}")print("\n" + "=" * 80)print("Starting training...")print("=" * 80)trainer.train()

## 7. Evaluation: Summarization Quality

Compare ROUGE scores and look at qualitative examples.


In [None]:
print("=" * 80)
print("STEP 7: Evaluating on test set...")
print("=" * 80)
print(f"Test samples: {len(tokenized_test)}")

try:
    test_results = trainer.evaluate(tokenized_test)
    print("\n✓ Test evaluation complete!")
    print("\nTest Results:")
    for key, value in test_results.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Test evaluation failed: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Qualitative Comparison
print("=" * 80)
print("STEP 8: Qualitative Comparison...")
print("=" * 80)

def generate_summary(model_obj, text, device):
    try:
        inputs = tokenizer(PREFIX + text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
        outputs = model_obj.generate(inputs.input_ids, max_length=MAX_TARGET_LENGTH, num_beams=4)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        print(f"[ERROR] generate_summary failed: {e}")
        return f"[ERROR: {str(e)}]"

sample_indices = [0, 5, 10, 15, 20]
print(f"Comparing {len(sample_indices)} examples from test set...")
print("=" * 80)

for i, idx in enumerate(sample_indices, 1):
    try:
        print(f"\nExample {i}/{len(sample_indices)} (Index {idx}):")
        example = test_ds[idx]
        text = example["Text"]
        ref_summary = example["Summary"]
        
        print(f"  Generating base model summary...")
        base_summary = generate_summary(base_model, text, device)
        
        print(f"  Generating fine-tuned model summary...")
        ft_summary = generate_summary(model, text, device)
        
        print(f"\n  Review: {text[:200]}...")
        print(f"  Reference: {ref_summary}")
        print(f"  Base Model: {base_summary}")
        print(f"  Fine-Tuned: {ft_summary}")
        print("-" * 80)
    except Exception as e:
        print(f"[ERROR] Failed to process example {idx}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n✓ Qualitative comparison complete!")
print("=" * 80)


## 8. Forgetting Analysis (After Training)

Check if the fine-tuned model has forgotten general knowledge.


In [None]:
print("=" * 80)
print("STEP 9: Evaluating Fine-Tuned Model on QA set (After Training)...")
print("=" * 80)
print(f"Number of QA pairs: {len(qa_pairs)}")
print(f"Device: {device}")
print(f"Fine-tuned model on device: {next(model.parameters()).device}")

try:
    ft_qa_acc, ft_qa_results = evaluate_forgetting(model, tokenizer, qa_pairs, device)
    print(f"\n✓ Fine-tuned model evaluation complete!")
    
    print("\n" + "=" * 80)
    print("FORGETTING ANALYSIS SUMMARY")
    print("=" * 80)
    print(f"Base Model QA Accuracy: {base_qa_acc:.2%} ({base_qa_acc * len(qa_pairs):.0f}/{len(qa_pairs)})")
    print(f"Fine-Tuned Model QA Accuracy: {ft_qa_acc:.2%} ({ft_qa_acc * len(qa_pairs):.0f}/{len(qa_pairs)})")
    
    diff = ft_qa_acc - base_qa_acc
    print(f"Change in Accuracy: {diff:+.2%}")
    
    if diff < 0:
        print(f"⚠️  Forgetting detected! Model lost {abs(diff):.2%} accuracy on general knowledge.")
    elif diff > 0:
        print(f"✓ Model improved by {diff:.2%} (unexpected but good!)")
    else:
        print(f"→ No change in general knowledge performance.")
    
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Fine-tuned model evaluation failed: {e}")
    import traceback
    traceback.print_exc()
    raise


## 9. Save Model

Save the fine-tuned model to be downloaded.


In [None]:
trainer.save_model("./finetuned_summarizer_final")
tokenizer.save_pretrained("./finetuned_summarizer_final")

print("Model saved to ./finetuned_summarizer_final")
# To download from Colab:
# from google.colab import files
# !zip -r model.zip ./finetuned_summarizer_final
# files.download('model.zip')


# Project 3: Fine-Tuning FLAN-T5 for Summarization & Measuring Forgetting

**Authors:** Shaunak Kapur & Pranav Krishnan

This notebook implements the Project 3 proposal: fine-tuning a small language model (`google/flan-t5-small`) on the Amazon Fine Food Reviews dataset to generate product review summaries. It also evaluates "forgetting" by checking the model's performance on a set of general knowledge questions before and after fine-tuning.


## 1. Setup and Installation

Installing required libraries: `transformers`, `datasets`, `evaluate`, `rouge_score`, `accelerate`, `sentencepiece`.


In [None]:
!pip install -q transformers datasets evaluate rouge_score accelerate sentencepiece


In [None]:
import torch
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate

# The code below was generated by AI; see [2].
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


## 2. Load and Preprocess Data

We use the Amazon Fine Food Reviews dataset. 
**Important:** You must upload `Reviews.csv` to the Colab runtime files (left sidebar) before running this cell.

We will:
1. Load the CSV.
2. Drop rows with missing values.
3. Sample the data (e.g., 20,000 rows) to keep training time reasonable.
4. Split into Train (80%), Validation (10%), and Test (10%).


In [None]:
# Load dataset
# The code below was generated by AI; see [2].
try:
    df = pd.read_csv("Reviews.csv")
except FileNotFoundError:
    print("Error: Reviews.csv not found. Please upload it to the Colab runtime.")
    # Create dummy data for demonstration purposes if file is missing so notebook can still 'run' structurally
    data = {
        "Summary": ["Great product", "Not good", "Okay item"] * 100,
        "Text": ["This is a really great product I loved it.", "This was terrible do not buy.", "It was just okay nothing special."] * 100
    }
    df = pd.DataFrame(data)

# Keep relevant columns and drop NaNs
df = df[["Summary", "Text"]].dropna()

# Filter out very long reviews to save memory/time
df = df[df["Text"].str.len() <= 512]

# Sample data for faster training (adjust as needed)
SAMPLE_SIZE = 20000
if len(df) > SAMPLE_SIZE:
    df = df.sample(SAMPLE_SIZE, random_state=42)

print(f"Dataset size: {len(df)}")
df.head()


In [None]:
from sklearn.model_selection import train_test_split

# The code below was generated by AI; see [2].
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))
val_ds = Dataset.from_pandas(val_df.reset_index(drop=True))
test_ds = Dataset.from_pandas(test_df.reset_index(drop=True))

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")


## 3. Model and Tokenizer Setup

We use `google/flan-t5-small`. We load two copies:
1. `base_model`: Keeps original weights to measure baseline performance and forgetting.
2. `model`: Will be fine-tuned.


In [None]:
MODEL_NAME = "google/flan-t5-small"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Model to be fine-tuned
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Base model for comparison (frozen)
base_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
base_model.to(device)
print("Models loaded.")


## 4. Tokenization

We preprocess the text inputs with a prefix "Summarize this review: ".


In [None]:
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 32
PREFIX = "Summarize this review: "

def preprocess_function(examples):
    inputs = [PREFIX + doc for doc in examples["Text"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)

    labels = tokenizer(text_target=examples["Summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = train_ds.map(preprocess_function, batched=True)
tokenized_val = val_ds.map(preprocess_function, batched=True)
tokenized_test = test_ds.map(preprocess_function, batched=True)


## 5. Forgetting Analysis (Before Training)

We define a small set of general knowledge questions to test the "forgetting" hypothesis. We check how well the base model answers them.


In [None]:
qa_pairs = [
    ("What is the capital of France?", "Paris"),
    ("How many days are in a week?", "7"),
    ("What gas do plants absorb?", "carbon dioxide"),
    ("What is the largest planet in our solar system?", "Jupiter"),
    ("What is H2O?", "water"),
    ("Who wrote Romeo and Juliet?", "Shakespeare"),
    ("What color is the sky on a clear day?", "blue"),
    ("What is 2 + 2?", "4")
]

def evaluate_forgetting(model_obj, tokenizer_obj, questions, device):
    model_obj.eval()
    correct = 0
    results = []
    
    print("--- Forgetting Analysis ---")
    for q, ans in questions:
        input_ids = tokenizer_obj("Answer the question: " + q, return_tensors="pt").input_ids.to(device)
        
        with torch.no_grad():
            outputs = model_obj.generate(input_ids, max_length=20)
        
        pred = tokenizer_obj.decode(outputs[0], skip_special_tokens=True)
        is_correct = ans.lower() in pred.lower()
        if is_correct:
            correct += 1
            
        results.append({"Question": q, "Expected": ans, "Predicted": pred, "Correct": is_correct})
        print(f"Q: {q} | Pred: {pred} | Expected: {ans}")
    
    accuracy = correct / len(questions)
    print(f"Accuracy: {accuracy:.2%}")
    return accuracy, results

print("Evaluating Base Model on QA set...")
base_qa_acc, base_qa_results = evaluate_forgetting(base_model, tokenizer, qa_pairs, device)


## 6. Fine-Tuning

We use `Seq2SeqTrainer` to fine-tune the model.


In [None]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v * 100, 4) for k, v in result.items()}

args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-summarizer",
    evaluation_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=(device == "cuda"),
    logging_steps=100,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# The code below was generated by AI; see [2].
trainer.train()


## 7. Evaluation: Summarization Quality

Compare ROUGE scores and look at qualitative examples.


In [None]:
print("Evaluating on Test Set...")
test_results = trainer.evaluate(tokenized_test)
print(test_results)


In [None]:
# Qualitative Comparison
def generate_summary(model_obj, text, device):
    inputs = tokenizer(PREFIX + text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
    outputs = model_obj.generate(inputs.input_ids, max_length=MAX_TARGET_LENGTH, num_beams=4)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

sample_indices = [0, 5, 10, 15, 20]
print("--- Qualitative Results ---\n")

for idx in sample_indices:
    example = test_ds[idx]
    text = example["Text"]
    ref_summary = example["Summary"]
    
    base_summary = generate_summary(base_model, text, device)
    ft_summary = generate_summary(model, text, device)
    
    print(f"Review: {text[:200]}...")
    print(f"Reference: {ref_summary}")
    print(f"Base Model: {base_summary}")
    print(f"Fine-Tuned: {ft_summary}")
    print("-" * 80)


## 8. Forgetting Analysis (After Training)

Check if the fine-tuned model has forgotten general knowledge.


In [None]:
print("Evaluating Fine-Tuned Model on QA set...")
ft_qa_acc, ft_qa_results = evaluate_forgetting(model, tokenizer, qa_pairs, device)

print(f"\nBase Model QA Accuracy: {base_qa_acc:.2%}")
print(f"Fine-Tuned Model QA Accuracy: {ft_qa_acc:.2%}")

diff = ft_qa_acc - base_qa_acc
print(f"Change in Accuracy: {diff:.2%}")


## 9. Save Model

Save the fine-tuned model to be downloaded.


In [None]:
trainer.save_model("./finetuned_summarizer_final")
tokenizer.save_pretrained("./finetuned_summarizer_final")

print("Model saved to ./finetuned_summarizer_final")
# To download from Colab:
# from google.colab import files
# !zip -r model.zip ./finetuned_summarizer_final
# files.download('model.zip')
