# 3. Result Visualization

This notebook analyzes and visualizes training results.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc

## 3.1 Load Model and Data

In [None]:
from src.utils import load_config, get_device
from src.data import DatasetLoader, get_link_split
from src.models import HeteroGCN, LinkPredictor

config = load_config('../configs/config_base.yaml')
device = get_device('cpu')

# Load data
loader = DatasetLoader('../data/processed')
data = loader.load('microglial_cell_pinnacle')

_, _, test_data = get_link_split(data)

In [None]:
# Load model
num_nodes = {nt: data[nt].num_nodes for nt in data.node_types}
input_channels = {nt: data[nt].x.shape[1] for nt in data.node_types}

model = HeteroGCN(
    data.metadata(),
    hidden_channels=config['model']['hidden_channels'],
    num_layers=config['model']['num_layers'],
    num_nodes_dict=num_nodes,
    input_channels_dict=input_channels
).to(device)

predictor = LinkPredictor().to(device)

# Load checkpoint
ckpt = torch.load('../results/checkpoints/best_model.pth', weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
predictor.load_state_dict(ckpt['predictor_state_dict'])
print(f"Loaded checkpoint from epoch {ckpt['epoch']}")
print(f"Saved metrics: {ckpt['metrics']}")

## 3.2 Generate Predictions

In [None]:
model.eval()
predictor.eval()

test_data = test_data.to(device)
target = ('drug', 'indication', 'disease')

with torch.no_grad():
    x_dict = {nt: test_data[nt].x for nt in test_data.node_types}
    z_dict = model(x_dict, test_data.edge_index_dict)
    
    edge_idx = test_data[target].edge_label_index
    edge_label = test_data[target].edge_label.numpy()
    
    scores = predictor(z_dict['drug'], z_dict['disease'], edge_idx)
    probs = torch.sigmoid(scores).numpy()

print(f'Test samples: {len(edge_label)}')
print(f'Positive: {sum(edge_label)}, Negative: {sum(edge_label == 0)}')

## 3.3 ROC Curve

In [None]:
fpr, tpr, _ = roc_curve(edge_label, probs)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.4 Precision-Recall Curve

In [None]:
precision, recall, _ = precision_recall_curve(edge_label, probs)
pr_auc = auc(recall, precision)

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='green', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
plt.axhline(y=sum(edge_label)/len(edge_label), color='navy', linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.5 Score Distribution

In [None]:
pos_scores = probs[edge_label == 1]
neg_scores = probs[edge_label == 0]

plt.figure(figsize=(10, 5))
plt.hist(neg_scores, bins=50, alpha=0.7, label='Negative', color='red')
plt.hist(pos_scores, bins=50, alpha=0.7, label='Positive', color='blue')
plt.xlabel('Prediction Score')
plt.ylabel('Frequency')
plt.title('Score Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3.6 Top Predictions Analysis

In [None]:
# Get node names
node_mapping = pd.read_csv('../data/processed/node_mapping.csv')

drugs = node_mapping[node_mapping['node_type'] == 'drug'].sort_values('node_idx')
diseases = node_mapping[node_mapping['node_type'] == 'disease'].sort_values('node_idx')

drug_names = drugs['node_name'].tolist()
disease_names = diseases['node_name'].tolist()

In [None]:
# Top predictions
drug_ids = edge_idx[0].numpy()
disease_ids = edge_idx[1].numpy()

results = pd.DataFrame({
    'drug_idx': drug_ids,
    'disease_idx': disease_ids,
    'score': probs,
    'label': edge_label
})

# Add names
results['drug_name'] = results['drug_idx'].apply(lambda x: drug_names[x] if x < len(drug_names) else 'Unknown')
results['disease_name'] = results['disease_idx'].apply(lambda x: disease_names[x] if x < len(disease_names) else 'Unknown')

# Top 10 predictions (unknown)
unknown = results[results['label'] == 0].nlargest(10, 'score')
print('Top 10 Novel Predictions:')
print(unknown[['drug_name', 'disease_name', 'score']])

## 3.7 Comparison: Cell-Specific vs General

In [None]:
# Compare results from different configs
results_comparison = {
    'Model': ['Cell-Specific (PINNACLE)', 'General (Baseline)'],
    'AUROC': [0.979, 0.975],  # Replace with actual values
    'AUPRC': [0.983, 0.980],
    'MRR': [0.45, 0.42]
}

comparison_df = pd.DataFrame(results_comparison)
print(comparison_df)

In [None]:
# Bar plot comparison
metrics = ['AUROC', 'AUPRC', 'MRR']
x = np.arange(len(metrics))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(x - width/2, comparison_df.iloc[0][metrics], width, label='Cell-Specific')
ax.bar(x + width/2, comparison_df.iloc[1][metrics], width, label='General')

ax.set_ylabel('Score')
ax.set_title('Cell-Specific vs General Model Performance')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
ax.set_ylim([0, 1])
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()