In [1]:
import json
import numpy as np
import faiss
import igraph as ig
import leidenalg
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

In [2]:
from sentence_transformers import SentenceTransformer


model = SentenceTransformer("all-mpnet-base-v2")

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
def compute_cosine_similarity(vec1, vec2):
   
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    if norm1 == 0 or norm2 == 0:
        return 0.0
    return np.dot(vec1, vec2) / (norm1 * norm2)

def compute_iou(entities_a, entities_b):
    
    intersection = entities_a.intersection(entities_b)
    union = entities_a.union(entities_b)
    if not union:
        return 0.0
    return len(intersection) / len(union)

def extract_entities(relationships):
    
    entities = set()
    for rel in relationships:
        if len(rel) >= 3:
            
            subject = rel[0].replace("\\", "").strip()
            object_ = rel[2].replace("\\", "").strip()
            entities.add(subject)
            entities.add(object_)
    return entities

In [3]:
def build_graph(index, article_ids, relationships_data, num_neighbors=50):
    
    article_to_entities = {}
    for aid in article_ids:
        data = relationships_data.get(aid, {})
        rels = data.get("relationships", [])
        article_to_entities[aid] = extract_entities(rels)
    
    
    g = ig.Graph()
    g.add_vertices([str(aid) for aid in article_ids])
    
    for aid in article_ids:
        # Convert article ID to int for FAISS.
        aid_int = int(aid)
        # Use the inner index for reconstructing the vector.
        query_vector = index.index.reconstruct(aid_int).reshape(1, -1)
        distances, retrieved_ids = index.index.search(query_vector, num_neighbors + 1)
        
        for neighbor in retrieved_ids[0]:
            # Convert neighbor to int explicitly.
            neighbor_int = int(neighbor)
            if neighbor_int == aid_int:
                continue
            neighbor_str = str(neighbor_int)
            neighbor_vector = index.index.reconstruct(neighbor_int)
            article_vector = query_vector.flatten()
            cos_sim = compute_cosine_similarity(article_vector, neighbor_vector)
            
            entities_a = article_to_entities.get(aid, set())
            entities_b = article_to_entities.get(neighbor_str, set())
            iou = compute_iou(entities_a, entities_b)
            
            weight = cos_sim + iou
            
            if not g.are_connected(str(aid), neighbor_str):
                g.add_edge(str(aid), neighbor_str, weight=weight)
    
    return g


In [None]:
def build_graph_parallel(index, article_ids, relationships_data, 
                         num_neighbors=50, num_workers=8, 
                         alpha=1.0, beta=1.0, weight_threshold=0.0):
    """
    Build a weighted undirected graph in parallel.
    Each node is an article and each edge weight is computed as:
        weight = alpha * cosine_similarity + beta * IOU
    Edges with weight below weight_threshold are skipped.
    Also counts the number of edges (per article) that are "cut" (i.e. not added).
    Uses the inner FAISS index (index.index) to reconstruct embeddings.
    """
    # Build mapping from article ID to its entity set.
    article_to_entities = {}
    for aid in article_ids:
        data = relationships_data.get(aid, {})
        rels = data.get("relationships", [])
        article_to_entities[aid] = extract_entities(rels)
    
    # Create an undirected graph with vertices named by article IDs.
    g = ig.Graph()
    g.add_vertices([str(aid) for aid in article_ids])
    
    def process_article(aid):
        """
        Process a single article: retrieve neighbors and compute edge weights.
        Returns:
            aid (str): the article ID.
            edges (list): list of (aid, neighbor, weight) tuples.
            cut_count (int): number of neighbor edges skipped due to low weight.
        """
        edges = []
        cut_count = 0
        aid_int = int(aid)
        query_vector = index.index.reconstruct(aid_int).reshape(1, -1)
        distances, retrieved_ids = index.index.search(query_vector, num_neighbors + 1)
        for neighbor in retrieved_ids[0]:
            neighbor_int = int(neighbor)
            if neighbor_int == aid_int:
                continue
            neighbor_str = str(neighbor_int)
            neighbor_vector = index.index.reconstruct(neighbor_int)
            article_vector = query_vector.flatten()
            cos_sim = compute_cosine_similarity(article_vector, neighbor_vector)
            entities_a = article_to_entities.get(aid, set())
            entities_b = article_to_entities.get(neighbor_str, set())
            iou = compute_iou(entities_a, entities_b)
            weight = alpha * cos_sim + beta * iou # time relevance bhi daal doun.
            #print(iou)
            if weight < weight_threshold:
                cut_count += 1
                continue  # Skip low-weight edge.
            edges.append((str(aid), neighbor_str, weight))
        #print(f"Article {aid}: processed {len(edges)} edges, cut {cut_count} edges.")
        return aid, edges, cut_count
    
    edges_all = []
    cut_counter = {}  # Dictionary to count cut edges per article.
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_article, aid): aid for aid in article_ids}
        for future in tqdm(as_completed(futures), total=len(article_ids), desc="Processing articles"):
            try:
                aid, edges, cut_count = future.result()
                edges_all.extend(edges)
                cut_counter[aid] = cut_count
            except Exception as e:
                print("Error processing article", futures[future], e)
    
    # Print the summary of cut edges.
    total_cut = sum(cut_counter.values())
    print(f"Total edges cut (skipped): {total_cut}")
    
    # Add edges to the graph, avoiding duplicates (using are_adjacent now)
    unique_edges = {}
    for a, b, weight in edges_all:
        # For undirected graph, sort the vertex pair.
        key = tuple(sorted((a, b)))
        if key not in unique_edges:
            unique_edges[key] = weight
        else:
            # Optionally, update the weight (e.g., average or max); here, we keep the first.
            pass

    # Bulk add edges.
    edge_list = list(unique_edges.keys())
    weights = list(unique_edges.values())
    g.add_edges(edge_list)
    g.es["weight"] = weights
    print(f"Added {len(edge_list)} unique edges to the graph.")
    
    return g, cut_counter

