# Project 3: Fine-Tuning FLAN-T5 for Summarization & Measuring Forgetting

**Authors:** Shaunak Kapur & Pranav Krishnan

This notebook implements the Project 3 proposal: fine-tuning a small language model (`google/flan-t5-small`) on the Amazon Fine Food Reviews dataset to generate product review summaries. It also evaluates "forgetting" by checking the model's performance on a set of general knowledge questions before and after fine-tuning.


## 1. Setup and Installation

Installing required libraries: `transformers`, `datasets`, `evaluate`, `rouge_score`, `accelerate`, `sentencepiece`.


In [None]:
!pip install -q transformers datasets evaluate rouge_score accelerate sentencepiece


  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import pandas as pd
import numpy as np
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import evaluate

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


## 2. Load and Preprocess Data

We use the Amazon Fine Food Reviews dataset from Hugging Face. The dataset will be automatically downloaded using `load_dataset`.

We will:
1. Download the dataset from Hugging Face.
2. Convert to pandas DataFrame.
3. Drop rows with missing values.
4. Sample the data (e.g., 20,000 rows) to keep training time reasonable.
5. Split into Train (80%), Validation (10%), and Test (10%).


In [None]:
# Load dataset from Hugging Face
print("=" * 80)
print("STEP 1: Downloading dataset from Hugging Face...")
print("=" * 80)
ds = load_dataset("jhan21/amazon-food-reviews-dataset")
print(f"✓ Dataset loaded. Available splits: {list(ds.keys())}")

# Convert to pandas DataFrame (the dataset has a 'train' split)
print("\nConverting to pandas DataFrame...")
df = ds["train"].to_pandas()
print(f"✓ Original dataset size: {len(df)} rows, {len(df.columns)} columns")
print(f"  Columns: {list(df.columns)}")

# Keep relevant columns and drop NaNs
print("\nFiltering data...")
print(f"  Before filtering: {len(df)} rows")
df = df[["Summary", "Text"]].dropna()
print(f"  After dropping NaN: {len(df)} rows")

# Filter out very long reviews to save memory/time
df = df[df["Text"].str.len() <= 512]
print(f"  After filtering long reviews (<=512 chars): {len(df)} rows")

# Sample data for faster training (adjust as needed)
SAMPLE_SIZE = 20000
if len(df) > SAMPLE_SIZE:
    print(f"\nSampling {SAMPLE_SIZE} rows from {len(df)} total rows...")
    df = df.sample(SAMPLE_SIZE, random_state=42)
    print(f"✓ Sampled dataset size: {len(df)} rows")
else:
    print(f"\nUsing full dataset: {len(df)} rows")

print("\nSample data preview:")
print(df.head(3))
print(f"\n✓ Dataset ready: {len(df)} rows")
print("=" * 80)


STEP 1: Downloading dataset from Hugging Face...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

Reviews.csv:   0%|          | 0.00/301M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/568454 [00:00<?, ? examples/s]

✓ Dataset loaded. Available splits: ['train']

Converting to pandas DataFrame...
✓ Original dataset size: 568454 rows, 10 columns
  Columns: ['Id', 'ProductId', 'UserId', 'ProfileName', 'HelpfulnessNumerator', 'HelpfulnessDenominator', 'Score', 'Time', 'Summary', 'Text']

Filtering data...
  Before filtering: 568454 rows
  After dropping NaN: 568427 rows
  After filtering long reviews (<=512 chars): 420179 rows

Sampling 20000 rows from 420179 total rows...
✓ Sampled dataset size: 20000 rows

Sample data preview:
                 Summary                                               Text
310404   Very tasty bars  There are many varieties of bars on the market...
74819   Keep on Munching  My puppy loves this product.  The moment he ha...
271575       cody cramer  These meatballs look so delicious I wanted to ...

✓ Dataset ready: 20000 rows


In [None]:
from sklearn.model_selection import train_test_split

print("=" * 80)
print("STEP 2: Splitting dataset into train/val/test...")
print("=" * 80)

