# 2. Graph Construction Debug

This notebook tests and debugs the graph building process.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import pandas as pd
from src.data import GraphBuilder, DatasetLoader

## 2.1 Build Graph Step by Step

In [None]:
builder = GraphBuilder(
    data_dir='../data/processed',
    raw_dir='../data/raw'
)

# Load node mapping
builder.load_node_mapping()
print(f'Node type counts: {builder.type_counts}')

In [None]:
# Check gene name coverage
print(f'Gene names in PrimeKG: {len(builder.gene_name_to_global)}')
print('Sample gene names:', list(builder.gene_name_to_global.keys())[:10])

## 2.2 PINNACLE Data Check

In [None]:
builder.load_pinnacle_data()

if builder.pinnacle_cell_data:
    print(f'PINNACLE cell types: {len(builder.pinnacle_cell_data)}')
    print('Sample cell types:', list(builder.pinnacle_cell_data.keys())[:5])

In [None]:
# Check ID alignment
if builder.pinnacle_all_proteins:
    primekg_genes = set(builder.gene_name_to_global.keys())
    pinnacle_proteins = builder.pinnacle_all_proteins
    
    overlap = primekg_genes & pinnacle_proteins
    print(f'PrimeKG genes: {len(primekg_genes)}')
    print(f'PINNACLE proteins: {len(pinnacle_proteins)}')
    print(f'Overlap: {len(overlap)} ({100*len(overlap)/len(primekg_genes):.1f}%)')

## 2.3 Build Complete Graph

In [None]:
data = builder.build('microglial_cell.txt', use_pinnacle=True)
print(data)

In [None]:
# Check edge types
print('Edge types:')
for et in data.edge_types:
    n_edges = data[et].edge_index.shape[1]
    print(f'  {et}: {n_edges} edges')

In [None]:
# Check node features
print('Node features:')
for nt in data.node_types:
    x = data[nt].x
    nonzero = (x.abs().sum(dim=1) > 0).sum().item()
    print(f'  {nt}: shape={x.shape}, non-zero rows={nonzero}')

## 2.4 Data Split Verification

In [None]:
from src.data import get_link_split

train, val, test = get_link_split(data, strategy='random')

target = ('drug', 'indication', 'disease')
print('Split statistics:')
for name, split in [('Train', train), ('Val', val), ('Test', test)]:
    labels = split[target].edge_label
    print(f'  {name}: {len(labels)} edges, pos={int(labels.sum())}, neg={int((labels==0).sum())}')

## 2.5 Graph Visualization (Subgraph)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

# Create small subgraph for visualization
G = nx.Graph()

# Add drug-gene edges (sample)
dg_edges = data['drug', 'targets', 'gene'].edge_index[:, :50]
for i in range(dg_edges.shape[1]):
    G.add_edge(f'D{dg_edges[0, i]}', f'G{dg_edges[1, i]}')

# Draw
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(G, seed=42)

drug_nodes = [n for n in G.nodes() if n.startswith('D')]
gene_nodes = [n for n in G.nodes() if n.startswith('G')]

nx.draw_networkx_nodes(G, pos, nodelist=drug_nodes, node_color='blue', node_size=100, label='Drug')
nx.draw_networkx_nodes(G, pos, nodelist=gene_nodes, node_color='green', node_size=100, label='Gene')
nx.draw_networkx_edges(G, pos, alpha=0.5)

plt.legend()
plt.title('Drug-Gene Subgraph (Sample)')
plt.axis('off')
plt.tight_layout()
plt.show()