In [None]:
import pandas as pd
import torch
import nltk
from nltk.tokenize import sent_tokenize
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModelForSequenceClassification
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import os
import re
from rouge import Rouge
import logging
import random
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset

In [None]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)
logger = logging.getLogger(__name__)

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
nltk.download('punkt', quiet=True)
nltk_data_path = os.path.join(os.getcwd(), 'nltk_data')
nltk.data.path.append(nltk_data_path)
print("Loading CNN/DailyMail dataset...")
dataset = load_dataset("cnn_dailymail", "3.0.0")

Loading CNN/DailyMail dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
train_subset = dataset["train"]
val_subset = dataset["validation"]
test_subset = dataset["test"]


dataset = {
    "train": train_subset,
    "validation": val_subset,
    "test": test_subset
}

print(f"Train set: {len(dataset['train'])} examples")
print(f"Validation set: {len(dataset['validation'])} examples")
print(f"Test set: {len(dataset['test'])} examples")

df_train = pd.DataFrame(dataset['train'])
df_val = pd.DataFrame(dataset['validation'])
df_test = pd.DataFrame(dataset['test'])

Train set: 1000 examples
Validation set: 100 examples
Test set: 100 examples


In [None]:
class BertSumExtractor(nn.Module):
    def __init__(self, bert_model_name="bert-base-uncased", num_labels=2):
        super(BertSumExtractor, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)


        self.doc_encoder = nn.TransformerEncoderLayer(
            d_model=self.bert.config.hidden_size,
            nhead=8,
            dim_feedforward=2048
        )
        self.doc_transformer = nn.TransformerEncoder(self.doc_encoder, num_layers=2)

    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):

        batch_size, num_sentences, seq_len = input_ids.size()
        input_ids = input_ids.view(-1, seq_len)
        attention_mask = attention_mask.view(-1, seq_len)


        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids.view(-1, seq_len) if token_type_ids is not None else None
        )


        sentence_embeddings = outputs.last_hidden_state[:, 0, :]
        sentence_embeddings = sentence_embeddings.view(batch_size, num_sentences, -1)


        doc_embeddings = self.doc_transformer(sentence_embeddings)

        logits = self.classifier(self.dropout(doc_embeddings))

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()

            loss = loss_fct(logits.view(-1, 2), labels.view(-1))

        return logits, loss

In [None]:
def create_extractive_data(df):

    data = []
    skipped = 0

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Creating extractive data"):
        article = row['article']
        highlights = row['highlights']

        article_sents = sent_tokenize(article)
        highlight_sents = sent_tokenize(highlights)

        if len(article_sents) == 0 or len(highlight_sents) == 0:
            skipped += 1
            continue


        labels = []


        for sent in article_sents:

            clean_sent = re.sub(r'\s+', ' ', sent.lower().strip())


            best_match_score = 0
            best_match_idx = -1

            for h_idx, h_sent in enumerate(highlight_sents):
                clean_h_sent = re.sub(r'\s+', ' ', h_sent.lower().strip())


                sent_tokens = set(clean_sent.split())
                h_tokens = set(clean_h_sent.split())

                if len(sent_tokens) == 0 or len(h_tokens) == 0:
                    continue

                overlap = len(sent_tokens.intersection(h_tokens))
                overlap_ratio = overlap / max(len(sent_tokens), len(h_tokens))

                if overlap_ratio > best_match_score:
                    best_match_score = overlap_ratio
                    best_match_idx = h_idx


            if best_match_score > 0.5:
                labels.append(1)
            else:
                labels.append(0)


        if sum(labels) == 0 and len(labels) > 0:

            sent_scores = []
            for sent_idx, sent in enumerate(article_sents):
                clean_sent = re.sub(r'\s+', ' ', sent.lower().strip())
                max_score = 0

                for h_sent in highlight_sents:
                    clean_h_sent = re.sub(r'\s+', ' ', h_sent.lower().strip())

                    sent_tokens = set(clean_sent.split())
                    h_tokens = set(clean_h_sent.split())

                    if len(sent_tokens) == 0 or len(h_tokens) == 0:
                        continue

                    overlap = len(sent_tokens.intersection(h_tokens))
                    overlap_ratio = overlap / max(len(sent_tokens), len(h_tokens))
                    max_score = max(max_score, overlap_ratio)

                sent_scores.append((sent_idx, max_score))


            sent_scores.sort(key=lambda x: x[1], reverse=True)
            for idx, _ in sent_scores[:min(3, len(sent_scores))]:
                labels[idx] = 1


        data.append({
            'article_sents': article_sents,
            'labels': labels
        })

    print(f"Skipped {skipped} articles due to missing sentences")
    return data

In [None]:
nltk.download('punkt_tab', quiet=True)

True

In [None]:
print("Processing the dataset for extractive summarization...")
extractive_train_data = create_extractive_data(df_train)
extractive_val_data = create_extractive_data(df_val)
print(f"Created {len(extractive_train_data)} training examples")
print(f"Created {len(extractive_val_data)} validation examples")




Processing the dataset for extractive summarization...


