In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from tqdm import tqdm
import logging
from rouge_score import rouge_scorer
import random

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Download NLTK resources
nltk.download('punkt')
nltk.download('punkt_tab')

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {device}')

# Function to load CSV with proper encoding
def load_csv(file_path, delimiter_symbol):
    try:
        return pd.read_csv(file_path, delimiter=delimiter_symbol, engine='python', encoding='iso-8859-1')
    except UnicodeDecodeError:
        logger.warning("ISO-8859-1 encoding failed, trying latin1")
        return pd.read_csv(file_path, delimiter=delimiter_symbol, engine='python', encoding='latin1')

# Build vocabulary for extractive model
def build_vocab(csv_file, min_freq=2):
    data = load_csv(csv_file, ";;;;;;")
    texts = data['content'].astype(str).tolist()
    word_counts = Counter()
    for text in tqdm(texts, desc="Building Extractive Vocabulary"):
        words = word_tokenize(text.lower())
        word_counts.update(words)
    vocab = {'<PAD>': 0, '<UNK>': 1}
    idx = len(vocab)
    for word, count in word_counts.items():
        if count >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab

# Custom Dataset for extractive model
class SummaryDataset(Dataset):
    def __init__(self, csv_file, vocab, max_length=40, max_sentences=7000):
        self.data = load_csv(csv_file, ";;;;;;")
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.max_length = max_length
        self.max_sentences = max_sentences
        self.chapters = []
        self.chapter_data = []
        
        # Group data by chapter
        grouped = self.data.groupby('chapter')
        for chapter_name, group in grouped:
            # Sort by sentence ID to ensure correct order
            group = group.sort_values('sentence', ascending=True)
            texts = []
            labels = []
            sentence_ids = []
            tokenized_indices = []
            
            for idx, row in group.iterrows():
                text = str(row['content'])
                label = row['in_summary']
                sentence_id = row['sentence']
                try:
                    label_int = int(label)
                    if label_int not in [0, 1]:
                        logger.warning(f"Skipping row {idx}: Invalid label value {label}")
                        continue
                    sentence_id_int = int(sentence_id)
                except (ValueError, TypeError):
                    logger.warning(f"Skipping row {idx}: Invalid label {label} or sentence ID {sentence_id}")
                    continue
                
                # Pre-tokenize
                words = word_tokenize(text.lower())
                indices = [self.vocab.get(word, self.vocab['<UNK>']) for word in words]
                if len(indices) > max_length:
                    indices = indices[:max_length]
                else:
                    indices += [self.vocab['<PAD>']] * (max_length - len(indices))
                
                # Validate indices
                indices = [max(0, min(idx, self.vocab_size - 1)) for idx in indices]
                if any(idx < 0 or idx >= self.vocab_size for idx in indices):
                    logger.warning(f"Invalid indices in text at index {idx}: {indices}")
                    indices = [self.vocab['<PAD>']] * max_length
                
                texts.append(text)
                labels.append(label_int)
                sentence_ids.append(sentence_id_int)
                tokenized_indices.append(indices)
            
            if texts:
                # Truncate at initialization to avoid oversized chapters
                if len(tokenized_indices) > max_sentences:
                    logger.warning(f"Chapter {chapter_name} has {len(tokenized_indices)} sentences, truncating to {max_sentences}")
                    tokenized_indices = tokenized_indices[:max_sentences]
                    labels = labels[:max_sentences]
                    texts = texts[:max_sentences]
                    sentence_ids = sentence_ids[:max_sentences]
                self.chapters.append(chapter_name)
                self.chapter_data.append({
                    'texts': texts,
                    'labels': labels,
                    'sentence_ids': sentence_ids,
                    'tokenized_indices': tokenized_indices
                })
        
        if not self.chapters:
            raise ValueError("No valid chapters found after cleaning")

    def __len__(self):
        return len(self.chapters)

    def __getitem__(self, idx):
        chapter = self.chapter_data[idx]
        tokenized_indices = chapter['tokenized_indices']
        labels = chapter['labels']
        num_sentences = len(tokenized_indices)  # Number of sentences after init truncation
        
        # Convert to tensors without padding to max_sentences
        input_ids = torch.tensor(tokenized_indices, dtype=torch.long)  # Shape: [num_sentences, max_length]
        labels = torch.tensor(labels, dtype=torch.long)               # Shape: [num_sentences]
        
        # Verify tensor sizes
        assert num_sentences <= self.max_sentences, \
            f"num_sentences {num_sentences} > max_sentences {self.max_sentences} for chapter {self.chapters[idx]}"
        assert input_ids.size() == (num_sentences, self.max_length), \
            f"input_ids size {input_ids.size()} != [{num_sentences}, {self.max_length}] for chapter {self.chapters[idx]}"
        assert labels.size() == (num_sentences,), \
            f"labels size {labels.size()} != [{num_sentences}] for chapter {self.chapters[idx]}"
        
        logger.debug(f"Chapter {self.chapters[idx]}: {num_sentences} sentences, input_ids shape {input_ids.size()}")
        
        return {
            'input_ids': input_ids,  # Shape: [num_sentences, max_length]
            'labels': labels,        # Shape: [num_sentences]
            'num_sentences': num_sentences  # Actual number of sentences
        }

