In [None]:
%load_ext autoreload
%autoreload 2

import os
import json
import pickle

from tqdm import tqdm

from src.datasets import IndoSum
from src.common import get_device
from src.indobart.base import get_model, get_tokenizer, get_config

import stanza
import torch
import spacy
from spacy.tokens import Doc
from spacy import displacy

import numpy as np
import nltk
import evaluate
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

from transformers import BartModel, BartConfig
from transformers import MBartForConditionalGeneration, MBartConfig
from transformers.models.bart.modeling_bart import BartAttention
import torch.nn as nn

import diskcache as dc

from accelerate import Accelerator

In [None]:
accelerator = Accelerator()
device = accelerator.device
device

In [None]:
# Download and set up Stanza's Indonesian NLP model
stanza.download("id")
nlp = stanza.Pipeline("id", processors="tokenize,pos,lemma,depparse", use_gpu=True, device=device)
nlp

### Data Loading

In [None]:
indosum = IndoSum()
indosum.ds

In [None]:
indosum.to_pd("train").head()

### Dependency Parsing

In [None]:
# Convert Stanza Document to spaCy Doc for Visualization
def stanza_to_spacy(doc):
    """
    Converts a Stanza-parsed document to a spaCy Doc object for dependency visualization.
    Handles root words and ensures valid head indices.
    """
    # Flatten all sentences' words, heads, and dependency relations from Stanza
    words = [word.text for sentence in doc.sentences for word in sentence.words]
    deps = [word.deprel for sentence in doc.sentences for word in sentence.words]

    # Stanza `head` is 1-based (1-indexed); convert to 0-based for spaCy and handle roots
    heads = []
    for sentence in doc.sentences:
        for word in sentence.words:
            # If the word is root, set head to its own index
            if word.head == 0:
                heads.append(word.id - 1)
            else:
                heads.append(word.head - 1)  # Convert to 0-based indexing

    # Create a spaCy Doc object using the extracted information
    spacy_doc = Doc(spacy.blank("id").vocab, words=words)
    for token, head, dep in zip(spacy_doc, heads, deps):
        token.dep_ = dep
        token.head = spacy_doc[head]

    return spacy_doc

# Build Dependency Information Matrices (DIM) for each sentence in a document
def build_dependency_matrices(document):
    """
    Build a list of Dependency Information Matrices (DIMs) for each sentence in the document.
    Each matrix represents dependency relations within a sentence.
    """
    doc = nlp(document)  # Process the document with Stanza
    matrices = []  # List to hold the DIM for each sentence in the document
    sentence_texts = []  # List to hold the raw sentences

    # Iterate over each sentence in the processed Stanza document
    for sentence in doc.sentences:
        n = len(sentence.words)  # Number of words in the sentence
        matrix = np.zeros((n, n))  # Initialize an n x n matrix with zeros

        # Populate the matrix with dependency information
        for word in sentence.words:
            if word.head > 0:  # If head is not root (head == 0 indicates root in Stanza)
                # Set a 1 for both directions (i.e., word -> head and head -> word)
                matrix[word.id - 1, word.head - 1] = 1  # word.id and word.head are 1-based indices
                matrix[word.head - 1, word.id - 1] = 1  # Make the matrix symmetric
        matrices.append(torch.tensor(matrix, dtype=torch.float32))  # Convert matrix to tensor and add to list
        sentence_texts.append(sentence.text)  # Add the raw sentence text to the list

    return list(zip(matrices, sentence_texts)), doc

# Parse and Visualize Dependencies
def visualize_dependencies(doc):
    """
    Visualizes dependencies from the Stanza-parsed document using spaCy's displacy.
    """
    # Convert Stanza output to spaCy format for visualization
    spacy_doc = stanza_to_spacy(doc)
    
    # Visualize dependencies using spaCy's displacy
    displacy.render(spacy_doc, style="dep", jupyter=True)  # Use jupyter=True in notebooks

#### Data Exploration

In [None]:
sample_doc = indosum.ds["validation"][0]['document']
sample_doc

