# Getting Started with DTI Prediction

This notebook demonstrates how to use the multi-modal DTI prediction framework for drug-target interaction prediction.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from models import DrugGNN, ProteinPLM, TDAEncoder, FusionAttention, EvidentialHead
from data import DTIDataset, ColdStartSplit
from data.preprocessing import DrugPreprocessor, ProteinPreprocessor, TDAPreprocessor

print("PyTorch version:", torch.__version__)
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")

## 1. Load Dataset

Let's load the Davis dataset and explore its structure.

In [None]:
# Load Davis dataset
dataset = DTIDataset(
    root='../data',
    dataset_name='davis',
    split='train'
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of unique drugs: {len(set(dataset.data['drugs']))}")
print(f"Number of unique proteins: {len(set(dataset.data['proteins']))}")

## 2. Preprocess Data

Convert drug SMILES to molecular graphs and protein sequences to embeddings.

In [None]:
# Initialize preprocessors
drug_preprocessor = DrugPreprocessor()
protein_preprocessor = ProteinPreprocessor()
tda_preprocessor = TDAPreprocessor()

# Example: Process a drug molecule
example_smiles = "CC(C)Cc1ccc(cc1)C(C)C(O)=O"  # Ibuprofen
drug_graph = drug_preprocessor.smiles_to_graph(example_smiles)

print("Drug graph features:")
print(f"  Nodes: {drug_graph['num_nodes']}")
print(f"  Node features shape: {drug_graph['node_features'].shape}")
print(f"  Edges: {drug_graph['edge_index'].shape[1]}")

## 3. Initialize Models

Create instances of each model component.

In [None]:
# Model parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize encoders
drug_gnn = DrugGNN(
    in_features=78,
    hidden_dim=256,
    out_dim=256,
    num_layers=3
).to(device)

protein_plm = ProteinPLM(
    embedding_dim=1280,
    out_dim=256
).to(device)

tda_encoder = TDAEncoder(
    input_dim=100,
    out_dim=256
).to(device)

# Fusion and prediction
fusion = FusionAttention(
    drug_dim=256,
    protein_dim=256,
    tda_dim=256,
    hidden_dim=512
).to(device)

prediction_head = EvidentialHead(
    in_dim=512,
    hidden_dim=256
).to(device)

print("Models initialized successfully!")

## 4. Cold-Start Splits

Generate different cold-start evaluation scenarios.

In [None]:
# Create cold-drug split
cold_drug_splitter = ColdStartSplit(
    dataset,
    split_type='cold_drug',
    test_ratio=0.2,
    val_ratio=0.1
)

split_info = cold_drug_splitter.generate_split()

print("Cold Drug Split:")
print(f"  Train samples: {len(split_info['train'])}")
print(f"  Val samples: {len(split_info['val'])}")
print(f"  Test samples: {len(split_info['test'])}")
print(f"  Test drugs: {len(split_info.get('test_drugs', []))}")

## 5. Visualization

Visualize model predictions and uncertainties.

In [None]:
# Placeholder for visualization
# In practice, you would load trained model and make predictions

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Simulated data for demonstration
y_true = np.random.randn(100) * 2 + 7
y_pred = y_true + np.random.randn(100) * 0.5
uncertainty = np.random.rand(100) * 2

# Scatter plot of predictions vs true values
axes[0].scatter(y_true, y_pred, alpha=0.6)
axes[0].plot([5, 10], [5, 10], 'r--', label='Perfect prediction')
axes[0].set_xlabel('True Affinity')
axes[0].set_ylabel('Predicted Affinity')
axes[0].set_title('Predictions vs Ground Truth')
axes[0].legend()

# Uncertainty distribution
axes[1].hist(uncertainty, bins=20, alpha=0.7)
axes[1].set_xlabel('Uncertainty')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Prediction Uncertainty Distribution')

plt.tight_layout()
plt.show()

## 6. Training

To train the model, use the provided training script:

```bash
python train.py --config config/default_config.yaml --dataset davis --split_type random
```

## 7. Evaluation

Evaluate trained models on cold-start scenarios:

```bash
python eval.py --checkpoint checkpoints/best_model.pt --dataset davis --evaluate_all_splits
```