In [1]:
# Step 1: Install required packages
!pip install -q transformers datasets evaluate rouge_score accelerate sacrebleu nltk

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
# Step 2: Imports and setup
import os
import torch
import numpy as np
import nltk
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, 
    Seq2SeqTrainingArguments, Seq2SeqTrainer, 
    DataCollatorForSeq2Seq, pipeline
)
from datasets import load_dataset
import evaluate

2025-07-24 04:54:56.990118: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753332897.328914      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753332897.422733      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [6]:
# Step 3: Login to Hugging Face
print("=== Setting up Hugging Face authentication ===")
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)

=== Setting up Hugging Face authentication ===


In [None]:
# Step 4: Check GPU and system info
print("=== System Information ===")
if torch.cuda.is_available():
    print("GPU is available!")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("GPU not available. Training will run on CPU.")


In [7]:
# Step 5: Load model and tokenizer
print("\n=== Loading Model and Tokenizer ===")
model_name = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

model.config.no_repeat_ngram_size = 3  # Prevent repeating 3-word sequences
model.config.length_penalty = 1.0      # No penalty on length (default is usually 1.0)

print(f"Model loaded: {model_name}")
print(f"Model parameters: {model.num_parameters():,}")


=== Loading Model and Tokenizer ===


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Model loaded: facebook/bart-base
Model parameters: 139,420,416


In [None]:
# Step 6: Load and prepare dataset
print("\n=== Loading Dataset ===")
dataset = load_dataset("cnn_dailymail", "3.0.0")

train_size = 100000 
val_size = 10000   

train_dataset = dataset["train"].select(range(train_size))
val_dataset = dataset["validation"].select(range(val_size))

print(f"Original dataset size: {len(dataset['train']):,} training examples")
print(f"Using subset: {train_size:,} training examples")
print(f"Validation subset: {val_size:,} examples")

In [None]:
# Step 7: Show sample data
print("\n=== Sample Data ===")
sample = train_dataset[0]
print(f"Article preview: {sample['article'][:300]}...")
print(f"Summary: {sample['highlights']}")

In [None]:
# Step 8: Preprocessing function
print("\n=== Setting up Data Preprocessing ===")
prefix = "summarize: "
max_input_length = 512  # Reduced for faster processing
max_target_length = 64 

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    
    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["highlights"], max_length=max_target_length, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply preprocessing
print("Preprocessing training data...")
tokenized_train = train_dataset.map(preprocess_function, batched=True)

print("Preprocessing validation data...")
tokenized_val = val_dataset.map(preprocess_function, batched=True)

print(f"Tokenized training examples: {len(tokenized_train):,}")
print(f"Tokenized validation examples: {len(tokenized_val):,}")

In [None]:
# Step 9: Setup evaluation metrics
print("\n=== Setting up Evaluation Metrics ===")

# Download NLTK data required for ROUGE
nltk.download('punkt', quiet=True)

# Load ROUGE metric
rouge_metric = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    
    # Extract key results and round
    result = {key: round(value * 100, 2) for key, value in result.items()}
    return result

In [None]:
# Step 10: Training arguments
print("\n=== Setting up Training Arguments ===")

# Make sure to change this to your Hugging Face username
YOUR_HF_USERNAME = "souradeepdutta"
MODEL_HUB_ID = f"{YOUR_HF_USERNAME}/bart-base-summarizer"

training_args = Seq2SeqTrainingArguments(
    output_dir="souradeepdutta/bart-base-summarizer",
    
    # Evaluation strategy
    eval_strategy="steps",
    eval_steps=3000, 
    save_strategy="steps",
    save_steps=3000,
    
    # Batch sizes
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32, 
    gradient_accumulation_steps=2,
    
    # Learning parameters
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_steps=500,
    
    # Training duration
    num_train_epochs=3,
    max_steps=-1,
    
    # Performance optimizations for GPU
    fp16=True,

    report_to="none",
    
    # Logging
    logging_steps=100,
    logging_strategy="steps",
    
    # Model saving and pushing to hub
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    push_to_hub=True,
    hub_model_id=MODEL_HUB_ID,
    
    # Generation settings for evaluation
    predict_with_generate=True,
    generation_max_length=64,
    generation_num_beams=4,
)

In [None]:
# Step 11: Data collator
print("\n=== Setting up Data Collator ===")
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model,
    padding=True,
    pad_to_multiple_of=8,  # Optimize for tensor cores on modern GPUs
)

In [None]:
# Step 12: Create trainer
print("\n=== Creating Trainer ===")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Step 13: Model optimizations for memory
print("\n=== Applying Model Optimizations ===")

# Gradient checkpointing saves memory at the cost of a slightly slower backward pass.
#model.gradient_checkpointing_enable()

# Disable cache during training, as it's only used for inference.
if hasattr(model.config, 'use_cache'):
    model.config.use_cache = False

In [None]:
# Step 14: Start training
print("\n=== Starting Training ===")

try:
    trainer.train()
    print("\n✅ Training completed successfully!")