train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"✓ Split complete:")
print(f"  Train: {len(train_df)} rows (80%)")
print(f"  Validation: {len(val_df)} rows (10%)")
print(f"  Test: {len(test_df)} rows (10%)")

train_ds = Dataset.from_pandas(train_df.reset_index(drop=True))
val_ds = Dataset.from_pandas(val_df.reset_index(drop=True))
test_ds = Dataset.from_pandas(test_df.reset_index(drop=True))

print(f"\n✓ Datasets created:")
print(f"  Train: {len(train_ds)} samples")
print(f"  Val: {len(val_ds)} samples")
print(f"  Test: {len(test_ds)} samples")
print("=" * 80)


STEP 2: Splitting dataset into train/val/test...
✓ Split complete:
  Train: 16000 rows (80%)
  Validation: 2000 rows (10%)
  Test: 2000 rows (10%)

✓ Datasets created:
  Train: 16000 samples
  Val: 2000 samples
  Test: 2000 samples


## 3. Model and Tokenizer Setup

We use `google/flan-t5-small`. We load two copies:
1. `base_model`: Keeps original weights to measure baseline performance and forgetting.
2. `model`: Will be fine-tuned.


In [None]:
# Model selection based on AI recommendation; see [1]
MODEL_NAME = "google/flan-t5-small"

print("=" * 80)
print("STEP 3: Loading model and tokenizer...")
print("=" * 80)
print(f"Model: {MODEL_NAME}")
print(f"Device: {device}")

print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"✓ Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"  Pad token: {tokenizer.pad_token_id}, EOS token: {tokenizer.eos_token_id}")

print("\nLoading model for fine-tuning (GPU-optimized)...")
# GPU-OPTIMIZED: Load model directly on GPU with explicit dtype
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # Explicit FP32 dtype
    device_map="auto"  # Auto-place model on GPU (no CPU intermediate step)
)
model.train()  # Explicit training mode
print(f"✓ Model loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Model device: {next(model.parameters()).device}")
print(f"  Model dtype: {next(model.parameters()).dtype}")

print("\nLoading base model for comparison...")
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
    device_map="auto"
)
base_model.eval()  # Set to eval mode (won't be trained)
print(f"✓ Base model loaded")
print(f"  Base model device: {next(base_model.parameters()).device}")
print("=" * 80)

STEP 3: Loading model and tokenizer...
Model: google/flan-t5-small
Device: cuda

Loading tokenizer...


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

✓ Tokenizer loaded. Vocab size: 32100
  Pad token: 0, EOS token: 1

Loading model for fine-tuning (GPU-optimized)...


config.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

✓ Model loaded. Parameters: 76,961,152
  Model device: cuda:0
  Model dtype: torch.float32

Loading base model for comparison...
✓ Base model loaded
  Base model device: cuda:0


## 4. Tokenization

We preprocess the text inputs with a prefix "Summarize this review: ".


In [None]:
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 32
PREFIX = "Summarize this review: "

print("=" * 80)
print("STEP 4: Tokenizing datasets...")
print("=" * 80)
print(f"Max input length: {MAX_INPUT_LENGTH} tokens")
print(f"Max target length: {MAX_TARGET_LENGTH} tokens")
print(f"Prefix: '{PREFIX}'")

