# Dataset
[Dataset Link](https://huggingface.co/datasets/ccdv/arxiv-summarization?utm_source=chatgpt.com)

In [None]:
!pip install -q datasets transformers spacy
!python -m spacy download en_core_web_sm

# Load Dataset and Inspect

In [None]:
from datasets import load_dataset

# Load the dataset (train split only for now)
dataset = load_dataset("ccdv/arxiv-summarization", split="train")

# View the structure of the first sample
print(dataset[0].keys())
print("Sample Article:", dataset[0]['article'][:500])  # Preview only
print("Sample Abstract:", dataset[0]['abstract'])

# Sentence Segmentation & Named Entity Extraction

In [None]:
import spacy

# Load spaCy model for sentence splitting and NER
nlp = spacy.load("en_core_web_sm")

# Load once globally
nlp_ner = spacy.load("en_core_web_sm")
nlp_sent = spacy.blank("en")
nlp_sent.add_pipe("sentencizer")

def preprocess_article(text, max_chars=50000, use_entities=True):
    text = text[:max_chars]

    if use_entities:
        doc = next(nlp_ner.pipe([text]))
        entities = list(set(
            ent.text.strip()
            for ent in doc.ents
            if ent.label_ in {'PERSON', 'ORG', 'GPE', 'DATE', 'WORK_OF_ART'}
        ))
    else:
        doc = next(nlp_sent.pipe([text]))
        entities = []

    # Filter better sentences
    sentences = [
        sent.text.strip() for sent in doc.sents
        if len(sent.text.strip()) > 40 and not sent.text.strip().startswith('*') and not sent.text.strip().isdigit()
    ]

    return sentences, entities

# Graph Construction with DGL

In [None]:
pip install -q torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124

In [None]:
pip install -q dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html

In [None]:
import dgl
import torch
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

## build graph

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

## Build and Inspect the Graph

In [None]:
# !pip install sentence-transformers==2.2.2

In [None]:
# # Install accelerate
!pip install -q accelerate==0.29.3

# # Install transformers
!pip install -q transformers==4.39.3

# # Install sentence-transformers
# !pip install sentence-transformers==2.2.2

# # Install bertopic
# !pip install bertopic==0.16.0

# # Ensure spaCy and its model are installed/downloaded
# !pip install spacy
# !python -m spacy download en_core_web_sm

In [None]:
!pip uninstall -y sentence-transformers

In [None]:
!pip install -q sentence-transformers==2.6.0

In [None]:
!pip install -q bertopic==0.16.0

In [None]:
import torch
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Device Capability: {torch.cuda.get_device_capability(0)}")

from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
import umap # Import UMAP for custom model
import hdbscan # Import HDBSCAN for custom model (good practice for BERTopic)
import numpy as np # For potential min/max operations

# --- NEW: Prepare a larger collection of sentences for BERTopic ---
all_sentences_for_bertopic = []
num_articles_to_process = 200 # Adjust this number!

print(f"Collecting sentences from {num_articles_to_process} articles...")
for i in range(num_articles_to_process):
    if i >= len(dataset):
        print(f"Reached end of dataset at article {i}. Stopping collection.")
        break
    if i % 100 == 0 and i != 0: # Print progress every 100 articles after the first
        print(f"Processing article {i}...")
    article_text = dataset[i]['article']
    sents, _ = preprocess_article(article_text)
    all_sentences_for_bertopic.extend(sents)

print(f"Total sentences collected for BERTopic: {len(all_sentences_for_bertopic)}")

if len(all_sentences_for_bertopic) < 50:
    print("WARNING: Insufficient sentences collected for meaningful topic modeling. Consider increasing `num_articles_to_process`.")

# --- Rest of your BERTopic code, now using `all_sentences_for_bertopic` ---

# Embed the sentences
# SentenceTransformer will automatically use GPU if available and properly configured with PyTorch-CUDA
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

print("Embedding sentences (this might take a while for large datasets)...")
# The encoding process will now run on GPU if available
# sentence_embeddings = sbert_model.encode(all_sentences_for_bertopic, show_progress_bar=True)
from tqdm import tqdm

batch_size = 256
sentence_embeddings = []
for i in tqdm(range(0, len(all_sentences_for_bertopic), batch_size)):
    batch = all_sentences_for_bertopic[i:i+batch_size]
    sentence_embeddings.extend(sbert_model.encode(batch))

In [None]:
# Configure UMAP and HDBSCAN for BERTopic
umap_n_neighbors = max(2, min(15, len(all_sentences_for_bertopic) - 1))
umap_model = umap.UMAP(n_neighbors=umap_n_neighbors,
                       n_components=5,
                       min_dist=0.0,
                       random_state=42)

hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=50,
                                metric='euclidean',
                                prediction_data=True)