In [3]:
def detect_communities(graph):
    """
    Run the Leiden community detection algorithm on the given graph.
    Returns a dictionary mapping community IDs to lists of article IDs.
    """
    print("Starting community detection...")
    partition = leidenalg.find_partition(
        graph, 
        leidenalg.RBConfigurationVertexPartition, 
        weights='weight'
    )
    print("Community detection completed.")
    communities = {}
    for idx, community in enumerate(partition):
        nodes = [graph.vs[v]["name"] for v in community]
        communities[idx] = nodes
    print(f"Detected {len(communities)} communities.")
    return communities

def save_communities(communities, filename="communities.json"):
    """Save the communities dictionary as a JSON file."""
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(communities, f, indent=4)
    print(f"Communities saved to {filename}")

def load_communities(filename="communities.json"):
    """Load the communities dictionary from a JSON file."""
    with open(filename, "r", encoding="utf-8") as f:
        communities = json.load(f)
    print(f"Communities loaded from {filename}")
    return communities

# Detect & Assign Communities

In [5]:
def detect_and_assign_communities(graph):
    """
    Run the Leiden community detection algorithm on the given graph.
    This assigns a "community" attribute to each node and returns a dictionary
    mapping community IDs to lists of article IDs (node names).
    """
    print("Starting community detection...")
    partition = leidenalg.find_partition(
        graph, 
        leidenalg.RBConfigurationVertexPartition, 
        weights='weight'
    )
    print("Community detection completed.")
    communities = {}
    for idx, community in enumerate(partition):
        # Get node names for the community
        nodes = [graph.vs[v]["name"] for v in community]
        communities[idx] = nodes
        # Assign community label to each node in the community
        for v in community:
            graph.vs[v]["community"] = idx
    print(f"Detected {len(communities)} communities.")
    return communities

In [6]:
def save_graph_as_json(graph, filename="graph_with_communities.json"):
    """
    Export the graph (nodes and edges, including community and weight information)
    as a JSON file.
    """
    nodes = []
    for node in graph.vs:
        nodes.append({
            "id": node["name"],
            "community": node["community"] if "community" in node.attributes() else None
        })
    edges = []
    for edge in graph.es:
        # Get source and target node names.
        source = graph.vs[edge.tuple[0]]["name"]
        target = graph.vs[edge.tuple[1]]["name"]
        weight = edge["weight"] if "weight" in edge.attributes() else None
        # Convert weight to a Python float if it is not None.
        if weight is not None:
            weight = float(weight)
        edges.append({
            "source": source,
            "target": target,
            "weight": weight
        })
    data = {"nodes": nodes, "edges": edges}
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4)
    print(f"Graph saved as JSON to {filename}")


In [None]:
def embed_query(query, index):
    """
    Embed the query text into the same vector space as your article embeddings using SentenceTransformer.
    
    Args:
        query (str): The query string.
        index (faiss.Index): The FAISS index (used here only for dimension consistency, if needed).
    
    Returns:
        np.ndarray: The embedding of the query.
    """
    # Use the model to encode the query. 
    # Ensure that your article embeddings were generated using the same model.
    return model.encode(query)