# Extractive Model Definition
class SummaryLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=64, num_layers=1):
        super(SummaryLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 2)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, num_sentences):
        # x shape: [batch_size, max_sentences, max_length]
        batch_size = x.size(0)
        outputs = []
        
        for i in range(batch_size):
            # Process sentences for one chapter
            chapter_input = x[i, :num_sentences[i]]  # Shape: [num_sentences, max_length]
            embedded = self.embedding(chapter_input)  # Shape: [num_sentences, max_length, embedding_dim]
            # Average embeddings over words to get sentence embeddings
            sentence_embeds = embedded.mean(dim=1)    # Shape: [num_sentences, embedding_dim]
            lstm_out, (hidden, _) = self.lstm(sentence_embeds.unsqueeze(0))  # Shape: [1, num_sentences, hidden_dim*2]
            lstm_out = lstm_out.squeeze(0)  # Shape: [num_sentences, hidden_dim*2]
            lstm_out = self.dropout(lstm_out)
            output = self.fc(lstm_out)  # Shape: [num_sentences, 2]
            # Pad output to max_sentences
            if num_sentences[i] < x.size(1):
                pad_size = x.size(1) - num_sentences[i]
                output = torch.cat([
                    output,
                    torch.zeros(pad_size, 2, device=output.device)
                ], dim=0)
            outputs.append(output)
        
        return torch.stack(outputs)  # Shape: [batch_size, max_sentences, 2]

# Function to generate extractive summary
def generate_summary(chapter_sentences, vocab, model, device, max_length=40, max_sentences=7000, target_ratio=0.15):
    model.eval()
    scores = []
    chapter_words = sum(len(word_tokenize(sentence)) for sentence in chapter_sentences)
    target_word_count = int(chapter_words * target_ratio)
    tokenized_indices = []
    for sentence in chapter_sentences:
        words = word_tokenize(sentence.lower())
        indices = [vocab.get(word, vocab['<UNK>']) for word in words]
        if len(indices) > max_length:
            indices = indices[:max_length]
        else:
            indices += [vocab['<PAD>']] * (max_length - len(indices))
        indices = [max(0, min(idx, len(vocab) - 1)) for idx in indices]
        tokenized_indices.append(indices)
    num_sentences = len(tokenized_indices)
    if num_sentences > max_sentences:
        tokenized_indices = tokenized_indices[:max_sentences]
        chapter_sentences = chapter_sentences[:max_sentences]
        num_sentences = max_sentences
    elif num_sentences < max_sentences:
        tokenized_indices += [[vocab['<PAD>']] * max_length] * (max_sentences - num_sentences)
    input_ids = torch.tensor([tokenized_indices], dtype=torch.long).to(device)
    with torch.no_grad():
        outputs = model(input_ids, torch.tensor([num_sentences], device=device))
        scores = torch.softmax(outputs[0], dim=1)[:, 1].cpu().numpy()
    sentence_scores = [(sentence, score, len(word_tokenize(sentence)), idx) 
                      for idx, (sentence, score) in enumerate(zip(chapter_sentences, scores[:num_sentences]))]
    sentence_scores.sort(key=lambda x: x[1], reverse=True)
    selected = []
    current_word_count = 0
    for sentence, score, word_count, sentence_id in sentence_scores:
        if current_word_count + word_count <= target_word_count or len(selected) < 1:
            selected.append((sentence, sentence_id))
            current_word_count += word_count
        else:
            break
    selected.sort(key=lambda x: x[1])
    selected_sentences = [sentence for sentence, _ in selected]
    return ' '.join(selected_sentences)