print("Fitting BERTopic model (this might take a while)...")
topic_model = BERTopic(embedding_model=sbert_model,
                       umap_model=umap_model,
                       hdbscan_model=hdbscan_model,
                       calculate_probabilities=True,
                       verbose=True)

# topics, probs = topic_model.fit_transform(all_sentences_for_bertopic, embeddings=sentence_embeddings)
import numpy as np

sentence_embeddings = np.array(sentence_embeddings)  # Convert list to array
topics, probs = topic_model.fit_transform(all_sentences_for_bertopic, embeddings=sentence_embeddings)

print("\nBERTopic ran successfully!")
print(f"Number of topics found: {len(topic_model.get_topic_info()) - 1}")

print("\nTop 10 Topics:")
print(topic_model.get_topic_info().head(11))

if 0 in topic_model.get_topic_info()['Topic'].values:
    print("\nWords for Topic 0:")
    print(topic_model.get_topic(0))
else:
    print("\nTopic 0 does not exist (possibly all data are outliers or no clear topics).")

# Optional: Visualize topics if you have enough data and want to confirm
# from bertopic import plotting
# if len(topics) > 1:
#    fig = topic_model.visualize_topics()
#    fig.show()

In [None]:
from bertopic import plotting
if len(topics) > 1: # Only visualize if more than just the outlier topic
    fig = topic_model.visualize_topics()
    fig.show()

    # Other useful visualizations:
    fig = topic_model.visualize_barchart(top_n_topics=10)
    fig.show()
    fig = topic_model.visualize_heatmap()
    fig.show()

# Add topic nodes and sentence-to-topic edges

# corefers_with edges between sentence nodes and discourse_follows edges (sequential sentence connections)

In [None]:
# neuralcoref was built for spaCy 2.x, we can still use a fork that works for spaCy 3.x
!pip install git+https://github.com/huggingface/neuralcoref.git

In [None]:
# import spacy
# import coreferee

# # Load coreference model (do once)
# nlp = spacy.load("en_coreference_web_trf")
# nlp.add_pipe("coreferee")

# def build_graph_with_coref_and_discourse(sentences, entities, topic_ids=None):
#     import torch
#     import dgl
#     from sklearn.feature_extraction.text import TfidfVectorizer
#     from sklearn.metrics.pairwise import cosine_similarity

#     num_sents = len(sentences)
#     num_ents = len(entities)
#     num_topics = len(set(topic_ids)) if topic_ids else 0

#     # --- Edge containers
#     edge_dict = {
#         ('sentence', 'mentions', 'entity'): ([], []),
#         ('sentence', 'belongs_to', 'topic'): ([], []),
#         ('sentence', 'corefers_with', 'sentence'): ([], []),
#         ('sentence', 'discourse_follows', 'sentence'): ([], [])
#     }

#     # --- Sentence-to-entity edges
#     for s_idx, sent in enumerate(sentences):
#         for e_idx, ent in enumerate(entities):
#             if ent.lower() in sent.lower():
#                 edge_dict[('sentence', 'mentions', 'entity')][0].append(s_idx)
#                 edge_dict[('sentence', 'mentions', 'entity')][1].append(e_idx)