def get_top_k_articles_for_query(query, index, communities, k, embed_query_fn):
    """
    Given a query string, find the top k articles within the community that
    best matches the query. The process is:
      1. Compute the query's embedding.
      2. Retrieve the nearest article from FAISS.
      3. Identify the community that contains this article.
      4. Compute cosine similarities between the query and each article in that community.
      5. Return the top k articles from the community ranked by similarity.
    
    Parameters:
        query        : The query string from the user.
        index        : The FAISS index (with stored embeddings).
        communities  : Dictionary mapping community IDs to lists of article IDs.
        k            : Number of articles to return.
        embed_query_fn: Function that converts the query text into an embedding.
    
    Returns:
        A list of top k article IDs (as strings).
    """
    # Compute the query embedding.
    q_emb = embed_query_fn(query, index).reshape(1, -1)
    
    # Retrieve the nearest neighbor (we use 1 neighbor to decide the community).
    distances, retrieved_ids = index.index.search(q_emb, 1)
    nearest_id = retrieved_ids[0][0]
    nearest_id_str = str(nearest_id)
    
    # Identify the community that contains the nearest article.
    target_community = None
    for comm_articles in communities.values():
        if nearest_id_str in comm_articles:
            target_community = comm_articles
            break
    if target_community is None:
        print("No community found for the nearest article. Falling back to global top k from FAISS index.")
        distances, retrieved_ids = index.index.search(q_emb, k)
        return [str(x) for x in retrieved_ids[0]]
    
    # Compute cosine similarity between the query and each article in the target community.
    sims = []
    for article_id in target_community:
        article_emb = index.index.reconstruct(int(article_id))
        sim = compute_cosine_similarity(q_emb.flatten(), article_emb)
        sims.append((article_id, sim))
    
    # Sort articles by similarity in descending order.
    sims.sort(key=lambda x: x[1], reverse=True)
    
    # Return the top k article IDs.
    top_k_articles = [article_id for article_id, sim in sims[:k]]
    return top_k_articles


In [8]:
if __name__ == "__main__":
    # Load the relationships JSON file.
    with open("entity_relationships.json", "r", encoding="utf-8") as f:
        relationships_data = json.load(f)
    
    # Load the articles JSON file (mapping from article ID to article data).
    with open("id_to_article_D.json", "r", encoding="utf-8") as f:
        articles_data = json.load(f)
    
    # Extract article IDs. Assuming keys in articles_data are article IDs.
    article_ids = list(articles_data.keys())
    # Optionally, sort them by numeric value if the keys are numeric.
    article_ids.sort(key=int)
    
    # Load the FAISS index from file (which already contains the embeddings).
    index = faiss.read_index("embeddings_index.faiss")
    
    # Build the weighted graph using the FAISS index.
    alpha = 1.0      # weight multiplier for cosine similarity
    beta = 2.0       # weight multiplier for IOU
    weight_threshold = 0.5  # clip edges with weight below this threshold
    
    # Build the weighted graph in parallel using TQDM.
    g, cut_counter = build_graph_parallel(index, article_ids, relationships_data, 
                             num_neighbors=50, num_workers=8, 
                             alpha=alpha, beta=beta, weight_threshold=weight_threshold)

    print("GraphML file 'graph_with_communities.graphml' created. You can now load it in Gephi.")
    
    # Print a sample entry from the graph and its neighbors.
    sample_node = g.vs[0]["name"]
    print(f"\nSample node: {sample_node}")
    neighbors = g.neighbors(sample_node, mode="all")
    for n in neighbors:
        neighbor_name = g.vs[n]["name"]
        # Get the edge connecting sample_node and neighbor_name.
        edge_id = g.get_eid(sample_node, neighbor_name)
        edge_weight = g.es[edge_id]["weight"]
        print(f"  Neighbor: {neighbor_name} (weight: {edge_weight:.4f})")
    
    # Detect communities using the Leiden algorithm.
    communities = detect_and_assign_communities(g)
    
    # Save and then load the communities.
    save_communities(communities, filename="communities_b2.json")
    save_graph_as_json(g, filename="graph_with_communities.json")
    #g.write_graphml("graph_with_communities.graphml")
    

Processing articles: 100%|██████████| 61314/61314 [05:11<00:00, 196.82it/s]


Total edges cut (skipped): 94198
Added 2280094 unique edges to the graph.
GraphML file 'graph_with_communities.graphml' created. You can now load it in Gephi.