# Training and evaluation functions for extractive model
def train_epoch(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc="Training Extractive Batches"):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        num_sentences = batch['num_sentences'].to(device)
        outputs = model(input_ids, num_sentences)
        outputs = outputs.view(-1, 2)
        labels = labels.view(-1)
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return total_loss / len(data_loader)

def eval_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            num_sentences = batch['num_sentences'].to(device)
            outputs = model(input_ids, num_sentences)
            outputs = outputs.view(-1, 2)
            labels = labels.view(-1)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
    return total_loss / len(data_loader)

# Main execution
# Step 1: Train extractive model
vocab = build_vocab('Dataset/train_GAlabelled.csv')
train_dataset = SummaryDataset('Dataset/train_GAlabelled.csv', vocab)
val_dataset = SummaryDataset('Dataset/val_GAlabelled.csv', vocab)

# Custom collate function to handle variable-length chapters
def custom_collate_fn(batch):
    max_sentences = 7000  # Same as max_sentences in SummaryDataset
    max_length = 40     # Same as max_length in SummaryDataset

    # Find the maximum number of sentences in this batch (capped at max_sentences)
    batch_max_sentences = min(max(item['num_sentences'] for item in batch), max_sentences)
    logger.debug(f"Batch max sentences: {batch_max_sentences}")

    input_ids_list = []
    labels_list = []
    num_sentences_list = []

    for item in batch:
        input_ids = item['input_ids']  # Shape: [num_sentences, max_length]
        labels = item['labels']        # Shape: [num_sentences]
        num_sentences = item['num_sentences']

        # Truncate if necessary
        if num_sentences > batch_max_sentences:
            logger.debug(f"Truncating chapter from {num_sentences} to {batch_max_sentences} sentences")
            input_ids = input_ids[:batch_max_sentences]
            labels = labels[:batch_max_sentences]
            num_sentences = batch_max_sentences
        # Pad if necessary
        elif num_sentences < batch_max_sentences:
            pad_size = batch_max_sentences - num_sentences
            logger.debug(f"Padding chapter from {num_sentences} to {batch_max_sentences} sentences")
            input_ids = torch.cat([
                input_ids,
                torch.zeros(pad_size, max_length, dtype=torch.long)
            ], dim=0)
            labels = torch.cat([
                labels,
                torch.full((pad_size,), -1, dtype=torch.long)
            ], dim=0)

        # Verify tensor size before stacking
        assert input_ids.size() == (batch_max_sentences, max_length), \
            f"input_ids size {input_ids.size()} != [{batch_max_sentences}, {max_length}]"
        assert labels.size() == (batch_max_sentences,), \
            f"labels size {labels.size()} != [{batch_max_sentences}]"

        input_ids_list.append(input_ids)
        labels_list.append(labels)
        num_sentences_list.append(num_sentences)

    # Stack into batch tensors
    input_ids_batch = torch.stack(input_ids_list)  # Shape: [batch_size, batch_max_sentences, max_length]
    labels_batch = torch.stack(labels_list)        # Shape: [batch_size, batch_max_sentences]
    num_sentences_batch = torch.tensor(num_sentences_list, dtype=torch.long)

    logger.debug(f"Batch shapes: input_ids {input_ids_batch.size()}, labels {labels_batch.size()}")

    return {
        'input_ids': input_ids_batch,
        'labels': labels_batch,
        'num_sentences': num_sentences_batch
    }

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=custom_collate_fn)