#     # --- Sentence-to-topic edges
#     if topic_ids:
#         topic_map = {t: i for i, t in enumerate(sorted(set(topic_ids)))}
#         for s_idx, t_id in enumerate(topic_ids):
#             if t_id != -1:
#                 topic_idx = topic_map[t_id]
#                 edge_dict[('sentence', 'belongs_to', 'topic')][0].append(s_idx)
#                 edge_dict[('sentence', 'belongs_to', 'topic')][1].append(topic_idx)

#     # --- Co-reference edges
#     doc = nlp(" ".join(sentences))
#     token_to_sent = {token.i: i for i, s in enumerate(doc.sents) for token in s}

#     if doc._.has_coref:
#         for chain in doc._.coref_chains:
#             mentions = chain.get_mentions()
#             sent_indices = list(set(
#                 token_to_sent.get(m.start)
#                 for m in mentions if token_to_sent.get(m.start) is not None
#             ))
#             for i in range(len(sent_indices)):
#                 for j in range(i + 1, len(sent_indices)):
#                     a, b = sent_indices[i], sent_indices[j]
#                     edge_dict[('sentence', 'corefers_with', 'sentence')][0] += [a, b]
#                     edge_dict[('sentence', 'corefers_with', 'sentence')][1] += [b, a]

#     # --- Discourse edges (sequential)
#     for i in range(num_sents - 1):
#         edge_dict[('sentence', 'discourse_follows', 'sentence')][0] += [i, i + 1]
#         edge_dict[('sentence', 'discourse_follows', 'sentence')][1] += [i + 1, i]

#     # --- Graph init
#     node_dict = {'sentence': num_sents, 'entity': num_ents}
#     if topic_ids:
#         node_dict['topic'] = num_topics

#     # Filter empty edge types
#     graph_data = {
#         etype: (torch.tensor(src), torch.tensor(dst))
#         for etype, (src, dst) in edge_dict.items()
#         if src and dst
#     }

#     g = dgl.heterograph(graph_data, num_nodes_dict=node_dict)
#     return g

In [None]:
def build_graph_with_topics(sentences, entities, precomputed_topics=None):
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity
    import torch
    import dgl

    # TF-IDF for sentence similarity
    vectorizer = TfidfVectorizer(max_features=300)
    tfidf_matrix = vectorizer.fit_transform(sentences)
    sim_matrix = cosine_similarity(tfidf_matrix)

    num_sentences = len(sentences)
    num_entities = len(entities)

    sentence_sim_src = []
    sentence_sim_dst = []
    sent_ent_src = []
    sent_ent_dst = []

    # Sentence-to-sentence edges
    threshold = 0.2
    for i in range(num_sentences):
        for j in range(num_sentences):
            if i != j and sim_matrix[i][j] > threshold:
                sentence_sim_src.append(i)
                sentence_sim_dst.append(j)

    # Sentence-to-entity edges
    for sent_idx, sentence in enumerate(sentences):
        for ent_idx, entity in enumerate(entities):
            if entity in sentence:
                sent_ent_src.append(sent_idx)
                sent_ent_dst.append(num_sentences + ent_idx)

    graph_data = {}

    if sentence_sim_src:
        graph_data[('sentence', 'similar_to', 'sentence')] = (
            torch.tensor(sentence_sim_src, dtype=torch.int32),
            torch.tensor(sentence_sim_dst, dtype=torch.int32)
        )

    if sent_ent_src:
        graph_data[('sentence', 'mentions', 'entity')] = (
            torch.tensor(sent_ent_src, dtype=torch.int32),
            torch.tensor(sent_ent_dst, dtype=torch.int32)
        )

    # Sentence-to-topic edges
    sent_topic_src = []
    sent_topic_dst = []
    topic_set = []
    topic_id_map = {}

    if precomputed_topics is not None:
        topic_set = sorted(set(t for t in precomputed_topics if t != -1))
        topic_id_map = {topic_id: idx for idx, topic_id in enumerate(topic_set)}
        for sent_idx, topic_id in enumerate(precomputed_topics):
            if topic_id != -1:
                topic_idx = topic_id_map[topic_id]
                sent_topic_src.append(sent_idx)
                sent_topic_dst.append(num_sentences + num_entities + topic_idx)

        if sent_topic_src:
            graph_data[('sentence', 'belongs_to', 'topic')] = (
                torch.tensor(sent_topic_src, dtype=torch.int32),
                torch.tensor(sent_topic_dst, dtype=torch.int32)
            )

    print(f"Sent-Sent edges: {len(sentence_sim_src)}")
    print(f"Sent-Ent edges: {len(sent_ent_src)}")
    print(f"Sent-Topic edges: {len(sent_topic_src)}")
    print(f"Unique topics found: {len(topic_set)}")

    if not graph_data:
        raise ValueError("Graph has no edges.")

    g = dgl.heterograph(graph_data)

    return g, sentences, entities, topic_set

