In [1]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizerFast, 
    EncoderDecoderModel,
    BertModel,
    BertConfig,
    DataCollatorWithPadding,
    get_linear_schedule_with_warmup
)
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from bert_score import BERTScorer
import numpy as np

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0")
train_data = dataset['train'].shuffle(seed=42)
val_data   = dataset['validation'].shuffle(seed=42)
test_data  = dataset['test'].shuffle(seed=42)

In [4]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

special_tokens = {
    'pad_token': '[PAD]',
    'bos_token': '[CLS]',
    'eos_token': '[SEP]'
}
tokenizer.add_special_tokens(special_tokens)

0

In [5]:
config_encoder = BertConfig.from_pretrained('bert-base-uncased')
config_decoder = BertConfig.from_pretrained('bert-base-uncased')
config_decoder.is_decoder = True
config_decoder.add_cross_attention = True

In [6]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    'bert-base-uncased', 
    'bert-base-uncased',
    encoder_config=config_encoder,
    decoder_config=config_decoder
)

model.encoder.resize_token_embeddings(len(tokenizer))
model.decoder.resize_token_embeddings(len(tokenizer))

model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = len(tokenizer)

model.config.max_length = 128
model.config.min_length = 10
model.config.no_repeat_ngram_size = 3
model.config.early_stopping = True
model.config.num_beams = 4

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.e

In [7]:
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value", "key", "dense"]
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.to(device)

trainable params: 6,537,216 || all params: 253,900,602 || trainable%: 2.5747


PeftModelForSeq2SeqLM(
  (base_model): LoraModel(
    (model): EncoderDecoderModel(
      (encoder): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_feature

In [8]:
class SummarizationDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_target_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        article = self.data[idx]['article']
        highlights = self.data[idx]['highlights']
        
        inputs = self.tokenizer(
            article,
            max_length=self.max_input_length,
            padding=False,
            truncation=True,
            return_tensors="pt"
        )
        
        targets = self.tokenizer(
            highlights,
            max_length=self.max_target_length,
            padding=False,
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

In [9]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Pad sequences
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels
    }

In [10]:
batch_size = 8
train_dataset = SummarizationDataset(train_data, tokenizer)
val_dataset   = SummarizationDataset(val_data, tokenizer)
test_dataset  = SummarizationDataset(test_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [11]:
optimizer = AdamW(model.parameters(), lr=5e-4)
epochs = 3
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=total_steps // 10, 
    num_training_steps=total_steps
)

rouge = evaluate.load("rouge")
bert_scorer = BERTScorer(lang="en", rescale_with_baseline=True)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def generate_summaries(dataloader, num_samples=100):
    model.eval()
    generated_summaries = []
    reference_summaries = []
    
    with torch.no_grad():
        sample_count = 0
        for batch in dataloader:
            if sample_count >= num_samples:
                break
                
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Generate summaries
            generated_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=128,
                min_length=10,
                num_beams=4,
                early_stopping=True,
                no_repeat_ngram_size=3,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
            
            # Decode generated and reference summaries
            for i in range(len(generated_ids)):
                if sample_count >= num_samples:
                    break
                    
                generated_text = tokenizer.decode(generated_ids[i], skip_special_tokens=True)
                reference_text = tokenizer.decode(batch['labels'][i][batch['labels'][i] != -100], skip_special_tokens=True)
                
                generated_summaries.append(generated_text)
                reference_summaries.append(reference_text)
                sample_count += 1
    
    return generated_summaries, reference_summaries

In [13]:
def evaluate_model(dataloader, dataset_name="Validation"):
    print(f"\n=== {dataset_name} Evaluation ===")
    generated_summaries, reference_summaries = generate_summaries(dataloader, num_samples=100)
    
    rouge_results = rouge.compute(
        predictions=generated_summaries,
        references=reference_summaries,
        use_stemmer=True
    )
    
    P, R, F1 = bert_scorer.score(generated_summaries, reference_summaries)
    bert_score = {
        'precision': P.mean().item(),
        'recall': R.mean().item(),
        'f1': F1.mean().item()
    }
    
    print(f"ROUGE-1: {rouge_results['rouge1']:.4f}")
    print(f"ROUGE-2: {rouge_results['rouge2']:.4f}")
    print(f"ROUGE-L: {rouge_results['rougeL']:.4f}")
    print(f"BERTScore F1: {bert_score['f1']:.4f}")
    print(f"BERTScore Precision: {bert_score['precision']:.4f}")
    print(f"BERTScore Recall: {bert_score['recall']:.4f}")
    
    return rouge_results, bert_score

In [14]:
def train_epoch(dataloader, epoch):
    model.train()
    total_loss = 0.0
    
    for step, batch in enumerate(dataloader, 1):
        optimizer.zero_grad()
        
        inputs = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
        if step % 50 == 0:
            print(f"Epoch {epoch} Step {step}/{len(dataloader)} - Loss: {total_loss/step:.4f}")
    
    return total_loss / len(dataloader)

In [15]:
def eval_epoch(dataloader):
    model.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
for epoch in range(1, epochs + 1):
    print(f"\n=== Epoch {epoch}/{epochs} ===")
    train_loss = train_epoch(train_loader, epoch)
    print(f"Train Loss: {train_loss:.4f}")
    
    val_loss = eval_epoch(val_loader)
    print(f"Val Loss: {val_loss:.4f}")
    evaluate_model(val_loader, "Validation")

Epoch 1 Step 2000/287113 — Loss: 10.4970
Epoch 1 Step 4000/287113 — Loss: 10.4915
Epoch 1 Step 6000/287113 — Loss: 10.4860
Epoch 1 Step 8000/287113 — Loss: 10.4806
Epoch 1 Step 10000/287113 — Loss: 10.4751
Epoch 1 Step 12000/287113 — Loss: 10.4696
Epoch 1 Step 14000/287113 — Loss: 10.4641
Epoch 1 Step 16000/287113 — Loss: 10.4586
Epoch 1 Step 18000/287113 — Loss: 10.4531
Epoch 1 Step 20000/287113 — Loss: 10.4476
Epoch 1 Step 22000/287113 — Loss: 10.4422
Epoch 1 Step 24000/287113 — Loss: 10.4367
Epoch 1 Step 26000/287113 — Loss: 10.4312
Epoch 1 Step 28000/287113 — Loss: 10.4257
Epoch 1 Step 30000/287113 — Loss: 10.4202
Epoch 1 Step 32000/287113 — Loss: 10.4147
Epoch 1 Step 34000/287113 — Loss: 10.4092
Epoch 1 Step 36000/287113 — Loss: 10.4038
Epoch 1 Step 38000/287113 — Loss: 10.3983
Epoch 1 Step 40000/287113 — Loss: 10.3928
Epoch 1 Step 42000/287113 — Loss: 10.3873
Epoch 1 Step 44000/287113 — Loss: 10.3818
Epoch 1 Step 46000/287113 — Loss: 10.3763
Epoch 1 Step 48000/287113 — Loss: 10.3

In [None]:
test_loss = eval_epoch(test_loader)
print(f"Test Loss: {test_loss:.4f}")
rouge_results, bert_score = evaluate_model(test_loader, "Test")

Test Loss: 5.2265
=== Test Evaluation ===
ROUGE-1: 0.2688
ROUGE-2: 0.1269
ROUGE-L: 0.2452
BERTScore F1: 0.5623
BERTScore Precision: 0.5514
BERTScore Recall: 0.5568