model = SummaryLSTM(vocab_size=len(vocab), embedding_dim=100, hidden_dim=64, num_layers=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-1)
patience = 2
best_val_loss = float('inf')
patience_counter = 0
max_epochs = 100

for epoch in range(max_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = eval_model(model, val_loader, criterion, device)
    logger.info(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            logger.info("Early stopping triggered")
            break

[nltk_data] Downloading package punkt to C:\Users\Viet-Dung
[nltk_data]     Nguyen\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to C:\Users\Viet-Dung
[nltk_data]     Nguyen\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
INFO:__main__:Using device: cuda
Building Extractive Vocabulary: 100%|███████| 2308987/2308987 [02:26<00:00, 15807.58it/s]
Training Extractive Batches: 100%|███████████████████| 1200/1200 [01:23<00:00, 14.45it/s]
INFO:__main__:Epoch 1, Train Loss: 0.3213, Val Loss: 0.3393
Training Extractive Batches: 100%|███████████████████| 1200/1200 [01:24<00:00, 14.13it/s]
INFO:__main__:Epoch 2, Train Loss: 0.2974, Val Loss: 0.3285
Training Extractive Batches: 100%|███████████████████| 1200/1200 [01:24<00:00, 14.28it/s]
INFO:__main__:Epoch 3, Train Loss: 0.2876, Val Loss: 0.3318
Training Extractive Batches: 100%|███████████████████| 1200/1200 [01:26<00:00, 13.89it/s]
INF

ParserError: field larger than field limit (131072)

In [6]:
# Build vocabulary for abstractive model
def build_vocab_abstractive(texts, min_freq=2):
    word_counts = Counter()
    for text in tqdm(texts, desc="Building Abstractive Vocabulary"):
        tokens = word_tokenize(text.lower())
        word_counts.update(tokens)
    vocab = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}
    idx = len(vocab)
    for word, count in word_counts.items():
        if count >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab, {v: k for k, v in vocab.items()}

# Custom Dataset for abstractive model
class AbstractiveDataset(Dataset):
    def __init__(self, pairs, vocab, max_input_len=400, max_target_len=100):
        self.pairs = pairs
        self.vocab = vocab
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        extractive, summary = self.pairs[idx]
        input_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in word_tokenize(extractive.lower())]
        target_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in word_tokenize(summary.lower())]
        input_ids = [self.vocab['<SOS>']] + input_ids + [self.vocab['<EOS>']]
        target_ids = [self.vocab['<SOS>']] + target_ids + [self.vocab['<EOS>']]
        input_ids = input_ids[:self.max_input_len] + [self.vocab['<PAD>']] * (self.max_input_len - len(input_ids))
        target_ids = target_ids[:self.max_target_len] + [self.vocab['<PAD>']] * (self.max_target_len - len(target_ids))
        return {'input_ids': torch.tensor(input_ids), 'target_ids': torch.tensor(target_ids)}