except Exception as e:
    print(f"\n❌ Training failed: {str(e)}")
    raise


In [None]:
# Step 15: Save and push model to Hub
print("\n=== Saving and Pushing Final Model ===")
try:
    trainer.save_model()
    trainer.push_to_hub()
    print(f"✅ Model saved and pushed to {MODEL_HUB_ID} on the Hugging Face Hub!")
except Exception as e:
    print(f"❌ Error saving or pushing to hub: {str(e)}")
    print("Model was saved locally in souradeepdutta/bart-base-summarizer/")

In [None]:
# Step 16: Test the fine-tuned model
print("\n=== Testing the Trained Model ===")
try:
    # Load the fine-tuned model from the local directory for inference
    summarizer = pipeline(
        "summarization", 
        model="souradeepdutta/bart-base-summarizer",
        tokenizer=tokenizer
    )
    
    test_article = """
    NASA's James Webb Space Telescope has captured its first direct image of a planet outside our solar system. 
    The exoplanet, known as HIP 65426 b, is a gas giant about six to 12 times the mass of Jupiter. 
    This observation is a transformative moment for astronomy, as it points the way toward future observations 
    that will reveal more information than ever before about exoplanets. The telescope's advanced infrared 
    capabilities allow it to see past the glare of the host star to capture the faint planet.
    """
    
    print(f"--- Test Article ---")
    print(f"Original: {test_article.strip()}")
    
    summary = summarizer(
        test_article, 
        max_length=80, 
        min_length=15, 
        do_sample=False
    )
    
    print(f"\nGenerated Summary: {summary[0]['summary_text']}")
        
except Exception as e:
    print(f"❌ Error during testing: {str(e)}")

In [None]:
# Step 17: Final model evaluation on the test set
print("\n=== Final Model Evaluation on Test Set ===")
try:
    # Evaluate on a small subset of the test set
    test_dataset = dataset["test"].select(range(5000))
    tokenized_test = test_dataset.map(preprocess_function, batched=True)
    
    print("Running evaluation...")
    eval_results = trainer.evaluate(tokenized_test)
    
    print("\n--- Test Set ROUGE Scores ---")
    for key, value in eval_results.items():
        if 'rouge' in key:
            print(f"  {key}: {value}")
        
except Exception as e:
    print(f"❌ Error during final evaluation: {str(e)}")

print("\n=== Notebook Complete! ===")

In [10]:
# Step 18: Interactive Summarization DemoPaste your article herePaste your article here......
import ipywidgets as widgets
from IPython.display import display

print("\n=== Interactive Summarization Demo ===")

try:
    # 1. Load the fine-tuned model and tokenizer
    print("Loading your fine-tuned model...")
    model_path = "souradeepdutta/bart-base-summarizer"
    summarizer = pipeline(
        "summarization", 
        model=model_path,
        tokenizer=model_path,
        device=0
    )
    print("✅ Model loaded successfully!")

    # 2. Create a text area for user input
    print("\nPaste your article into the text box below and click 'Summarize'.")
    article_input = widgets.Textarea(
        value='Paste your article herePaste your article here......',
        placeholder='Type something',
        description='Article:',
        layout={'height': '200px', 'width': '95%'},
        disabled=False
    )

    # 3. Create a button to trigger summarization
    summarize_button = widgets.Button(
        description='Summarize',
        button_style='success',
        tooltip='Click to generate summary',
        icon='check'
    )

    # 4. Create an output area to display the result
    summary_output = widgets.Output()

    # 5. Define the function to run on button click
    def on_summarize_button_clicked(b):
        with summary_output:
            summary_output.clear_output() # Clear previous summary
            print("Generating summary...")
            
            # Get the text and generate the summary
            article_text = article_input.value
            if not article_text or article_text == 'Paste your article here...':
                print("❌ Please paste an article first.")
                return

            try:
                # Generate summary with sensible length constraints
                result = summarizer(
                    article_text, 
                    max_length=128,
                    min_length=30,
                    do_sample=True,
                    num_beams=4,
                    temperature=0.8,
                    top_p=0.95,
                    top_k = 50,
                    no_repeat_ngram_size=3
                )
                
                print("\n--- Generated Summary ---")
                print(result[0]['summary_text'])
                
            except Exception as e:
                print(f"An error occurred during summarization: {e}")

    # 6. Link the button to the function
    summarize_button.on_click(on_summarize_button_clicked)

    # 7. Display the widgets
    display(article_input, summarize_button, summary_output)

except Exception as e:
    print(f"\n❌ An error occurred while setting up the demo: {str(e)}")
    print("Please ensure that you have successfully trained and saved the model in the 'souradeepdutta/bart-base-summarizer' directory.")


=== Interactive Summarization Demo ===
Loading your fine-tuned model...


Device set to use cuda:0


✅ Model loaded successfully!

Paste your article into the text box below and click 'Summarize'.


Textarea(value='Paste your article herePaste your article here......', description='Article:', layout=Layout(h…

Button(button_style='success', description='Summarize', icon='check', style=ButtonStyle(), tooltip='Click to g…

Output()