Creating extractive data: 100%|██████████| 1000/1000 [00:03<00:00, 276.37it/s]


Skipped 0 articles due to missing sentences


Creating extractive data: 100%|██████████| 100/100 [00:00<00:00, 426.33it/s]

Skipped 0 articles due to missing sentences
Created 1000 training examples
Created 100 validation examples





In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
class ExtractiveDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        sentences = item['article_sents']
        labels = item['labels']
        sentences = sentences[:20]
        labels = labels[:20]



        encodings = self.tokenizer(
            sentences,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encodings['input_ids'],
            'attention_mask': encodings['attention_mask'],
            'labels': torch.tensor(labels, dtype=torch.long)
        }



In [None]:
train_dataset = ExtractiveDataset(extractive_train_data, tokenizer,max_length=256)
val_dataset = ExtractiveDataset(extractive_val_data, tokenizer,max_length=256)

In [None]:
def collate_fn(batch):

    input_ids = []
    attention_masks = []
    labels = []

    for item in batch:
        input_ids.append(item['input_ids'])
        attention_masks.append(item['attention_mask'])
        labels.append(item['labels'])

    max_length = max(len(ids) for ids in input_ids)

    padded_input_ids = []
    padded_attention_masks = []
    padded_labels = []

    for ids, mask, label in zip(input_ids, attention_masks, labels):

        padding_length = max_length - len(ids)
        if padding_length > 0:
            zero_padding = torch.zeros((padding_length, ids.size(1)), dtype=ids.dtype)
            padded_ids = torch.cat([ids, zero_padding], dim=0)

            mask_padding = torch.zeros((padding_length, mask.size(1)), dtype=mask.dtype)
            padded_mask = torch.cat([mask, mask_padding], dim=0)


            label_padding = torch.ones(padding_length, dtype=torch.long) * -100
            padded_label = torch.cat([label, label_padding], dim=0)
        else:
            padded_ids = ids
            padded_mask = mask
            padded_label = label

        padded_input_ids.append(padded_ids)
        padded_attention_masks.append(padded_mask)
        padded_labels.append(padded_label)

    return {
        'input_ids': torch.stack(padded_input_ids),
        'attention_mask': torch.stack(padded_attention_masks),
        'labels': torch.stack(padded_labels)
    }


In [None]:
batch_size = 2  # Adjust based on your GPU memory
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)

In [None]:
print("Initializing the BertSum model...")
model = BertSumExtractor().to(device)


Initializing the BertSum model...




In [None]:
optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 3
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

In [None]:
import gc
def train():
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        print(f"\n======== Epoch {epoch+1} / {num_epochs} ========")


        model.train()
        total_train_loss = 0

        for step, batch in enumerat# Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)


            optimizer.zero_grad()


            logits, loss = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )


            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            torch.cuda.empty_cache()
            gc.collect()

            total_train_loss += loss.item()

            if step % 100 == 0 and step != 0:
                logger.info(f"Epoch: {epoch+1}/{num_epochs} | Step: {step}/{len(train_dataloader)} | Loss: {loss.item()}")

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss}")


        model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                logits, loss = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                total_val_loss += loss.item()


                preds = torch.argmax(logits, dim=-1)
                valid_indices = labels != -100

                all_preds.extend(preds[valid_indices].cpu().numpy())
                all_labels.extend(labels[valid_indices].cpu().numpy())

        avg_val_loss = total_val_loss / len(val_dataloader)
        accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

        print(f"Validation loss: {avg_val_loss}")
        print(f"Validation accuracy: {accuracy}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print("Saving best model...")
            torch.save(model.state_dict(), "best_bertsum_model.pt")

    print("Training completed!")

In [None]:
print("Starting training...")
train()

Starting training...



Training: 100%|██████████| 500/500 [16:28<00:00,  1.98s/it]


Average training loss: 0.3687620839253068


Validation: 100%|██████████| 50/50 [00:25<00:00,  1.94it/s]


Validation loss: 0.37115209728479387
Validation accuracy: 0.8764501160092807
Saving best model...



Training: 100%|██████████| 500/500 [16:30<00:00,  1.98s/it]


Average training loss: 0.30995405465364456


Validation: 100%|██████████| 50/50 [00:25<00:00,  1.95it/s]


Validation loss: 0.3929674586653709
Validation accuracy: 0.8723897911832946



Training: 100%|██████████| 500/500 [16:29<00:00,  1.98s/it]


Average training loss: 0.24253470274806022


Validation: 100%|██████████| 50/50 [00:25<00:00,  1.94it/s]

Validation loss: 0.4343030548095703
Validation accuracy: 0.8683294663573086
Training completed!





In [None]:
print("Loading best model for inference...")
model.load_state_dict(torch.load("best_bertsum_model.pt"))
model.eval()

Loading best model for inference...


BertSumExtractor(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [None]:
def summarize_article(article, model, tokenizer, top_n=3):

    sentences = sent_tokenize(article)
    if len(sentences) == 0:
        return ""


    input_ids = []
    attention_masks = []

    for sent in sentences:
        encoded = tokenizer(
            sent,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])


    input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
    attention_mask = torch.cat(attention_masks, dim=0).unsqueeze(0)


    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)


    with torch.no_grad():
        logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)


    sentence_scores = torch.softmax(logits, dim=-1)[0, :, 1].cpu().numpy()

    top_n = min(top_n, len(sentences))
    top_indices = np.argsort(sentence_scores)[-top_n:]
    top_indices = sorted(top_indices)


    summary = ' '.join([sentences[i] for i in top_indices])
    return summary