In [None]:
sample_dim_sentence_pairs, sample_stanza_doc = build_dependency_matrices(sample_doc)
print("Dependency Information Matrices for each sentence:")
for i, (matrix, sentence) in enumerate(sample_dim_sentence_pairs, 1):
    print(f"Raw Sentence {i}:\n{sentence}")
    print(f"DIM Sentence {i}:\n{matrix}\n")

### Linguistic-Guided Attention in the Encoder

In [None]:
# Define a custom attention layer for the encoder to use the DIM during the attention calculation
class EncoderLinguisticGuidedAttention(BartAttention):
    def __init__(self, embed_dim, num_heads, dropout=0.0, alpha=1.0):
        super().__init__(embed_dim, num_heads, dropout, is_decoder=False)
        self.alpha = alpha

    def forward(self, hidden_states, dim_matrix, **kwargs):
        attn_output, attn_weights = super().forward(hidden_states, **kwargs)
        
        # Apply linguistic-guided attention
        dim_matrix = dim_matrix.to(attn_weights.device)
        lg_attn_weights = (self.alpha * dim_matrix + torch.eye(dim_matrix.size(-1), device=dim_matrix.device)) * attn_weights
        attn_output = torch.matmul(lg_attn_weights, hidden_states)
        return attn_output, lg_attn_weights

In [None]:
# Now, replace the encoder's attention mechanism with EncoderLinguisticGuidedAttention.
class CustomIndoMBartWithLGA(MBartForConditionalGeneration):
    def __init__(self, config: MBartConfig, alpha=1.0):
        super().__init__(config)

        # Modify encoder layers to use linguistic-guided attention
        for layer in self.model.encoder.layers:
            layer.self_attn = EncoderLinguisticGuidedAttention(
                config.d_model, config.encoder_attention_heads, config.attention_dropout, alpha=alpha
            )


### Load Model

In [None]:
tokenizer = get_tokenizer()
tokenizer


In [None]:
config = get_config()
config

In [None]:
pretrained_model = get_model()

In [None]:
pretrained_model

In [None]:
model = CustomIndoMBartWithLGA(config, alpha=1.0)


In [None]:
model

In [None]:
model.load_state_dict(get_model().state_dict(), strict=False)

In [None]:
model

### Train Model

In [None]:
# Setup evaluation
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

#### Preparation

In [None]:
# Update data collator to include DIM
class CustomDataCollator(DataCollatorForSeq2Seq):
    def __call__(self, features):
        batch = super().__call__(features)
        
        # Flatten and pad DIMs across documents in the batch for consistent dimensions
        max_sentences = max(len(f["dim_matrices"]) for f in features)
        max_tokens = max(matrix.size(0) for f in features for matrix in f["dim_matrices"])

        # Initialize padded tensor for batched DIMs
        dim_matrices_padded = torch.zeros((len(features), max_sentences, max_tokens, max_tokens))

        for i, feature in enumerate(features):
            for j, matrix in enumerate(feature["dim_matrices"]):
                dim_matrices_padded[i, j, :matrix.size(0), :matrix.size(1)] = matrix

        batch["dim_matrices"] = dim_matrices_padded
        return batch

data_collator = CustomDataCollator(tokenizer=tokenizer, model=model)

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    return result

In [None]:
# Set up the cache
os.makedirs(f"./results/00-indobart-dp/dim_cache", exist_ok=True)
cache_path = "./results/00-indobart-dp/dim_cache/dim_cache.pkl"

# Dictionary to hold all precomputed DIMs
dim_cache = {}

# Precompute DIMs and store them in the dictionary with progress tracking
def precompute_dims(documents):
    print("Precomputing Dependency Information Matrices (DIMs) for documents...")
    for i, document in enumerate(tqdm(documents, desc="Computing DIMs")):
        dim_sentence_pairs, _ = build_dependency_matrices(document)
        dim_matrices, _ = zip(*dim_sentence_pairs)
        dim_cache[i] = [dm.tolist() for dm in dim_matrices]  # Store DIMs as lists for easier serialization

# Save the entire DIM cache as a single file with progress tracking
def save_cache(filename="dim_cache.pkl"):
    print(f"Saving DIM cache to {filename}...")
    with open(filename, "wb") as f:
        pickle.dump(dim_cache, f)
    print("DIM cache saved successfully.")