In [None]:
# ‚úÖ Combine multiple articles for better topic matching
combined_text = "\n".join(dataset[i]['article'] for i in range(200))  # ‚Üê try 20 or even 50

# ‚úÖ Preprocess the combined article text
sentences, entities = preprocess_article(combined_text)

# ‚úÖ Get topic assignments using the trained BERTopic model
document_topics, _ = topic_model.transform(sentences)

# ‚úÖ Build the heterogeneous graph with sentence-entity-topic structure
graph, sent_nodes, ent_nodes, topic_ids = build_graph_with_topics(
    sentences,
    entities,
    precomputed_topics=document_topics
)

# ‚úÖ Inspect graph structure
print("\n--- Graph Structure ---")
print(graph)
print("Node types:", graph.ntypes)
print("Edge types:", graph.etypes)

# ‚úÖ Optional: check sent-topic edge count
if ('sentence', 'belongs_to', 'topic') in graph.canonical_etypes:
    print(f"‚úÖ Sent-Topic edges: {graph.num_edges(('sentence', 'belongs_to', 'topic'))}")
else:
    print("‚ùå No Sent-Topic edges created.")

In [None]:
print("Assigned topics:", set(document_topics))

In [None]:
# topics, probs = topic_model.fit_transform(sentences)
# Transform individual document using global BERTopic
document_topics, _ = topic_model.transform(sentences)

print("Topics assigned:", topics)
print("Unique topics (excluding -1):", set(t for t in topics if t != -1))

In [None]:
import pickle

with open("bertopic_model.pkl", "wb") as f:
    pickle.dump(topic_model, f)

np.save("sentence_embeddings.npy", sentence_embeddings)

# Graph Encoding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import HeteroGraphConv, GraphConv

# Graph Encoder for Heterogeneous Graph
class HeteroGraphEncoder(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats, rel_names):
        super().__init__()
        # Define two-layer HeteroGraphConv
        self.layer1 = HeteroGraphConv({
            rel: GraphConv(in_feats, hidden_feats)
            for rel in rel_names
        }, aggregate='mean')

        self.layer2 = HeteroGraphConv({
            rel: GraphConv(hidden_feats, out_feats)
            for rel in rel_names
        }, aggregate='mean')

    def forward(self, g, inputs):
        # inputs is a dictionary of {ntype: feature_tensor}
        h = self.layer1(g, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.layer2(g, h)
        return h

In [None]:
# 1. Define input features (e.g., use TF-IDF vectors or random init)
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=300)
sentence_feats = vectorizer.fit_transform(sent_nodes).toarray()
sentence_feats = torch.tensor(sentence_feats, dtype=torch.float32)

# Optional: init entity/topic features as zeros or random
entity_feats = torch.randn(graph.num_nodes('entity'), 300)
topic_feats = torch.randn(graph.num_nodes('topic'), 300)

# 2. Combine into dictionary
features = {
    'sentence': sentence_feats,
    'entity': entity_feats,
    'topic': topic_feats
}

# 3. Instantiate the model
rel_names = list(graph.etypes)
model = HeteroGraphEncoder(in_feats=300, hidden_feats=128, out_feats=64, rel_names=rel_names)