Sample node: 0
  Neighbor: 64 (weight: 0.5904)
  Neighbor: 137 (weight: 1.8388)
  Neighbor: 142 (weight: 0.5681)
  Neighbor: 151 (weight: 0.6971)
  Neighbor: 154 (weight: 0.5681)
  Neighbor: 155 (weight: 0.5695)
  Neighbor: 291 (weight: 0.7871)
  Neighbor: 573 (weight: 0.6800)
  Neighbor: 576 (weight: 0.7084)
  Neighbor: 1052 (weight: 0.5121)
  Neighbor: 1080 (weight: 0.7082)
  Neighbor: 1181 (weight: 0.6948)
  Neighbor: 1365 (weight: 0.7074)
  Neighbor: 1551 (weight: 0.7701)
  Neighbor: 2128 (weight: 0.7520)
  Neighbor: 3689 (weight: 0.5036)
  Neighbor: 4062 (weight: 0.6352)
  Neighbor: 4648 (weight: 0.5874)
  Neighbor: 9697 (weight: 0.6664)
  Neighbor: 11522 (weight: 0.6025)
  Neighbor: 11931 (weight: 0.5520)
  Neighbor: 12050 (weight: 0.6465)
  Neighbor: 12571 (weight: 0.5548)
  Neighbor: 15660 (weight: 0.5452)
  Neighbor: 16

In [13]:
g.write_graphml("graph_with_communities_p.graphml")


In [3]:
with open("id_to_article_D.json", "r", encoding="utf-8") as f:
        articles_data = json.load(f)
    
    # Extract article IDs. Assuming keys in articles_data are article IDs.
article_ids = list(articles_data.keys())
    # Optionally, sort them by numeric value if the keys are numeric.
article_ids.sort(key=int)
    

In [4]:
index = faiss.read_index("embeddings_index.faiss")

In [58]:
loaded_communities = load_communities(filename="communities.json")

Communities loaded from communities.json


In [71]:
user_query ="Russian and Ukrain War"
k = 1000  # Number of top articles to retrieve
    
top_articles = get_top_k_articles_for_query(user_query, index, loaded_communities, k, embed_query)
    
print(f"\nTop {k} articles for the query: '{user_query}'")
for idx, article_id in enumerate(top_articles):
    article_info = articles_data.get(article_id, {})
    title = article_info.get("title", "No Title")
    print(f"{idx+1}. Article ID: {article_id} - Title: {title}")


Top 1000 articles for the query: 'Russian and Ukrain War'
1. Article ID: 4980 - Title: Putin draws parallels between WWII and Ukraine conflict
2. Article ID: 34256 - Title: Kremlin aide rewrites Russian history for a society at war

3. Article ID: 26401 - Title: Russia reports fierce fighting on three sections of frontline

4. Article ID: 27319 - Title: Russian mercenary force halts march on Moscow

5. Article ID: 28399 - Title: Ukrainians and Russians fight  for every yard

6. Article ID: 8392 - Title: First Western tanks arrive in Ukraine

7. Article ID: 11075 - Title: Putin says Russia is fighting for its very existence
8. Article ID: 24928 - Title: Unwinnable war

9. Article ID: 8662 - Title: Putin calls Ukraine war a battle for Russia’s survival

10. Article ID: 8524 - Title: Ukraine anniversary

11. Article ID: 22311 - Title: Wagner group hands over  Bakhmut to Russian army

12. Article ID: 35154 - Title: Around 45,000 Muscovites fight in Ukraine: mayor

13. Article ID: 8965 - T

# Simple RAG

In [5]:
def embed_query(query, index):
    return model.encode(query)

def get_top_k_articles(query, index, k, embed_query_fn):
    
    q_emb = embed_query_fn(query, index).reshape(1, -1)
    
    
    distances, retrieved_ids = index.index.search(q_emb, k)
    
    
    return [str(article_id) for article_id in retrieved_ids[0]]


In [44]:
user_query ="Toshakhana Case"
k = 50
    
top_articles = get_top_k_articles(user_query, index, k, embed_query)
    
print(f"\nTop {k} articles for the query: '{user_query}'")
for idx, article_id in enumerate(top_articles):
    article_info = articles_data.get(article_id, {})
    title = article_info.get("title", "No Title")
    print(f"{idx+1}. Article ID: {article_id} - Title: {title}")


Top 50 articles for the query: 'Toshakhana Case'
1. Article ID: 5717 - Title: Toshakhana reference: Imran’s indictment deferred as Islamabad court accepts PTI chief’s exemption plea
2. Article ID: 29082 - Title: Toshakhana case against PTI chief declared maintainable
3. Article ID: 24745 - Title: Govt gets time to submit Toshakhana details

4. Article ID: 7819 - Title: Imran’s indictment in Toshakhana reference deferred yet again
5. Article ID: 33125 - Title: IHC set to decide fate of Toshakhana case

6. Article ID: 29183 - Title: Toshakhana reference ‘maintainable’, trial to begin next week

7. Article ID: 34000 - Title: Toshakhana gifts: Plea for cases against MPs

8. Article ID: 35944 - Title: SC to hear Imran’s Toshakhana appeal tomorrow

9. Article ID: 9888 - Title: Toshakhana case: IHC suspends Imran’s arrest warrants till March 13
10. Article ID: 32850 - Title: SC to hear Imran’s plea in Toshakhana case today