def preprocess_function(examples):
    inputs = [PREFIX + doc for doc in examples["Text"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)

    labels = tokenizer(text_target=examples["Summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("\nTokenizing training set...")
tokenized_train = train_ds.map(preprocess_function, batched=True, remove_columns=train_ds.column_names)
print(f"✓ Train tokenized: {len(tokenized_train)} samples")
print(f"  Columns after tokenization: {tokenized_train.column_names}")

print("Tokenizing validation set...")
tokenized_val = val_ds.map(preprocess_function, batched=True, remove_columns=val_ds.column_names)
print(f"✓ Val tokenized: {len(tokenized_val)} samples")

print("Tokenizing test set...")
tokenized_test = test_ds.map(preprocess_function, batched=True, remove_columns=test_ds.column_names)
print(f"✓ Test tokenized: {len(tokenized_test)} samples")

# Show example
print("\nExample tokenized input:")
example = tokenized_train[0]
print(f"  Input IDs length: {len(example['input_ids'])}")
print(f"  Labels length: {len(example['labels'])}")
print(f"  Decoded input: {tokenizer.decode(example['input_ids'][:50])}...")
print(f"  Decoded label: {tokenizer.decode([l for l in example['labels'] if l != -100])}")
print("=" * 80)

STEP 4: Tokenizing datasets...
Max input length: 256 tokens
Max target length: 32 tokens
Prefix: 'Summarize this review: '

Tokenizing training set...


Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

✓ Train tokenized: 16000 samples
  Columns after tokenization: ['input_ids', 'attention_mask', 'labels']
Tokenizing validation set...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

✓ Val tokenized: 2000 samples
Tokenizing test set...


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

✓ Test tokenized: 2000 samples

Example tokenized input:
  Input IDs length: 68
  Labels length: 4
  Decoded input: Summarize this review: This is the second time I purchased this Fancy Feast Chunky Turkey Feast for my finicky cats. This product/flavor is no longer carried in the markets here, so I'm...
  Decoded label: Great buy!</s>


## 5. Forgetting Analysis (Before Training)

We define a small set of general knowledge questions to test the "forgetting" hypothesis. We check how well the base model answers them.


In [None]:
# Forgetting analysis approach based on AI recommendation; see [4]
qa_pairs = [
    ("What is the capital of France?", "Paris"),
    ("How many days are in a week?", "7"),
    ("What gas do plants absorb?", "carbon dioxide"),
    ("What is the largest planet in our solar system?", "Jupiter"),
    ("What is H2O?", "water"),
    ("Who wrote Romeo and Juliet?", "Shakespeare"),
    ("What color is the sky on a clear day?", "blue"),
    ("What is 2 + 2?", "4")
]

def evaluate_forgetting(model_obj, tokenizer_obj, questions, device):
    model_obj.eval()
    correct = 0
    results = []

    print("--- Forgetting Analysis ---")
    for q, ans in questions:
        # FLAN-T5 prompt format based on AI guidance; see [5]
        prompt = f"Question: {q}\nAnswer:"
        input_ids = tokenizer_obj(prompt, return_tensors="pt", max_length=128, truncation=True).input_ids.to(device)

        with torch.no_grad():
            outputs = model_obj.generate(
                input_ids,
                max_length=50,
                num_beams=2,
                early_stopping=True,
                do_sample=False
            )

        pred = tokenizer_obj.decode(outputs[0], skip_special_tokens=True).strip()

        # More flexible answer matching
        pred_lower = pred.lower()
        ans_lower = ans.lower()

        # Check if answer is in prediction (handles partial matches)
        is_correct = (
            ans_lower in pred_lower or
            pred_lower in ans_lower or
            any(word in pred_lower for word in ans_lower.split() if len(word) > 2)
        )

        # Special cases for numeric answers
        if ans.isdigit():
            # Extract numbers from prediction
            import re
            numbers = re.findall(r'\d+', pred)
            is_correct = ans in numbers or is_correct

        if is_correct:
            correct += 1

        results.append({"Question": q, "Expected": ans, "Predicted": pred, "Correct": is_correct})
        print(f"Q: {q}")
        print(f"  Expected: {ans} | Predicted: {pred} | {'✓' if is_correct else '✗'}")

    accuracy = correct / len(questions)
    print(f"\nAccuracy: {accuracy:.2%} ({correct}/{len(questions)})")
    return accuracy, results

print("=" * 80)
print("STEP 5: Evaluating Base Model on QA set (Before Training)...")
print("=" * 80)
print(f"Number of QA pairs: {len(qa_pairs)}")
print(f"Device: {device}")
print(f"Base model on device: {next(base_model.parameters()).device}")

try:
    base_qa_acc, base_qa_results = evaluate_forgetting(base_model, tokenizer, qa_pairs, device)
    print(f"\n✓ Base model evaluation complete!")
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Base model evaluation failed: {e}")
    import traceback
    traceback.print_exc()
    raise


STEP 5: Evaluating Base Model on QA set (Before Training)...
Number of QA pairs: 8
Device: cuda
Base model on device: cuda:0
--- Forgetting Analysis ---
Q: What is the capital of France?
  Expected: Paris | Predicted: london | ✗
Q: How many days are in a week?
  Expected: 7 | Predicted: 7 days | ✓
Q: What gas do plants absorb?
  Expected: carbon dioxide | Predicted: helium | ✗
Q: What is the largest planet in our solar system?
  Expected: Jupiter | Predicted: venus | ✗
Q: What is H2O?
  Expected: water | Predicted: H2O | ✗
Q: Who wrote Romeo and Juliet?
  Expected: Shakespeare | Predicted: edward wilson | ✗
Q: What color is the sky on a clear day?
  Expected: blue | Predicted: blue | ✓
Q: What is 2 + 2?
  Expected: 4 | Predicted: 2 + 2 | ✗

Accuracy: 25.00% (2/8)

✓ Base model evaluation complete!


## 6. Fine-Tuning

We use `Seq2SeqTrainer` to fine-tune the model.


In [None]:
# ROUGE metric implementation based on AI guidance; see [2]
rouge = evaluate.load("rouge")
print("✓ ROUGE metric loaded")

# GPU DIAGNOSTICS - Check environment before training
print("\n" + "=" * 80)
print("GPU DIAGNOSTICS")
print("=" * 80)
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU capability: {torch.cuda.get_device_capability(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Supports BF16: {torch.cuda.get_device_capability()[0] >= 7}")
print("=" * 80)

# Check model for NaN weights
print("\nChecking model weights for NaN...")
has_nan = False
nan_params = []
for name, param in model.named_parameters():
    if torch.isnan(param).any():
        print(f"  ❌ NaN found in: {name}")
        has_nan = True
        nan_params.append(name)
if not has_nan:
    print("  ✓ No NaN in model weights")
else:
    print(f"\n  ⚠️  WARNING: Found NaN in {len(nan_params)} parameters!")
    print(f"  This will cause training to fail. Model needs to be reloaded.")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")
print("=" * 80)

def compute_metrics(eval_pred):
    try:
        predictions, labels = eval_pred
        print(f"\n[DEBUG] compute_metrics called - predictions shape: {np.array(predictions).shape}, labels shape: {np.array(labels).shape}")

        # Convert to numpy if needed and ensure valid token IDs
        predictions = np.array(predictions)
        labels = np.array(labels)

        # Clip predictions to valid token ID range (0 to vocab_size-1) - fix for OverflowError; see [3]
        vocab_size = tokenizer.vocab_size
        predictions = np.clip(predictions, 0, vocab_size - 1)
        print(f"[DEBUG] Predictions clipped to vocab range [0, {vocab_size-1}]")

        # Decode predictions
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        print(f"[DEBUG] Decoded {len(decoded_preds)} predictions")

        # Replace -100 (ignored labels) with pad_token_id for decoding
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        labels = np.clip(labels, 0, vocab_size - 1)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        print(f"[DEBUG] Decoded {len(decoded_labels)} labels")

        # Compute ROUGE
        result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        print(f"[DEBUG] ROUGE computed: {result}")

        # Calculate actual generation length (only count non-padding tokens up to EOS)
        gen_lens = []
        for pred in predictions:
            # Find EOS token or count non-padding tokens
            pred_list = pred.tolist() if hasattr(pred, 'tolist') else list(pred)
            # Remove padding tokens (0) and count until EOS (1 for T5)
            length = 0
            for token_id in pred_list:
                if token_id == tokenizer.eos_token_id or token_id == 1:  # EOS token
                    break
                if token_id != tokenizer.pad_token_id and token_id != 0:
                    length += 1
            gen_lens.append(length)

        avg_gen_len = np.mean(gen_lens) if gen_lens else 0
        result["gen_len"] = avg_gen_len
        print(f"[DEBUG] Average generation length: {avg_gen_len:.2f}")

        # Convert ROUGE scores to percentages (but NOT gen_len)
        final_result = {}
        for k, v in result.items():
            if k == "gen_len":
                final_result[k] = round(v, 2)  # Keep gen_len as-is, just round
            else:
                final_result[k] = round(v * 100, 4)  # Convert ROUGE to percentage

        print(f"[DEBUG] Final metrics: {final_result}")
        return final_result
    except Exception as e:
        print(f"[ERROR] compute_metrics failed: {e}")
        import traceback
        traceback.print_exc()
        return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0, "rougeLsum": 0.0, "gen_len": 0.0}

# DEBUG: Let's inspect a batch to see what's being fed to the model
print("\n" + "=" * 80)
print("DIAGNOSTIC: Inspecting training data batch...")
print("=" * 80)

# Create data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Get a small batch
sample_batch = [tokenized_train[i] for i in range(2)]
collated_batch = data_collator(sample_batch)

print(f"Batch keys: {list(collated_batch.keys())}")
print(f"Input IDs shape: {collated_batch['input_ids'].shape}")
print(f"Labels shape: {collated_batch['labels'].shape}")
print(f"\nSample input IDs (first 20): {collated_batch['input_ids'][0][:20].tolist()}")
print(f"Sample labels (first 20): {collated_batch['labels'][0][:20].tolist()}")

# Count how many labels are NOT -100 (i.e., actual labels vs padding)
labels_array = collated_batch['labels'].numpy()
non_ignore_labels = np.sum(labels_array != -100)
total_labels = labels_array.size
ignore_ratio = (total_labels - non_ignore_labels) / total_labels

print(f"\nLabel statistics:")
print(f"  Total label positions: {total_labels}")
print(f"  Non-ignore labels (not -100): {non_ignore_labels}")
print(f"  Ignore labels (-100): {total_labels - non_ignore_labels}")
print(f"  Ignore ratio: {ignore_ratio:.2%}")

if ignore_ratio > 0.95:
    print("\n⚠️  WARNING: More than 95% of labels are -100 (ignore)!")
    print("   This will cause very low or zero loss!")

# Try a forward pass to see actual loss
print("\n" + "=" * 80)
print("DIAGNOSTIC: Testing forward pass with sample batch...")
print("=" * 80)

model.eval()
with torch.no_grad():
    # Move batch to device
    batch_device = {k: v.to(device) for k, v in collated_batch.items()}
    outputs = model(**batch_device)
    loss = outputs.loss
    loss_value = loss.item()
    print(f"Forward pass loss: {loss_value:.6f}")

    if np.isnan(loss_value):
        print("\n⚠️  CRITICAL: Forward pass loss is NaN!")
        print("   This indicates a GPU-specific numerical instability issue.")
        print(f"   Model device: {next(model.parameters()).device}")
        print(f"   Batch device: {batch_device['input_ids'].device}")
        print(f"   Model dtype: {next(model.parameters()).dtype}")
        print("\n   SOLUTION: Will enable BF16 training for GPU compatibility.")
    elif loss_value == 0.0:
        print("\n⚠️  CRITICAL: Forward pass loss is 0.0!")
        print("   This confirms the model is not computing loss correctly.")
        print("   The issue is likely with label preparation.")
    else:
        print(f"✓ Forward pass loss looks good: {loss_value:.6f} - Model should learn!")

print("=" * 80)

# Custom callback to print training loss in real-time
from transformers import TrainerCallback

class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            # Print training loss if available
            if "loss" in logs:
                step = state.global_step
                loss = logs["loss"]
                print(f"\n[TRAINING] Step {step}: Loss = {loss:.6f}")
                # If loss is 0 or NaN, print a warning
                if np.isnan(loss):
                    print("  ⚠️  WARNING: Loss is NaN - training has a problem!")
                elif loss == 0.0:
                    print("  ⚠️  WARNING: Loss is 0.0 - model may not be learning!")
                elif loss > 0.01:
                    print("  ✓ Loss looks good - model is learning!")

            # Print learning rate if available
            if "learning_rate" in logs:
                lr = logs["learning_rate"]
                print(f"[TRAINING] Learning Rate = {lr:.6f}")

print("=" * 80)
print("STEP 6: Setting up training...")
print("=" * 80)

# GPU-optimized training arguments with BF16 support
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7

args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-summarizer",
    eval_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=2,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LENGTH,
    generation_num_beams=4,
    bf16=use_bf16,  # Use BF16 for modern GPUs (fixes NaN loss!)
    fp16=False,  # Don't use FP16
    logging_steps=100,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
)

print(f"Training configuration:")
print(f"  Epochs: {args.num_train_epochs}")
print(f"  Batch size: {args.per_device_train_batch_size}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  Generation max length: {args.generation_max_length}")
print(f"  Generation num beams: {args.generation_num_beams}")
print(f"  BF16: {args.bf16} (GPU-optimized for {'modern' if use_bf16 else 'older'} GPU)")
print(f"  FP16: {args.fp16}")
print(f"  Output dir: {args.output_dir}")

print("\nCreating trainer...")
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[LossLoggingCallback()],
)
print(f"✓ Trainer created")
print(f"  Train samples: {len(tokenized_train)}")
print(f"  Eval samples: {len(tokenized_val)}")

print("\n" + "=" * 80)
print("Starting training...")
print("=" * 80)

try:
    trainer.train()
    print("\n" + "=" * 80)
    print("✓ Training completed successfully!")
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Training failed: {e}")
    import traceback
    traceback.print_exc()
    raise

Downloading builder script: 0.00B [00:00, ?B/s]

✓ ROUGE metric loaded

GPU DIAGNOSTICS
CUDA available: True
CUDA version: 12.6
PyTorch version: 2.9.0+cu126
GPU name: Tesla T4
GPU capability: (7, 5)
GPU memory: 15.83 GB
Supports BF16: True

Checking model weights for NaN...
  ✓ No NaN in model weights
Model device: cuda:0
Model dtype: torch.float32

DIAGNOSTIC: Inspecting training data batch...
Batch keys: ['input_ids', 'attention_mask', 'labels', 'decoder_input_ids']
Input IDs shape: torch.Size([2, 124])
Labels shape: torch.Size([2, 6])

Sample input IDs (first 20): [12198, 1635, 1737, 48, 1132, 10, 100, 19, 8, 511, 97, 27, 3907, 48, 377, 6833, 377, 11535, 4004, 6513]
Sample labels (first 20): [1651, 805, 55, 1, -100, -100]

Label statistics:
  Total label positions: 12
  Non-ignore labels (not -100): 10
  Ignore labels (-100): 2
  Ignore ratio: 16.67%

DIAGNOSTIC: Testing forward pass with sample batch...
Forward pass loss: 3.075017
✓ Forward pass loss looks good: 3.075017 - Model should learn!
STEP 6: Setting up training...


  trainer = Seq2SeqTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.


Training configuration:
  Epochs: 2
  Batch size: 4
  Learning rate: 0.0002
  Generation max length: 32
  Generation num beams: 4
  BF16: True (GPU-optimized for modern GPU)
  FP16: False
  Output dir: ./flan-t5-summarizer

Creating trainer...
✓ Trainer created
  Train samples: 16000
  Eval samples: 2000

Starting training...


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mshaunak1206[0m ([33mshaunak[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,3.107,3.028076,15.8881,5.8383,15.5122,15.5552,7.43
2,2.8723,2.997926,17.0325,6.1056,16.6849,16.7319,6.77



[TRAINING] Step 100: Loss = 3.652800
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000198

[TRAINING] Step 200: Loss = 3.444100
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000195

[TRAINING] Step 300: Loss = 3.552100
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000193

[TRAINING] Step 400: Loss = 3.321600
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000190

[TRAINING] Step 500: Loss = 3.409300
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000188

[TRAINING] Step 600: Loss = 3.428800
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000185

[TRAINING] Step 700: Loss = 3.404900
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000183

[TRAINING] Step 800: Loss = 3.320700
  ✓ Loss looks good - model is learning!
[TRAINING] Learning Rate = 0.000180

[TRAINING] Step 900: Loss = 3.327500
  ✓ Loss looks good - model is learning!
[

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].



✓ Training completed successfully!


## 7. Evaluation: Summarization Quality

Compare ROUGE scores and look at qualitative examples.


In [None]:
print("=" * 80)
print("STEP 7: Evaluating on test set...")
print("=" * 80)
print(f"Test samples: {len(tokenized_test)}")

try:
    test_results = trainer.evaluate(tokenized_test)
    print("\n✓ Test evaluation complete!")
    print("\nTest Results:")
    for key, value in test_results.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Test evaluation failed: {e}")
    import traceback
    traceback.print_exc()


STEP 7: Evaluating on test set...
Test samples: 2000



[DEBUG] compute_metrics called - predictions shape: (2000, 32), labels shape: (2000, 32)
[DEBUG] Predictions clipped to vocab range [0, 32099]
[DEBUG] Decoded 2000 predictions
[DEBUG] Decoded 2000 labels
[DEBUG] ROUGE computed: {'rouge1': np.float64(0.17236512279656996), 'rouge2': np.float64(0.07056816339276491), 'rougeL': np.float64(0.16915911774918535), 'rougeLsum': np.float64(0.16872712370953644)}
[DEBUG] Average generation length: 6.62
[DEBUG] Final metrics: {'rouge1': np.float64(17.2365), 'rouge2': np.float64(7.0568), 'rougeL': np.float64(16.9159), 'rougeLsum': np.float64(16.8727), 'gen_len': np.float64(6.62)}

✓ Test evaluation complete!

Test Results:
  eval_loss: 2.9697
  eval_rouge1: 17.2365
  eval_rouge2: 7.0568
  eval_rougeL: 16.9159
  eval_rougeLsum: 16.8727
  eval_gen_len: 6.6200
  eval_runtime: 168.0712
  eval_samples_per_second: 11.9000
  eval_steps_per_second: 2.9750
  epoch: 2.0000


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Qualitative Comparison
print("=" * 80)
print("STEP 8: Qualitative Comparison...")
print("=" * 80)

def generate_summary(model_obj, text, device):
    try:
        inputs = tokenizer(PREFIX + text, return_tensors="pt", max_length=MAX_INPUT_LENGTH, truncation=True).to(device)
        outputs = model_obj.generate(inputs.input_ids, max_length=MAX_TARGET_LENGTH, num_beams=4)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        print(f"[ERROR] generate_summary failed: {e}")
        return f"[ERROR: {str(e)}]"

sample_indices = [0, 5, 10, 15, 20]
print(f"Comparing {len(sample_indices)} examples from test set...")
print("=" * 80)

for i, idx in enumerate(sample_indices, 1):
    try:
        print(f"\nExample {i}/{len(sample_indices)} (Index {idx}):")
        example = test_ds[idx]
        text = example["Text"]
        ref_summary = example["Summary"]

        print(f"  Generating base model summary...")
        base_summary = generate_summary(base_model, text, device)

        print(f"  Generating fine-tuned model summary...")
        ft_summary = generate_summary(model, text, device)

        print(f"\n  Review: {text[:200]}...")
        print(f"  Reference: {ref_summary}")
        print(f"  Base Model: {base_summary}")
        print(f"  Fine-Tuned: {ft_summary}")
        print("-" * 80)
    except Exception as e:
        print(f"[ERROR] Failed to process example {idx}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n✓ Qualitative comparison complete!")
print("=" * 80)


STEP 8: Qualitative Comparison...
Comparing 5 examples from test set...

Example 1/5 (Index 0):
  Generating base model summary...
  Generating fine-tuned model summary...

  Review: This chip has a little tomato taste but the jalepeno seems mild. Good alternative to plain tortilla chips. Not overwhelmed by the taste but a good change....
  Reference: Good snack, not very hot
  Base Model: This chip has a little tomato taste but the jalepeno seems mild. Good alternative to plain tortilla chips. Not overwhelmed by the taste but 
  Fine-Tuned: A little tomato taste but a little mild
--------------------------------------------------------------------------------

Example 2/5 (Index 5):
  Generating base model summary...
  Generating fine-tuned model summary...

  Review: I received my first jar as a gift and promptly fell in love with it! When I found it on this website I was so thrilled.  The transaction went very smoothly.  Not a problem to be found.  I think the pr...
  Reference: Bes

## 8. Forgetting Analysis (After Training)

Check if the fine-tuned model has forgotten general knowledge.


In [None]:
print("=" * 80)
print("STEP 9: Evaluating Fine-Tuned Model on QA set (After Training)...")
print("=" * 80)
print(f"Number of QA pairs: {len(qa_pairs)}")
print(f"Device: {device}")
print(f"Fine-tuned model on device: {next(model.parameters()).device}")

try:
    ft_qa_acc, ft_qa_results = evaluate_forgetting(model, tokenizer, qa_pairs, device)
    print(f"\n✓ Fine-tuned model evaluation complete!")

    print("\n" + "=" * 80)
    print("FORGETTING ANALYSIS SUMMARY")
    print("=" * 80)
    print(f"Base Model QA Accuracy: {base_qa_acc:.2%} ({base_qa_acc * len(qa_pairs):.0f}/{len(qa_pairs)})")
    print(f"Fine-Tuned Model QA Accuracy: {ft_qa_acc:.2%} ({ft_qa_acc * len(qa_pairs):.0f}/{len(qa_pairs)})")

    diff = ft_qa_acc - base_qa_acc
    print(f"Change in Accuracy: {diff:+.2%}")

    if diff < 0:
        print(f"⚠️  Forgetting detected! Model lost {abs(diff):.2%} accuracy on general knowledge.")
    elif diff > 0:
        print(f"✓ Model improved by {diff:.2%} (unexpected but good!)")
    else:
        print(f"→ No change in general knowledge performance.")

    print("=" * 80)
except Exception as e:
    print(f"\n[ERROR] Fine-tuned model evaluation failed: {e}")
    import traceback
    traceback.print_exc()
    raise


STEP 9: Evaluating Fine-Tuned Model on QA set (After Training)...
Number of QA pairs: 8
Device: cuda
Fine-tuned model on device: cuda:0
--- Forgetting Analysis ---
Q: What is the capital of France?
  Expected: Paris | Predicted: French capital | ✗
Q: How many days are in a week?
  Expected: 7 | Predicted: 7 days | ✓
Q: What gas do plants absorb?
  Expected: carbon dioxide | Predicted: gas | ✗
Q: What is the largest planet in our solar system?
  Expected: Jupiter | Predicted: Earth | ✗
Q: What is H2O?
  Expected: water | Predicted: H2O | ✗
Q: Who wrote Romeo and Juliet?
  Expected: Shakespeare | Predicted: edmund wilson | ✗
Q: What color is the sky on a clear day?
  Expected: blue | Predicted: blue sky | ✓
Q: What is 2 + 2?
  Expected: 4 | Predicted: 2 + 2 | ✗

Accuracy: 25.00% (2/8)

✓ Fine-tuned model evaluation complete!

FORGETTING ANALYSIS SUMMARY
Base Model QA Accuracy: 25.00% (2/8)
Fine-Tuned Model QA Accuracy: 25.00% (2/8)
Change in Accuracy: +0.00%
→ No change in general knowle

## 9. Save Model

Save the fine-tuned model to be downloaded.


In [None]:
trainer.save_model("./finetuned_summarizer_final")
tokenizer.save_pretrained("./finetuned_summarizer_final")

print("Model saved to ./finetuned_summarizer_final")
# To download from Colab:
# from google.colab import files
# !zip -r model.zip ./finetuned_summarizer_final
# files.download('model.zip')


Model saved to ./finetuned_summarizer_final