# 4. Forward pass to get updated node embeddings
model.eval()  # Disable dropout/batchnorm for inference
with torch.no_grad():
    node_embeddings = model(graph, features)

# 5. Extract sentence node embeddings for summary ranking
sentence_embeddings = node_embeddings['sentence']
print("Sentence Embeddings Shape:", sentence_embeddings.shape)

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch

class MultiLevelGraphAttentionPooling(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.sentence_attn = nn.Linear(embed_dim, 1)
        self.entity_attn = nn.Linear(embed_dim, 1)
        self.topic_attn = nn.Linear(embed_dim, 1)

        self.entity_proj = nn.Linear(300, embed_dim)
        self.topic_proj = nn.Linear(300, embed_dim)

        # Learnable fusion weights
        self.raw_weights = nn.Parameter(torch.tensor([1.0, 1.0, 1.0]))  # [Œ±, Œ≤, Œ≥]

    def dense_adj_from_edges(self, g, etype, src_size, dst_size):
        src, dst = g.edges(etype=etype)
        adj = torch.zeros((src_size, dst_size), device=src.device)
        adj[src, dst] = 1
        return adj

    def forward(self, sentence_embs, entity_embs, topic_embs, graph, attention_breakdown=False):
        sent_scores = self.sentence_attn(sentence_embs).squeeze(1)

        entity_embs_proj = self.entity_proj(entity_embs)
        topic_embs_proj = self.topic_proj(topic_embs)

        # Entity attention
        if ('sentence', 'mentions', 'entity') in graph.canonical_etypes:
            adj_se = self.dense_adj_from_edges(graph, ('sentence', 'mentions', 'entity'),
                                               graph.num_nodes('sentence'), graph.num_nodes('entity'))
            ent_context = adj_se @ entity_embs_proj
            ent_scores = self.entity_attn(ent_context).squeeze(1)
        else:
            ent_scores = torch.zeros_like(sent_scores)

        # Topic attention
        if ('sentence', 'belongs_to', 'topic') in graph.canonical_etypes:
            adj_st = self.dense_adj_from_edges(graph, ('sentence', 'belongs_to', 'topic'),
                                               graph.num_nodes('sentence'), graph.num_nodes('topic'))
            topic_context = adj_st @ topic_embs_proj
            topic_scores = self.topic_attn(topic_context).squeeze(1)
        else:
            topic_scores = torch.zeros_like(sent_scores)

        weights = F.softmax(self.raw_weights, dim=0)
        alpha, beta, gamma = weights

        final_scores = alpha * sent_scores + beta * ent_scores + gamma * topic_scores

        if attention_breakdown:
            return final_scores, {
                'sentence': sent_scores.detach().cpu(),
                'entity': ent_scores.detach().cpu(),
                'topic': topic_scores.detach().cpu(),
                'weights': weights.detach().cpu()
            }

        return final_scores

In [None]:
# # Instantiate ML-GAP
# mlgap_scorer = MultiLevelGraphAttentionPooling(embed_dim=64)
# mlgap_scorer.eval()

# with torch.no_grad():
#     importance_scores = mlgap_scorer(
#         sentence_embs=node_embeddings['sentence'],  # ‚úÖ use only sentence embeddings
#         entity_embs=features['entity'],             # ‚úÖ entity features (300-dim)
#         topic_embs=features['topic'],               # ‚úÖ topic features (300-dim)
#         graph=graph
#     )

# # Select top-ranked sentences
# top_k = 5
# top_indices = torch.topk(importance_scores, k=top_k).indices
# top_indices_sorted = sorted(top_indices.tolist())

# # Generate final summary
# summary_sentences = [sent_nodes[i] for i in top_indices_sorted]

# import re
# def clean_sentence(text):
#     text = re.sub(r'@xmath\d+', '', text)
#     text = re.sub(r'@xcite', '', text)
#     text = re.sub(r'\s+', ' ', text)
#     return text.strip()

# cleaned_summary = [clean_sentence(sent) for sent in summary_sentences]

# # Display output
# for i, sent in enumerate(cleaned_summary, 1):
#     print(f"{i}. {sent}")

In [None]:
# Install ROUGE scorer (only once needed)
!pip install -q rouge-score

In [None]:
attention_data = []

In [None]:
from rouge_score import rouge_scorer
from statistics import mean
from tqdm import tqdm
import torch
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer

def clean_sentence(text):
    text = re.sub(r'@xmath\d+', '', text)
    text = re.sub(r'@xcite', '', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# ROUGE evaluator
rouge_eval = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# ML-GAP model in eval mode
mlgap_scorer = MultiLevelGraphAttentionPooling(embed_dim=64)
mlgap_scorer.eval()

# TF-IDF vectorizer for sentence features
vectorizer = TfidfVectorizer(max_features=300)
vectorizer.fit([" ".join(preprocess_article(dataset[i]['article'])[0]) for i in range(50)])

# Initialize metrics and storage
rouge1_f, rouge2_f, rougel_f = [], [], []
attention_data = []  # üîç for attention visualization
num_docs_to_eval = 20

# --- Evaluation loop ---
for i in tqdm(range(num_docs_to_eval)):
    article = dataset[i]['article']
    reference = clean_sentence(dataset[i]['abstract'])

    sents, ents = preprocess_article(article)
    if len(sents) < 5:
        continue

    doc_topics, _ = topic_model.transform(sents)
    g, sent_nodes, ent_nodes, topic_ids = build_graph_with_topics(sents, ents, precomputed_topics=doc_topics)

    sent_feats_np = vectorizer.transform(sent_nodes).toarray()
    if sent_feats_np.shape[1] < 300:
        pad_width = 300 - sent_feats_np.shape[1]
        sent_feats_np = np.pad(sent_feats_np, ((0, 0), (0, pad_width)), mode='constant')
    sent_feats = torch.tensor(sent_feats_np, dtype=torch.float32)

    ent_feats = torch.randn(g.num_nodes('entity'), 300)
    topic_feats = torch.randn(g.num_nodes('topic'), 300)
    feats = {'sentence': sent_feats, 'entity': ent_feats, 'topic': topic_feats}

    # --- Forward pass with attention breakdown ---
    with torch.no_grad():
        node_embeds = model(g, feats)['sentence']
        scores, parts = mlgap_scorer(
            sentence_embs=node_embeds,
            entity_embs=feats['entity'],
            topic_embs=feats['topic'],
            graph=g,
            attention_breakdown=True
        )

        top_idx = torch.topk(scores, k=5).indices
        top_idx_sorted = sorted(top_idx.tolist())
        summary_sents = [clean_sentence(sent_nodes[j]) for j in top_idx_sorted]
        generated_summary = " ".join(summary_sents)

        # Print per-sentence attention for summary
        print(f"\nüîç Attention Breakdown for Document {i + 1}:")
        print(f"Learned weights ‚Äî alpha: {parts['weights'][0]:.4f}, beta: {parts['weights'][1]:.4f}, gamma: {parts['weights'][2]:.4f}")
        for j in top_idx_sorted:
            print(f"- {clean_sentence(sent_nodes[j])}")
            print(f"  Sentence: {parts['sentence'][j]:.4f} | Entity: {parts['entity'][j]:.4f} | Topic: {parts['topic'][j]:.4f}")

        # üîª Save full attention values for visualization
        attention_data.append({
            "doc_id": i,
            "sentence_texts": [clean_sentence(s) for s in sent_nodes],
            "sentence_scores": parts['sentence'].cpu().numpy().tolist(),
            "entity_scores": parts['entity'].cpu().numpy().tolist(),
            "topic_scores": parts['topic'].cpu().numpy().tolist(),
            "weights": [x.item() for x in parts['weights']]
        })

    # --- ROUGE ---
    results = rouge_eval.score(reference, generated_summary)
    rouge1_f.append(results['rouge1'].fmeasure)
    rouge2_f.append(results['rouge2'].fmeasure)
    rougel_f.append(results['rougeL'].fmeasure)

# --- Final report ---
print("\nüìä Average ROUGE F1 Scores over", len(rouge1_f), "documents:")
print(f"ROUGE-1 F1: {mean(rouge1_f):.4f}")
print(f"ROUGE-2 F1: {mean(rouge2_f):.4f}")
print(f"ROUGE-L F1: {mean(rougel_f):.4f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_attention_breakdown(attention_data, doc_index=0, sentence_texts=None, top_k=5, sort_by_score=False):
    doc = attention_data[doc_index]

    sentence_scores = np.array(doc['sentence_scores'])
    entity_scores = np.array(doc['entity_scores'])
    topic_scores = np.array(doc['topic_scores'])
    final_scores = sentence_scores + entity_scores + topic_scores
    num_sents = len(final_scores)

    # Sorting if requested
    if sort_by_score:
        sort_order = np.argsort(final_scores)[::-1]
        sentence_scores = sentence_scores[sort_order]
        entity_scores = entity_scores[sort_order]
        topic_scores = topic_scores[sort_order]
        final_scores = final_scores[sort_order]
        labels = [f"S{i+1}" for i in sort_order]
        sorted_sentences = [sentence_texts[i] for i in sort_order] if sentence_texts else labels
    else:
        labels = [f"S{i+1}" for i in range(num_sents)]
        sorted_sentences = sentence_texts if sentence_texts else labels

    # üñ®Ô∏è Print top-k sentences
    print(f"\nüìå Top-{top_k} Summary Sentences for Document {doc['doc_id'] + 1}:\n")
    for rank in range(top_k):
        print(f"{rank+1}. {sorted_sentences[rank]}  (Score: {final_scores[rank]:.4f})")

    # üìä Plot
    x = np.arange(num_sents)
    width = 0.6
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x, sentence_scores, width, label='Sentence', color='skyblue')
    ax.bar(x, entity_scores, width, bottom=sentence_scores, label='Entity', color='orange')
    ax.bar(x, topic_scores, width, bottom=sentence_scores + entity_scores, label='Topic', color='lightgreen')

    ax.set_ylabel('Attention Contribution')
    ax.set_title(f'Attention Breakdown per Sentence (Doc {doc["doc_id"] + 1})')

    # ‚úÖ Only show every Nth tick label
    tick_interval = max(1, num_sents // 30)
    ax.set_xticks(x[::tick_interval])
    ax.set_xticklabels([labels[i] for i in range(0, num_sents, tick_interval)], rotation=45, ha='right')

    ax.legend(loc='upper right')
    ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)

    plt.tight_layout()
    plt.show()

In [None]:
plot_attention_breakdown(attention_data, doc_index=0, sentence_texts=sent_nodes, top_k=5, sort_by_score=True)

# graphs

## Heterogeneous Graph Structure

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import dgl

def visualize_graph_structure(g):
    """
    Plot bar charts of node and edge type counts.
    """
    node_counts = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
    edge_counts = {str(etype): g.num_edges(etype) for etype in g.canonical_etypes}

    # --- Node counts
    plt.figure(figsize=(6, 4))
    plt.bar(node_counts.keys(), node_counts.values())
    plt.title("Node Type Counts")
    plt.xlabel("Node Type")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.show()

    # --- Edge counts
    plt.figure(figsize=(8, 5))
    plt.bar(edge_counts.keys(), edge_counts.values(), color='orange')
    plt.title("Edge Type Counts")
    plt.xlabel("Edge Type")
    plt.ylabel("Count")
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

# Call this after building your graph:
visualize_graph_structure(graph)


## Mini Graph Visualization with NetworkX

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

def draw_mini_hetero_graph(g, sent_nodes, max_nodes=20, output_pdf_path='graph_output.pdf'):

    G = nx.MultiDiGraph()

    # Limit sentence nodes for clarity
    node_ids = list(range(min(len(sent_nodes), max_nodes)))

    for i in node_ids:
        G.add_node(f"Sent_{i}", type='sentence')

    if 'entity' in g.ntypes:
        for i in range(min(g.num_nodes('entity'), 10)):
            G.add_node(f"Ent_{i}", type='entity')

    if 'topic' in g.ntypes:
        for i in range(min(g.num_nodes('topic'), 5)):
            G.add_node(f"Topic_{i}", type='topic')

    # Add edges
    for src, dst in zip(*g.edges(etype=('sentence', 'similar_to', 'sentence'))):
        if src.item() in node_ids and dst.item() in node_ids:
            G.add_edge(f"Sent_{src.item()}", f"Sent_{dst.item()}", label='similar_to')

    if ('sentence', 'mentions', 'entity') in g.canonical_etypes:
        src, dst = g.edges(etype=('sentence', 'mentions', 'entity'))
        for s, e in zip(src, dst):
            if s.item() in node_ids and e.item() < 10:
                G.add_edge(f"Sent_{s.item()}", f"Ent_{e.item()}", label='mentions')

    if ('sentence', 'belongs_to', 'topic') in g.canonical_etypes:
        src, dst = g.edges(etype=('sentence', 'belongs_to', 'topic'))
        for s, t in zip(src, dst):
            if s.item() in node_ids and t.item() < 5:
                G.add_edge(f"Sent_{s.item()}", f"Topic_{t.item()}", label='belongs_to')

    # Create the plot
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, seed=42, k=0.6)
    node_colors = []
    for n in G.nodes(data=True):
        if n[1]['type'] == 'sentence':
            node_colors.append('skyblue')
        elif n[1]['type'] == 'entity':
            node_colors.append('orange')
        else:
            node_colors.append('lightgreen')

    # Draw the graph
    nx.draw(G, pos, with_labels=True, node_color=node_colors, node_size=800, font_size=8, arrows=True)
    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=7)

    # Title and appearance
    plt.title("Mini Heterogeneous Graph")
    plt.axis('off')
    plt.tight_layout()

    # Save the graph as a PDF file
    plt.savefig(output_pdf_path, format='pdf')

    # Show the plot (optional)
    plt.show()

    print(f"Graph saved to {output_pdf_path}")