11. Article ID: 12705 - Title: Toshakhana in its true perspective

1

# Code For Graphml

In [9]:
import json
import numpy as np
import faiss  # Make sure FAISS is installed and imported
import random

# These functions are assumed to be defined elsewhere:
# def embed_query(query, index): ...
# def compute_cosine_similarity(vec1, vec2): ...

def load_graph_json(filename="graph_with_communities.json"):
    """
    Load the graph (nodes and edges, with community info) from a JSON file.
    The JSON file is expected to have the structure:
      {"nodes": [{"id": <node_id>, "community": <community_label>}, ...],
       "edges": [{"source": <node_id>, "target": <node_id>, "weight": <weight>}, ...]}
    """
    with open(filename, "r", encoding="utf-8") as f:
        graph_data = json.load(f)
    return graph_data




In [10]:
def build_neighbors_dict(graph_data):
    """
    Build a dictionary mapping each node id to a list of its one-hop neighbors
    using the "edges" information from the graph JSON.
    """
    neighbors = {}
    # Initialize neighbors for each node from the nodes list.
    for node in graph_data["nodes"]:
        neighbors[node["id"]] = set()
    # Populate the dictionary using edges.
    for edge in graph_data["edges"]:
        source = edge["source"]
        target = edge["target"]
        neighbors[source].add(target)
        neighbors[target].add(source)
    # Convert neighbor sets to lists.
    for node_id in neighbors:
        neighbors[node_id] = list(neighbors[node_id])
    return neighbors

In [11]:
def build_multi_hop_neighbors_dict(graph_data, hops=1):
    """
    Build a neighbor dictionary mapping each node id to all nodes that are reachable
    within 'hops' hops (excluding the node itself).
    
    Parameters:
       graph_data: JSON data with "nodes" and "edges"
       hops: number of hops to include (e.g., hops=2 returns one- and two-hop neighbors)
    
    Returns:
       A dictionary where keys are node IDs and values are lists of reachable node IDs.
    """
    one_hop = build_neighbors_dict(graph_data)
    multi_hop = {}
    for node in one_hop.keys():
        visited = {node}  # include self to avoid cycles
        current_level = {node}
        for _ in range(hops):
            next_level = set()
            for n in current_level:
                for neighbor in one_hop.get(n, []):
                    if neighbor not in visited:
                        visited.add(neighbor)
                        next_level.add(neighbor)
            current_level = next_level
            if not current_level:
                break
        # Exclude the node itself from its neighbor list.
        visited.remove(node)
        multi_hop[node] = list(visited)
    return multi_hop

In [12]:
def query_retriever_with_multi_hop(query, index, embed_query_fn, multi_hop_neighbors_dict, threshold=0.5, n_seed=3):
    # Compute the query embedding.
    q_emb = embed_query_fn(query, index).reshape(1, -1)
    
    # Retrieve n_seed seed articles from FAISS.
    distances, seed_ids = index.index.search(q_emb, n_seed)
    seed_ids = [str(x) for x in seed_ids[0]]
    
    # Initialize candidate set with the seed articles.
    candidate_ids = set(seed_ids)
    # need their entities
    # For each seed, add its multi-hop neighbors.
    for seed in seed_ids:
        if seed in multi_hop_neighbors_dict:
            candidate_ids.update(multi_hop_neighbors_dict[seed])
    candidate_ids = list(candidate_ids)
    
    # Compute cosine similarity between the query and each candidate's embedding.
    sims = []
    for article_id in candidate_ids:
        # Reconstruct the candidate embedding using FAISS.
        article_emb = index.index.reconstruct(int(article_id))
        sim = compute_cosine_similarity(q_emb.flatten(), article_emb)
        sims.append((article_id, sim))
    
    # Filter candidates by the similarity threshold.
    filtered_candidates = [ (aid, s) for aid, s in sims if s >= threshold ]
    
    # Sort the filtered candidates by similarity in descending order.
    filtered_candidates.sort(key=lambda x: x[1], reverse=True)
    
    # Return just the article IDs that meet the threshold.
    return [article_id for article_id, sim in filtered_candidates]


# For Ent Overlap Test