# Seq2Seq Model for abstractive summarization
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden, cell):
        embedded = self.embedding(x)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        return self.fc(output), hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(vocab_size, embedding_dim, hidden_dim)
        self.decoder = Decoder(vocab_size, embedding_dim, hidden_dim)

    def forward(self, source, decoder_input):
        batch_size = source.size(0)
        seq_len = decoder_input.size(1)
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, seq_len, vocab_size).to(source.device)
        hidden, cell = self.encoder(source)
        for t in range(seq_len):
            input = decoder_input[:, t].unsqueeze(1)
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t] = output.squeeze(1)
        return outputs

    def generate_summary(self, input_ids, vocab, idx2word, max_length=100):
        self.eval()
        with torch.no_grad():
            hidden, cell = self.encoder(input_ids.unsqueeze(0))
            decoder_input = torch.tensor([[vocab['<SOS>']]], device=input_ids.device)
            outputs = []
            for _ in range(max_length):
                output, hidden, cell = self.decoder(decoder_input, hidden, cell)
                top_idx = output.argmax(2).item()
                outputs.append(top_idx)
                if top_idx == vocab['<EOS>']:
                    break
                decoder_input = torch.tensor([[top_idx]], device=input_ids.device)
            return ' '.join([idx2word.get(idx, '<UNK>') for idx in outputs if idx != vocab['<EOS>']])