## Attention Score Heatmap per Sentence

In [None]:
import seaborn as sns
import numpy as np

def plot_attention_heatmap(attention_data, doc_index=0, top_k=10):
    doc = attention_data[doc_index]
    scores = np.stack([
        doc["sentence_scores"],
        doc["entity_scores"],
        doc["topic_scores"]
    ])

    scores = scores[:, :top_k]  # Limit to top-k for clarity
    labels = [f"S{i+1}" for i in range(top_k)]
    sources = ['Sentence', 'Entity', 'Topic']

    plt.figure(figsize=(10, 4))
    sns.heatmap(scores, annot=True, xticklabels=labels, yticklabels=sources, cmap='YlGnBu')
    plt.title(f"Attention Breakdown Heatmap (Doc {doc['doc_id'] + 1})")
    plt.xlabel("Sentence Index")
    plt.ylabel("Attention Source")
    plt.tight_layout()
    plt.show()

# Call with your saved attention_data
plot_attention_heatmap(attention_data, doc_index=0, top_k=10)


## Combine with Summary Output

In [None]:
def print_summary_from_attention(attention_data, sent_nodes, doc_index=0, top_k=5):
    doc = attention_data[doc_index]
    sentence_scores = np.array(doc['sentence_scores']) + \
                      np.array(doc['entity_scores']) + \
                      np.array(doc['topic_scores'])

    top_idx = np.argsort(sentence_scores)[::-1][:top_k]
    print(f"\nüìå Top-{top_k} Summary Sentences for Document {doc['doc_id'] + 1}:\n")
    for rank, idx in enumerate(top_idx):
        print(f"{rank + 1}. {sent_nodes[idx]}")

# Call after inference
print_summary_from_attention(attention_data, sent_nodes, doc_index=0)