# Load the DIM cache from a single file
def load_cache(filename="dim_cache.pkl"):
    print(f"Loading DIM cache from {filename}...")
    with open(filename, "rb") as f:
        cache = pickle.load(f)
    print("DIM cache loaded successfully.")
    return cache

# Main function to check cache existence, load or precompute DIMs, and save if necessary
def prepare_dim_cache(documents, cache_path="dim_cache.pkl"):
    global dim_cache  # Use the global dim_cache dictionary

    if os.path.exists(cache_path):
        # If cache file exists, load it
        dim_cache = load_cache(cache_path)
    else:
        # If cache file does not exist, precompute DIMs and save to cache
        precompute_dims(documents)
        save_cache(cache_path)

# Example usage
documents = indosum.ds["train"]["document"]  # Replace with your dataset's document list
prepare_dim_cache(documents, cache_path)

In [None]:
# Prepare and tokenize dataset
def preprocess_function(examples):
    model_inputs = tokenizer(examples["document"], max_length=768, truncation=True)
    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]

    # Retrieve DIMs from the loaded cache for each document in the batch
    dim_matrices_batch = []
    for i, document in enumerate(examples["document"]):
        if i in dim_cache:
            # Convert the cached DIMs (stored as lists) back to tensors
            dim_matrices = [torch.tensor(dm) for dm in dim_cache[i]]
        else:
            # If a document’s DIMs are not in the cache, handle it (optional)
            dim_sentence_pairs, _ = build_dependency_matrices(document)
            dim_matrices, _ = zip(*dim_sentence_pairs)
            dim_matrices = [torch.tensor(dm) for dm in dim_matrices]
            dim_cache[i] = [dm.tolist() for dm in dim_matrices]  # Add to cache for future use

        dim_matrices_batch.append(dim_matrices)

    model_inputs["dim_matrices"] = dim_matrices_batch
    return model_inputs


tokenized_ds = indosum.ds.map(preprocess_function, batched=True)

def train_model(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length):
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir + "/checkpoint",
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        weight_decay=0.01,
        num_train_epochs=num_train_epochs,
        fp16=True,
        predict_with_generate=True,
        generation_max_length=generation_max_length,
        log_level="info",
        logging_first_step=True,
        logging_dir=output_dir + "/logs",
        resume_from_checkpoint=True,
        save_total_limit=1,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_ds["train"],
        eval_dataset=tokenized_ds["validation"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    return trainer
    
def evaluate_model(trainer):
    eval_results = trainer.evaluate(eval_dataset=tokenized_ds["test"])
    return eval_results


def train_and_evaluate(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length):
    trainer = train_model(output_dir, per_device_batch_size, learning_rate, num_train_epochs, generation_max_length)
    eval_results = evaluate_model(trainer)
    
    return trainer, eval_results


#### Training & Evaluation

Try multiple generation max length with the rest parameters fixed.
Observes the best score and the corresponding generation max length.

In [None]:
experiments = []

for i in range(1, 6):
    generation_max_length = 50 + i * 10
    experiments.append({
        "output_dir": f"./results/00-indobart-dp/model/0{i}",
        "per_device_batch_size": 8,
        "learning_rate": 3.75e-5,
        "num_train_epochs": 3,
        "generation_max_length": generation_max_length
    })

for exp in experiments:
    os.makedirs(exp["output_dir"], exist_ok=True)
    
    trainer, eval_results = train_and_evaluate(
        exp["output_dir"],
        exp["per_device_batch_size"],
        exp["learning_rate"],
        exp["num_train_epochs"],
        exp["generation_max_length"]
    )
    
    # print params and the results
    print("=== Results for experiment ===")
    print("-- Params --") 
    print(json.dumps(exp, indent=4))
    print("-- Eval results --")
    print(json.dumps(eval_results, indent=4))
    
    # save mapping between params and results
    with open(exp["output_dir"] + "/params.json", "w") as f:
        json.dump(exp, f)
    
    with open(exp["output_dir"] + "/eval_results.json", "w") as f:
        json.dump(eval_results, f)

