# 3. Result Visualization

This notebook analyzes and visualizes training results.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc

## 3.1 Load Model and Data

In [None]:
from src.utils import load_config, get_device
from src.data import DatasetLoader, get_link_split
from src.models import HeteroGCN, LinkPredictor

config = load_config('../configs/config_base.yaml')
device = get_device('cpu')

# Load data
loader = DatasetLoader('../data/processed')
data = loader.load('microglial_cell_pinnacle')

_, _, test_data = get_link_split(data)

In [None]:
# Load model
num_nodes = {nt: data[nt].num_nodes for nt in data.node_types}
model_cfg = config['model']

model = HeteroGCN(
    data.metadata(),
    hidden_channels=model_cfg['hidden_channels'],
    num_layers=model_cfg['num_layers'],
    num_nodes_dict=num_nodes
).to(device)

# Dummy forward to initialize LazyLinear
with torch.no_grad():
    dummy_x = {nt: test_data[nt].x.to(device) for nt in test_data.node_types}
    _ = model(dummy_x, test_data.edge_index_dict)

# LinkPredictor with optional SimGNN decoder
use_sim = model_cfg.get('use_sim_decoder', False)
predictor = LinkPredictor(
    hidden_channels=model_cfg['hidden_channels'] if use_sim else None,
    use_sim_decoder=use_sim
).to(device)

# Load checkpoint
ckpt = torch.load('../results/checkpoints/best_model_microglial_cell.pth', weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
predictor.load_state_dict(ckpt['predictor_state_dict'])
print(f"Loaded checkpoint from epoch {ckpt['epoch']}")
print(f"Saved metrics: {ckpt['metrics']}")

## 3.2 Generate Predictions

In [None]:
model.eval()
predictor.eval()

test_data = test_data.to(device)
target = ('drug', 'indication', 'disease')

with torch.no_grad():
    x_dict = {nt: test_data[nt].x for nt in test_data.node_types}
    z_dict = model(x_dict, test_data.edge_index_dict)
    
    edge_idx = test_data[target].edge_label_index
    edge_label = test_data[target].edge_label.numpy()
    
    scores = predictor(z_dict['drug'], z_dict['disease'], edge_idx)
    probs = torch.sigmoid(scores).numpy()

print(f'Test samples: {len(edge_label)}')
print(f'Positive: {sum(edge_label)}, Negative: {sum(edge_label == 0)}')

## 3.3 ROC Curve

In [None]:
fpr, tpr, _ = roc_curve(edge_label, probs)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.4 Precision-Recall Curve

In [None]:
precision, recall, _ = precision_recall_curve(edge_label, probs)
pr_auc = auc(recall, precision)

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='green', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
plt.axhline(y=sum(edge_label)/len(edge_label), color='navy', linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.5 Score Distribution

In [None]:
pos_scores = probs[edge_label == 1]
neg_scores = probs[edge_label == 0]

plt.figure(figsize=(10, 5))
plt.hist(neg_scores, bins=50, alpha=0.7, label='Negative', color='red')
plt.hist(pos_scores, bins=50, alpha=0.7, label='Positive', color='blue')
plt.xlabel('Prediction Score')
plt.ylabel('Frequency')
plt.title('Score Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.6 Top Drug Predictions for a Disease

In [None]:
# Get node mappings
node_mapping = pd.read_csv('../data/processed/node_mapping.csv')

drug_names = node_mapping[node_mapping['node_type'] == 'drug']['node_name'].tolist()
disease_names = node_mapping[node_mapping['node_type'] == 'disease']['node_name'].tolist()

# Example: get top predictions for a specific disease
disease_idx = 0
disease_name = disease_names[disease_idx] if disease_idx < len(disease_names) else f'Disease_{disease_idx}'

print(f'\nTop 10 drug predictions for: {disease_name}')
all_drug_scores = predictor(z_dict['drug'], z_dict['disease'][disease_idx:disease_idx+1].expand(len(drug_names), -1), 
                           torch.stack([torch.arange(len(drug_names)), torch.zeros(len(drug_names), dtype=torch.long)]))
top_indices = torch.argsort(all_drug_scores, descending=True)[:10]

for i, idx in enumerate(top_indices):
    drug_name = drug_names[idx] if idx < len(drug_names) else f'Drug_{idx}'
    score = torch.sigmoid(all_drug_scores[idx]).item()
    print(f'{i+1}. {drug_name}: {score:.4f}')

## 3.7 Multi-Relational PPI Edge Statistics

In [None]:
# Check edge types including multi-relational PPI
print('Edge types in graph:')
for et in data.edge_types:
    n_edges = data[et].edge_index.shape[1]
    print(f'  {et}: {n_edges:,} edges')

# Highlight PPI edge types
ppi_generic = ('gene', 'ppi_generic', 'gene')
ppi_cell = ('gene', 'ppi_cell', 'gene')

if ppi_generic in data.edge_types:
    print(f'\nGeneric PPI edges: {data[ppi_generic].edge_index.shape[1]:,}')
if ppi_cell in data.edge_types:
    print(f'Cell-specific PPI edges: {data[ppi_cell].edge_index.shape[1]:,}')