# Training function for abstractive model
def train_abstractive_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training Abstractive"):
        source = batch['input_ids'].to(device)
        target_ids = batch['target_ids'].to(device)
        decoder_input = target_ids[:, :-1]
        decoder_target = target_ids[:, 1:]
        outputs = model(source, decoder_input)
        loss = criterion(
            outputs.contiguous().view(-1, outputs.size(-1)),
            decoder_target.contiguous().view(-1)
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Generate extractive summaries for abstractive training
def generate_extractive_summaries(data_df, model, vocab, device):
    summaries = []
    for idx, row in tqdm(data_df.iterrows(), total=len(data_df), desc="Generating Extractive Summaries"):
        chapter_text = row['chapter']
        sentences = sent_tokenize(chapter_text)
        extractive_summary = generate_summary(sentences, vocab, model, device)
        summaries.append((idx, extractive_summary, row['summary_text']))
    return summaries

# Function to extract summaries from labelled CSV and align with reference summaries
def extract_summaries_with_reference(labelled_csv, reference_csv, delimiter=";;;;;;"):
    labelled_data = load_csv(labelled_csv, delimiter)
    reference_data = pd.read_csv(reference_csv)
    
    summaries = []
    grouped = labelled_data.groupby('chapter')
    labelled_chapters = list(grouped.groups.keys())
    
    # Ensure reference data is aligned by chapter order
    reference_data['chapter'] = reference_data['chapter'].astype(str)
    reference_chapters = reference_data['chapter'].tolist()
    
    # Verify chapter order alignment
    if len(labelled_chapters) != len(reference_chapters):
        logger.warning(f"Chapter count mismatch: {len(labelled_chapters)} in labelled, {len(reference_chapters)} in reference")
    
    for chapter_name, group in grouped:
        group = group.sort_values('sentence', ascending=True)
        extractive_summary = ' '.join(
            str(row['content']) for _, row in group.iterrows() if int(row['in_summary']) == 1
        )
        # Find corresponding reference summary
        try:
            ref_idx = labelled_chapters.index(chapter_name)
            ref_row = reference_data.iloc[ref_idx]
            reference_summary = str(ref_row['summary_text'])
        except (ValueError, IndexError):
            logger.warning(f"No reference summary found for chapter {chapter_name}")
            reference_summary = ''
        summaries.append((chapter_name, extractive_summary, reference_summary))
    return summaries

# Step 2: Prepare abstractive data
train_data = extract_summaries_with_reference('Dataset/train_GAlabelled.csv', 'Dataset/train.csv')
val_data = extract_summaries_with_reference('Dataset/val_GAlabelled.csv', 'Dataset/dev.csv')
test_data = pd.read_csv('Dataset/test.csv')
test_abstractive_data = generate_extractive_summaries(test_data, model, vocab, device)

# Step 3: Prepare abstractive dataset
train_pairs = [(extractive, summary) for _, extractive, summary in train_data if extractive and summary]
val_pairs = [(extractive, summary) for _, extractive, summary in val_data if extractive and summary]
test_pairs = [(extractive, summary) for _, extractive, summary in test_abstractive_data]
all_texts = [text for pair in (train_pairs + val_pairs + test_pairs) for text in pair if text]
vocab_abstractive, idx2word = build_vocab_abstractive(all_texts)
train_dataset_abstractive = AbstractiveDataset(train_pairs, vocab_abstractive)
val_dataset_abstractive = AbstractiveDataset(val_pairs, vocab_abstractive)
test_dataset_abstractive = AbstractiveDataset(test_pairs, vocab_abstractive)
train_loader_abstractive = DataLoader(train_dataset_abstractive, batch_size=8, shuffle=True)
val_loader_abstractive = DataLoader(val_dataset_abstractive, batch_size=8)
test_loader_abstractive = DataLoader(test_dataset_abstractive, batch_size=8)

Generating Extractive Summaries: 100%|███████████████| 1431/1431 [01:44<00:00, 13.68it/s]
Building Abstractive Vocabulary: 100%|████████████| 25026/25026 [00:38<00:00, 655.48it/s]
Training Abstractive:  10%|██▌                        | 114/1200 [01:19<12:35,  1.44it/s]


KeyboardInterrupt: 

In [7]:
# Step 4: Train abstractive model
model_abstractive = Seq2Seq(len(vocab_abstractive)).to(device)
criterion_abstractive = nn.CrossEntropyLoss(ignore_index=vocab_abstractive['<PAD>'])
optimizer_abstractive = torch.optim.Adam(model_abstractive.parameters(), lr=0.001)
patience = 3
best_val_loss = float('inf')
patience_counter = 0
max_epochs_abstractive = 10

for epoch in range(max_epochs_abstractive):
    train_loss = train_abstractive_epoch(model_abstractive, train_loader_abstractive, criterion_abstractive, optimizer_abstractive, device)
    # Evaluate validation loss
    model_abstractive.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader_abstractive:
            source = batch['input_ids'].to(device)
            target_ids = batch['target_ids'].to(device)
            decoder_input = target_ids[:, :-1]
            decoder_target = target_ids[:, 1:]
            outputs = model_abstractive(source, decoder_input)
            loss = criterion_abstractive(
                outputs.contiguous().view(-1, outputs.size(-1)),
                decoder_target.contiguous().view(-1)
            )
            val_loss += loss.item()
    val_loss /= len(val_loader_abstractive)
    logger.info(f"Abstractive Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model_abstractive.state_dict(), 'abstractive_model.pt')
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            logger.info("Early stopping triggered for abstractive model")
            break

Training Abstractive: 100%|██████████████████████████| 1200/1200 [18:57<00:00,  1.06it/s]
INFO:__main__:Abstractive Epoch 1, Train Loss: 6.2602, Val Loss: 6.1520
Training Abstractive: 100%|██████████████████████████| 1200/1200 [15:09<00:00,  1.32it/s]
INFO:__main__:Abstractive Epoch 2, Train Loss: 5.4900, Val Loss: 6.0022
Training Abstractive: 100%|██████████████████████████| 1200/1200 [16:39<00:00,  1.20it/s]
INFO:__main__:Abstractive Epoch 3, Train Loss: 5.2086, Val Loss: 5.9417
Training Abstractive: 100%|██████████████████████████| 1200/1200 [16:17<00:00,  1.23it/s]
INFO:__main__:Abstractive Epoch 4, Train Loss: 5.0183, Val Loss: 5.9322
Training Abstractive: 100%|██████████████████████████| 1200/1200 [16:17<00:00,  1.23it/s]
INFO:__main__:Abstractive Epoch 5, Train Loss: 4.8697, Val Loss: 5.9499
Training Abstractive: 100%|██████████████████████████| 1200/1200 [17:51<00:00,  1.12it/s]
INFO:__main__:Abstractive Epoch 6, Train Loss: 4.7437, Val Loss: 5.9869
Training Abstractive: 100%|█

In [13]:
# Step 5: Test and evaluate
rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_scores = []
output_file = 'test_abstractive_summaries.txt'

with open(output_file, 'w', encoding='utf-8') as f:
    f.write("Test Summaries (Extractive and Abstractive)\n")
    f.write("=" * 50 + "\n\n")

    for chapter_id, extractive, reference in tqdm(test_abstractive_data, desc="Evaluating Test Summaries"):
        input_ids = [vocab_abstractive.get(token, vocab_abstractive['<UNK>']) for token in word_tokenize(extractive.lower())]
        input_ids = [vocab_abstractive['<SOS>']] + input_ids + [vocab_abstractive['<EOS>']]
        input_ids = input_ids[:400] + [vocab_abstractive['<PAD>']] * (400 - len(input_ids))
        input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
        generated_summary = model_abstractive.generate_summary(input_ids, vocab_abstractive, idx2word)
        scores = rouge_scorer_instance.score(reference, generated_summary)
        rouge_scores.append(scores)

        # Write to file
        f.write(f"Chapter ID: {chapter_id}\n")
        f.write("-" * 50 + "\n")
        f.write("Extractive Summary:\n")
        f.write(f"{extractive}\n\n")
        f.write("Abstractive Summary:\n")
        f.write(f"{generated_summary}\n\n")
        f.write("Reference Summary:\n")
        f.write(f"{reference}\n\n")
        f.write("ROUGE Scores (Abstractive vs Reference):\n")
        for metric in ['rouge1', 'rouge2', 'rougeL']:
            f.write(f"{metric.upper()}:\n")
            f.write(f"  Precision: {scores[metric].precision:.4f}\n")
            f.write(f"  Recall: {scores[metric].recall:.4f}\n")
            f.write(f"  F1: {scores[metric].fmeasure:.4f}\n")
        f.write("\n" + "=" * 50 + "\n\n")

avg_rouge = {
    'rouge1': {'precision': 0, 'recall': 0, 'fmeasure': 0},
    'rouge2': {'precision': 0, 'recall': 0, 'fmeasure': 0},
    'rougeL': {'precision': 0, 'recall': 0, 'fmeasure': 0}
}
for scores in tqdm(rouge_scores, desc="Computing ROUGE Scores"):
    for metric in avg_rouge:
        avg_rouge[metric]['precision'] += scores[metric].precision / len(rouge_scores)
        avg_rouge[metric]['recall'] += scores[metric].recall / len(rouge_scores)
        avg_rouge[metric]['fmeasure'] += scores[metric].fmeasure / len(rouge_scores)

# Append average ROUGE scores to the file
with open(output_file, 'a', encoding='utf-8') as f:
    f.write("Average ROUGE Scores (Abstractive vs Reference)\n")
    f.write("-" * 50 + "\n")
    for metric, values in avg_rouge.items():
        f.write(f"{metric.upper()}:\n")
        f.write(f"  Precision: {values['precision']:.4f}\n")
        f.write(f"  Recall: {values['recall']:.4f}\n")
        f.write(f"  F1: {values['fmeasure']:.4f}\n")
    f.write("\n" + "=" * 50 + "\n")

logger.info("Average ROUGE Scores:")
for metric, values in avg_rouge.items():
    logger.info(f"{metric}: Precision={values['precision']:.4f}, Recall={values['recall']:.4f}, F1={values['fmeasure']:.4f}")

# Save abstractive model
torch.save(model_abstractive.state_dict(), 'abstractive_model.pt')

INFO:absl:Using default tokenizer.
Evaluating Test Summaries: 100%|█████████████████████| 1431/1431 [03:31<00:00,  6.76it/s]
Computing ROUGE Scores: 100%|████████████████████| 1431/1431 [00:00<00:00, 470158.94it/s]
INFO:__main__:Average ROUGE Scores:
INFO:__main__:rouge1: Precision=0.3691, Recall=0.1329, F1=0.1762
INFO:__main__:rouge2: Precision=0.0483, Recall=0.0141, F1=0.0196
INFO:__main__:rougeL: Precision=0.2755, Recall=0.0986, F1=0.1303