In [13]:
def query_retriever_with_multi_hop_ret(query, index, embed_query_fn, multi_hop_neighbors_dict, threshold=0.5, n_seed=3):
    """
    Retrieve articles for a query using multi-hop neighbor expansion while tracking each seed.
    
    Steps:
      1. Embed the query and retrieve n_seed seed articles from FAISS.
      2. For each seed article, collect its multi-hop neighbors (the candidate set).
      3. Compute cosine similarity between the query embedding and each candidate's embedding.
      4. Filter candidates by the similarity threshold.
      5. Return a dictionary mapping each seed article to its list of candidate articles (and similarities).
    
    Returns:
      A dictionary: { seed_article_id: [(candidate_article_id, similarity), ...], ... }
    """
    # Compute the query embedding.
    q_emb = embed_query_fn(query, index).reshape(1, -1)
    
    # Retrieve n_seed seed articles from FAISS.
    distances, seed_ids = index.index.search(q_emb, n_seed)
    seed_ids = [str(x) for x in seed_ids[0]]
    
    # For each seed, compute candidates and track their similarities.
    results = {}
    for seed in seed_ids:
        # Initialize candidate set with the seed itself.
        candidate_ids = set([seed])
        # For each seed, add its multi-hop neighbors.
        if seed in multi_hop_neighbors_dict:
            candidate_ids.update(multi_hop_neighbors_dict[seed])
        candidate_ids = list(candidate_ids)
        
        # Compute cosine similarity for each candidate.
        sims = []
        for article_id in candidate_ids:
            # Reconstruct the candidate embedding using FAISS.
            article_emb = index.index.reconstruct(int(article_id))
            sim = compute_cosine_similarity(q_emb.flatten(), article_emb)
            if sim >= threshold:
                sims.append((article_id, sim))
        
        # Sort candidates by similarity in descending order.
        sims.sort(key=lambda x: x[1], reverse=True)
        results[seed] = sims
        
    return results


In [14]:




# Example usage:
if __name__ == "__main__":
    # Load the graph JSON (which should have been exported with node and edge information)
    graph_data = load_graph_json("graph_with_communities.json")
    #neighbors_dict = build_neighbors_dict(graph_data)
    hops = 2  # Change this value to control how many hops to include.
    multi_hop_neighbors = build_multi_hop_neighbors_dict(graph_data, hops=hops)
    #with open("multi_hop_neighbors_dict.json", "w") as f:
    #    json.dump(multi_hop_neighbors, f, indent=4)
    
    # Example: Retrieve one-hop neighbors for a given article.
    #example_article_id = "123"  # Replace with an actual article id from your graph.
    #one_hop_articles = one_hop_retriever(example_article_id, neighbors_dict)
    #print(f"One-hop neighbors for article {example_article_id}:", one_hop_articles)
    
    # Example: Retrieve articles for a given query using FAISS and one-hop connections.
    

In [17]:
query_text = "Toshakhana Case"
top_articles_with_multi_hop = query_retriever_with_multi_hop_ret(query_text, index, embed_query, multi_hop_neighbors, threshold=0.5, n_seed=3)
print("Top articles (with multi-hop expansion|):", top_articles_with_multi_hop)