def evaluate_model(model, tokenizer, test_df, sample_size=100):
    """Evaluate model on test set and calculate ROUGE scores"""
    print("Evaluating the model on test set...")
    summaries = []
    rouge_scores = []
    rouge_calculator = Rouge()

    test_sample = test_df.sample(min(sample_size, len(test_df)))

    for idx, row in tqdm(test_sample.iterrows(), total=len(test_sample), desc="Generating summaries"):
        article = row['article']
        original_summary = row['highlights']

        try:

            generated_summary = summarize_article(article, model, tokenizer)
            summaries.append(generated_summary)


            score = rouge_calculator.get_scores(generated_summary, original_summary)[0]
            rouge_scores.append(score)
        except Exception as e:
            print(f"Error processing article {idx}: {str(e)}")
            summaries.append("")
            rouge_scores.append({'rouge-1': {'f': 0}, 'rouge-2': {'f': 0}, 'rouge-l': {'f': 0}})


    rouge_1_f = np.mean([score['rouge-1']['f'] for score in rouge_scores])
    rouge_2_f = np.mean([score['rouge-2']['f'] for score in rouge_scores])
    rouge_l_f = np.mean([score['rouge-l']['f'] for score in rouge_scores])

    print(f"\nAverage ROUGE-1 F1: {rouge_1_f:.4f}")
    print(f"Average ROUGE-2 F1: {rouge_2_f:.4f}")
    print(f"Average ROUGE-L F1: {rouge_l_f:.4f}")


    test_sample['generated_summary'] = summaries
    test_sample[['article', 'highlights', 'generated_summary']].to_csv('bertsum_results.csv', index=False)
    print("Results saved to 'bertsum_results.csv'")

    return test_sample




In [None]:
test_results = evaluate_model(model, tokenizer, df_test, sample_size=100)

Evaluating the model on test set...


Generating summaries: 100%|██████████| 100/100 [01:22<00:00,  1.21it/s]


Average ROUGE-1 F1: 0.2673
Average ROUGE-2 F1: 0.0959
Average ROUGE-L F1: 0.2465
Results saved to 'bertsum_results.csv'





In [None]:

summaries = []
rouge_scores = []
rouge_calculator = Rouge()

for idx, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Generating summaries"):
    article = row['article']
    original_summary = row['highlights']

    try:

        generated_summary = summarize_article(article, model, tokenizer)
        summaries.append(generated_summary)


        score = rouge_calculator.get_scores(generated_summary, original_summary)[0]
        rouge_scores.append(score)
    except Exception as e:
        print(f"Error processing article {idx}: {str(e)}")
        summaries.append("[SUMMARY GENERATION FAILED]")  # Placeholder for failed cases
        rouge_scores.append({'rouge-1': {'f': 0}, 'rouge-2': {'f': 0}, 'rouge-l': {'f': 0}})

if len(summaries) == len(df_test):
    df_test['bertsum_summary'] = summaries
    df_test[['article', 'highlights', 'bertsum_summary']].to_csv('bertsum_results.csv', index=False)
    print("Results saved to 'bertsum_results.csv'")


    print("\n===== Sample Summaries =====")
    for i in range(min(5, len(df_test))):
        print(f"\nOriginal Article (first 100 chars): {df_test.iloc[i]['article'][:100]}...")
        print(f"\nOriginal Summary: {df_test.iloc[i]['highlights']}")
        print(f"\nGenerated Summary: {df_test.iloc[i]['bertsum_summary']}")
        print("\n" + "-"*50)
else:
    print(f"Error: Generated {len(summaries)} summaries but have {len(df_test)} test articles")


if len(rouge_scores) > 0:
    rouge_1_f = np.mean([score['rouge-1']['f'] for score in rouge_scores])
    rouge_2_f = np.mean([score['rouge-2']['f'] for score in rouge_scores])
    rouge_l_f = np.mean([score['rouge-l']['f'] for score in rouge_scores])

    print(f"\nAverage ROUGE-1 F1: {rouge_1_f:.4f}")
    print(f"Average ROUGE-2 F1: {rouge_2_f:.4f}")
    print(f"Average ROUGE-L F1: {rouge_l_f:.4f}")

print("\nComplete! Evaluation finished.")

Generating summaries: 100%|██████████| 100/100 [01:22<00:00,  1.21it/s]

Results saved to 'bertsum_results.csv'

===== Sample Summaries =====

Original Article (first 100 chars): (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Cour...

Original Summary: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

Generated Summary: (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead.

------------------------------------


