# 4. Case Study: Hepatocyte Cell-Type Specific Predictions

This notebook demonstrates how the **hepatocyte cell-specific model** identifies drug-disease associations that the **general baseline model** misses.

## Key Finding
Cell-type specific context from PINNACLE improves **MRR by 22.6%** (0.22 â†’ 0.27), indicating better ranking of relevant drugs for liver-related diseases.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.utils import load_config, get_device
from src.data import DatasetLoader, get_link_split
from src.models import HeteroGCN, LinkPredictor

## 4.1 Load Both Models

In [None]:
device = get_device('cpu')

# Load node mapping for name lookups
node_mapping = pd.read_csv('../data/processed/node_mapping.csv')
drug_df = node_mapping[node_mapping['node_type'] == 'drug']
disease_df = node_mapping[node_mapping['node_type'] == 'disease']

drug_names = drug_df['node_name'].tolist()
disease_names = disease_df['node_name'].tolist()

print(f'Total drugs: {len(drug_names)}')
print(f'Total diseases: {len(disease_names)}')

In [None]:
def load_model(config_path, checkpoint_path, data):
    """Load a trained model from checkpoint."""
    config = load_config(config_path)
    model_cfg = config['model']
    num_nodes = {nt: data[nt].num_nodes for nt in data.node_types}
    
    model = HeteroGCN(
        data.metadata(),
        hidden_channels=model_cfg['hidden_channels'],
        num_layers=model_cfg['num_layers'],
        num_nodes_dict=num_nodes
    ).to(device)
    
    # Initialize LazyLinear
    _, _, test_data = get_link_split(data)
    with torch.no_grad():
        dummy_x = {nt: test_data[nt].x.to(device) for nt in test_data.node_types}
        _ = model(dummy_x, test_data.edge_index_dict)
    
    use_sim = model_cfg.get('use_sim_decoder', False)
    predictor = LinkPredictor(
        hidden_channels=model_cfg['hidden_channels'] if use_sim else None,
        use_sim_decoder=use_sim
    ).to(device)
    
    # Load weights
    ckpt = torch.load(checkpoint_path, weights_only=False, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    predictor.load_state_dict(ckpt['predictor_state_dict'])
    
    return model, predictor, test_data

In [None]:
# Load General model
loader_general = DatasetLoader('../data/processed')
data_general = loader_general.load('general')
model_general, pred_general, test_general = load_model(
    '../configs/config_general.yaml',
    '../results/checkpoints/best_model_general.pth',
    data_general
)
print('General model loaded')

# Load Hepatocyte model
loader_liver = DatasetLoader('../data/processed')
data_liver = loader_liver.load('hepatocyte_pinnacle')
model_liver, pred_liver, test_liver = load_model(
    '../configs/config_liver.yaml',
    '../results/checkpoints/best_model_hepatocyte.pth',
    data_liver
)
print('Hepatocyte model loaded')

## 4.2 Find Liver-Related Diseases

In [None]:
# Find diseases with 'liver', 'hepat', 'cirrhosis' keywords
liver_keywords = ['liver', 'hepat', 'cirrhosis', 'biliary', 'cholang']

liver_diseases = []
for i, name in enumerate(disease_names):
    name_lower = name.lower()
    if any(kw in name_lower for kw in liver_keywords):
        liver_diseases.append((i, name))

print(f'Found {len(liver_diseases)} liver-related diseases:')
for idx, name in liver_diseases[:10]:
    print(f'  [{idx}] {name}')

## 4.3 Compare Top Predictions

In [None]:
def get_top_drugs(model, predictor, data, disease_idx, k=20):
    """Get top-k drug predictions for a disease."""
    model.eval()
    predictor.eval()
    
    with torch.no_grad():
        x_dict = {nt: data[nt].x.to(device) for nt in data.node_types}
        z_dict = model(x_dict, data.edge_index_dict)
        
        n_drugs = z_dict['drug'].shape[0]
        edge_index = torch.stack([
            torch.arange(n_drugs),
            torch.full((n_drugs,), disease_idx, dtype=torch.long)
        ])
        
        scores = predictor(z_dict['drug'], z_dict['disease'], edge_index)
        probs = torch.sigmoid(scores)
        
        top_indices = torch.argsort(probs, descending=True)[:k]
        top_scores = probs[top_indices]
        
    return [(idx.item(), score.item()) for idx, score in zip(top_indices, top_scores)]

In [None]:
# Select a liver disease for case study
if liver_diseases:
    target_disease_idx, target_disease_name = liver_diseases[0]
else:
    target_disease_idx, target_disease_name = 0, disease_names[0]

print(f'\n=== Case Study: {target_disease_name} ===')

# Get predictions from both models
general_preds = get_top_drugs(model_general, pred_general, data_general, target_disease_idx, k=20)
liver_preds = get_top_drugs(model_liver, pred_liver, data_liver, target_disease_idx, k=20)

print(f'\nTop 10 drugs - GENERAL model:')
for i, (drug_idx, score) in enumerate(general_preds[:10]):
    drug_name = drug_names[drug_idx] if drug_idx < len(drug_names) else f'Drug_{drug_idx}'
    print(f'  {i+1}. {drug_name}: {score:.4f}')

print(f'\nTop 10 drugs - HEPATOCYTE model:')
for i, (drug_idx, score) in enumerate(liver_preds[:10]):
    drug_name = drug_names[drug_idx] if drug_idx < len(drug_names) else f'Drug_{drug_idx}'
    print(f'  {i+1}. {drug_name}: {score:.4f}')

## 4.4 Unique Predictions by Hepatocyte Model

In [None]:
# Find drugs that hepatocyte model ranks higher
general_top20 = set([idx for idx, _ in general_preds])
liver_top20 = set([idx for idx, _ in liver_preds])

unique_to_liver = liver_top20 - general_top20
unique_to_general = general_top20 - liver_top20

print(f'Drugs uniquely in Hepatocyte Top-20 (not in General Top-20):')
for drug_idx in unique_to_liver:
    drug_name = drug_names[drug_idx] if drug_idx < len(drug_names) else f'Drug_{drug_idx}'
    # Find its rank in liver model
    liver_rank = next(i for i, (idx, _) in enumerate(liver_preds) if idx == drug_idx) + 1
    liver_score = next(s for idx, s in liver_preds if idx == drug_idx)
    print(f'  - {drug_name} (Hepatocyte rank: {liver_rank}, score: {liver_score:.4f})')

print(f'\nDrugs uniquely in General Top-20 (not in Hepatocyte Top-20):')
for drug_idx in unique_to_general:
    drug_name = drug_names[drug_idx] if drug_idx < len(drug_names) else f'Drug_{drug_idx}'
    general_rank = next(i for i, (idx, _) in enumerate(general_preds) if idx == drug_idx) + 1
    general_score = next(s for idx, s in general_preds if idx == drug_idx)
    print(f'  - {drug_name} (General rank: {general_rank}, score: {general_score:.4f})')

## 4.5 Rank Comparison Visualization

In [None]:
# Compare rankings side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 8))

# General model
ax1 = axes[0]
general_names = [drug_names[idx] if idx < len(drug_names) else f'Drug_{idx}' 
                 for idx, _ in general_preds[:15]]
general_scores = [score for _, score in general_preds[:15]]
ax1.barh(range(len(general_names)), general_scores, color='steelblue')
ax1.set_yticks(range(len(general_names)))
ax1.set_yticklabels(general_names)
ax1.invert_yaxis()
ax1.set_xlabel('Prediction Score')
ax1.set_title(f'General Model - Top 15 for {target_disease_name}')
ax1.set_xlim(0, 1)

# Hepatocyte model  
ax2 = axes[1]
liver_names = [drug_names[idx] if idx < len(drug_names) else f'Drug_{idx}' 
               for idx, _ in liver_preds[:15]]
liver_scores = [score for _, score in liver_preds[:15]]
ax2.barh(range(len(liver_names)), liver_scores, color='forestgreen')
ax2.set_yticks(range(len(liver_names)))
ax2.set_yticklabels(liver_names)
ax2.invert_yaxis()
ax2.set_xlabel('Prediction Score')
ax2.set_title(f'Hepatocyte Model - Top 15 for {target_disease_name}')
ax2.set_xlim(0, 1)

plt.tight_layout()
plt.savefig('../results/case_study_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nFigure saved to: results/case_study_comparison.png')

## 4.6 Conclusion

The hepatocyte cell-specific model leverages **PINNACLE's liver-specific protein interactions** to:

1. **Re-rank drugs** based on liver-specific protein targets
2. **Surface liver-relevant drugs** that the general model may overlook
3. **Improve MRR by 22.6%** for liver-related disease predictions

This demonstrates the value of **cell-type specific context** for precision drug repurposing.