Top articles (with multi-hop expansion|): {'5717': [('5717', np.float32(0.6368177)), ('29082', np.float32(0.62874717)), ('24745', np.float32(0.6226366)), ('7819', np.float32(0.6076301)), ('33125', np.float32(0.60314494)), ('29183', np.float32(0.60306144)), ('34000', np.float32(0.60108346)), ('35944', np.float32(0.59950286)), ('9888', np.float32(0.59555775)), ('32850', np.float32(0.5951112)), ('12705', np.float32(0.5919544)), ('31017', np.float32(0.59173936)), ('8119', np.float32(0.59102863)), ('60319', np.float32(0.59080213)), ('4662', np.float32(0.5892629)), ('7823', np.float32(0.5880902)), ('10932', np.float32(0.58222616)), ('8287', np.float32(0.58105904)), ('8184', np.float32(0.57713073)), ('29653', np.float32(0.57635087)), ('32999', np.float32(0.5762695)), ('15910', np.float32(0.57610816)), ('2860', np.float32(0.575664)), ('1481', np.float32(0.5756075)), ('31986', np.float32(0.57509357)), ('20355', np.float32(0.57444876)), ('28566', np.float32(0.5740169)), ('18255', np.float32(0.57

In [18]:
print(len(top_articles_with_multi_hop))

3


In [34]:
top_articles_with_multi_hop

{'5717': [('5717', np.float32(0.6368177)),
  ('29082', np.float32(0.62874717)),
  ('24745', np.float32(0.6226366)),
  ('7819', np.float32(0.6076301)),
  ('33125', np.float32(0.60314494)),
  ('29183', np.float32(0.60306144)),
  ('34000', np.float32(0.60108346)),
  ('35944', np.float32(0.59950286)),
  ('9888', np.float32(0.59555775)),
  ('32850', np.float32(0.5951112)),
  ('12705', np.float32(0.5919544)),
  ('31017', np.float32(0.59173936)),
  ('8119', np.float32(0.59102863)),
  ('60319', np.float32(0.59080213)),
  ('4662', np.float32(0.5892629)),
  ('7823', np.float32(0.5880902)),
  ('10932', np.float32(0.58222616)),
  ('8287', np.float32(0.58105904)),
  ('8184', np.float32(0.57713073)),
  ('29653', np.float32(0.57635087)),
  ('32999', np.float32(0.5762695)),
  ('15910', np.float32(0.57610816)),
  ('2860', np.float32(0.575664)),
  ('1481', np.float32(0.5756075)),
  ('31986', np.float32(0.57509357)),
  ('20355', np.float32(0.57444876)),
  ('28566', np.float32(0.5740169)),
  ('18255', np.

In [20]:
def extract_article_ids_from_results(results):
    article_ids = set()
    for seed, candidates in results.items():
        article_ids.add(seed)
        for candidate, _ in candidates:
            article_ids.add(candidate)
    return article_ids

In [36]:
def get_articles_from_seeds(seed_results, k):
    articles_set = set()
    
    for seed, neighbors in seed_results.items():
        # Get the top k neighbors; assumes the list is sorted in descending order of similarity.
        top_k_neighbors = neighbors[:k]
        for article_id, similarity in top_k_neighbors:
            articles_set.add(article_id)
    
    return articles_set

In [42]:
comp_articles = get_articles_from_seeds(top_articles_with_multi_hop,50)

In [43]:
print(f"\nTop {len(comp_articles)} articles for the query: '{query_text}'")
for idx, article_id in enumerate(comp_articles):
    article_info = articles_data.get(article_id, {})
    title = article_info.get("title", "No Title")
    print(f"{idx+1}. Article ID: {article_id} - Title: {title}")


Top 50 articles for the query: 'Toshakhana Case'
1. Article ID: 18255 - Title: Islamabad court summons Imran on May 10 for indictment in Toshakhana case
2. Article ID: 9888 - Title: Toshakhana case: IHC suspends Imran’s arrest warrants till March 13
3. Article ID: 32999 - Title: SC rejects Imran’s plea to halt trial in Toshakhana case

4. Article ID: 10788 - Title: ‘Incomplete list’: PTI demands disclosure of Toshakhana gifts retained by generals and judges
5. Article ID: 7819 - Title: Imran’s indictment in Toshakhana reference deferred yet again
6. Article ID: 31149 - Title: PTI counsel tries to drag on Toshakhana case hearing

7. Article ID: 29183 - Title: Toshakhana reference ‘maintainable’, trial to begin next week

8. Article ID: 59080 - Title: NAB team formed for Toshakhana case

9. Article ID: 10902 - Title: Edict issued against Toshakhana gifts
10. Article ID: 12705 - Title: Toshakhana in its true perspective

11. Article ID: 10913 - Title: LHC orders govt to submit remaining 

In [18]:
def load_entity_relationships_for_articles(article_ids, filename="entity_relationships.json"):
    with open(filename, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    filtered_relationships = {}
    for article_id in article_ids:
        if article_id in data:
            triplets = data[article_id].get("relationships", [])
            cleaned_triplets = set()
            for triplet in triplets:
                # Remove backslashes from each element.
                cleaned = tuple(elem.replace("\\", "") for elem in triplet)
                cleaned_triplets.add(cleaned)
            filtered_relationships[article_id] = cleaned_triplets
    return filtered_relationships

In [20]:
def compute_entity_overlap(results, filtered_entity_relationships):
    new_results = {}
    for seed, candidates in results.items():
        seed_entities = filtered_entity_relationships.get(seed, set())
        candidate_results = []
        for candidate, sim in candidates:
            candidate_entities = filtered_entity_relationships.get(candidate, set())
            overlap = list(seed_entities.intersection(candidate_entities))
            candidate_results.append({
                "candidate": candidate,
                "similarity": sim,
                "overlap": overlap
            })
        candidate_results.sort(key=lambda x: x["similarity"], reverse=True)
        new_results[seed] = candidate_results
    return new_results

In [21]:
article_ids = extract_article_ids_from_results(top_articles_with_multi_hop)
print("Extracted article IDs:", article_ids)

# 2. Load only the relevant entity relationships.
filtered_entity_relationships = load_entity_relationships_for_articles(article_ids, filename="entity_relationships.json")
print("Loaded entity relationships for", len(filtered_entity_relationships), "articles.")

# 3. Compute the entity overlap.
final_results = compute_entity_overlap(top_articles_with_multi_hop, filtered_entity_relationships)
print("Final results with entity overlaps:")
print(final_results)

Extracted article IDs: {'10932', '36163', '31293', '17129', '12092', '54268', '35051', '33062', '59258', '31017', '11524', '10913', '29764', '32902', '31791', '55567', '19429', '2463', '30596', '8184', '18371', '52719', '34000', '35944', '33436', '24544', '33200', '3089', '29183', '31986', '28463', '45634', '37158', '24745', '28322', '14944', '12197', '14558', '10918', '52213', '29677', '12705', '10630', '36994', '10907', '42949', '28772', '2694', '36311', '36318', '15910', '33818', '7819', '18397', '19233', '10788', '8119', '5746', '60319', '4745', '32850', '5717', '35690', '11433', '10786', '59080', '31149', '20355', '24664', '7823', '4662', '33772', '10301', '5827', '1481', '12108', '10902', '31937', '14622', '32999', '2860', '29082', '33125', '10787', '28872', '11275', '28566', '15178', '28375', '24679', '2060', '31906', '34039', '18255', '9888', '29653', '2989', '33203', '8287', '9749', '10747'}
Loaded entity relationships for 100 articles.
Final results with entity overlaps:
{'57

In [23]:
print(f"\nTop results for the query: '{query_text}'")
for seed, candidates in final_results.items():
    seed_info = articles_data.get(seed, {})
    seed_title = seed_info.get("title", "No Title")
    print(f"\nSeed Article ID: {seed} - Title: {seed_title}")
    for idx, candidate_info in enumerate(candidates, start=1):
        candidate_id = candidate_info["candidate"]
        similarity = candidate_info["similarity"]
        overlap = candidate_info["overlap"]
        candidate_info_data = articles_data.get(candidate_id, {})
        candidate_title = candidate_info_data.get("title", "No Title")
        print(f"  {idx}. Candidate Article ID: {candidate_id} - Title: {candidate_title} - Similarity: {similarity:.2f}")
        if overlap:
            print("       Overlapping Triplets:")
            for triplet in overlap:
                print("         -", triplet)
        else:
            print("       No overlapping triplets.")


Top results for the query: 'Toshakhana Case'

Seed Article ID: 5717 - Title: Toshakhana reference: Imran’s indictment deferred as Islamabad court accepts PTI chief’s exemption plea
  1. Candidate Article ID: 5717 - Title: Toshakhana reference: Imran’s indictment deferred as Islamabad court accepts PTI chief’s exemption plea - Similarity: 0.64
       Overlapping Triplets:
         - ('Islamabad', 'accepted', "Imran Khan's")
         - ('ECP', 'represented', 'Gohar Ali Khan')
         - ('Ali Bukhari', 'represented', 'ECP')
         - ("Imran Khan's", 'accepted', 'Islamabad')
         - ('ECP', 'represented', 'Ali Bukhari')
         - ('Toshakhana', 'accepted', 'Islamabad')
         - ('Bukhari', 'According', '"If Imran\'s')
         - ('Islamabad', 'accepted', 'PTI')
         - ('ECP', 'represented', 'PTI')
         - ('Saad Hasan', 'represented', 'ECP')
         - ('PTI', 'responded', 'ECP')
         - ('PTI', 'accepted', "Imran Khan's")
         - ('Zafar Iqbal', 'conducted', 'Additi

In [22]:
from pyvis.network import Network

def build_overlap_graph(final_results):
    """
    Build a directed graph from overlapping triplets.
    Each triplet is assumed to be a tuple (subject, relation, object).
    Nodes are entities and directed edges go from subject to object with the relation as the label.
    
    Parameters:
      final_results: Dictionary mapping seed article IDs to a list of candidate dictionaries,
                     where each candidate dict has an "overlap" key that is a list of triplets.
    
    Returns:
      A pyvis Network object.
    """
    # Create a directed network.
    net = Network(directed=True)
    
    # Keep track of added nodes and edges to avoid duplicates.
    added_nodes = set()
    added_edges = set()  # each edge is a tuple (source, target, relation)
    
    # Iterate over each seed and its candidate results.
    for seed, candidates in final_results.items():
        for candidate in candidates:
            # For each candidate, process its overlapping triplets.
            for triplet in candidate.get("overlap", []):
                # Unpack the triplet. Each triplet is assumed to be (subject, relation, object).
                subject, relation, object_ = triplet
                
                # Add nodes if not already added.
                if subject not in added_nodes:
                    net.add_node(subject, label=subject)
                    added_nodes.add(subject)
                if object_ not in added_nodes:
                    net.add_node(object_, label=object_)
                    added_nodes.add(object_)
                
                # Use a tuple (subject, object, relation) to check if edge is already added.
                edge_key = (subject, object_, relation)
                # Also create the reciprocal key.
                reciprocal_edge_key = (object_, subject, relation)
                
                # Only add the edge if neither the edge nor its reciprocal exists.
                if edge_key not in added_edges and reciprocal_edge_key not in added_edges:
                    net.add_edge(subject, object_, title=relation, label=relation)
                    added_edges.add(edge_key)
    return net


In [23]:
overlap_net = build_overlap_graph(final_results)

In [24]:
overlap_net.write_html("overlapgraph.html")