In [2]:
pkgs = [
    "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
]

for p in pkgs:
    print(f"\nInstalling {p} ...")
    !{sys.executable} -m pip install {p}


Installing torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 ...


'{sys.executable}' is not recognized as an internal or external command,
operable program or batch file.


In [None]:
import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight

from tqdm import tqdm
import re

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
GPU device: NVIDIA GeForce RTX 2070 SUPER
GPU memory: 8.00 GB
Using device: cuda


In [7]:
import pandas as pd
import re

def extract_direction(description):
    """
    Determine if the interaction increases or decreases the effect
    """
    desc_lower = description.lower()
    
    # Check for increase variations
    increase_patterns = [
        r'\bincreas(e|ed|es|ing)\b',
        r'\b(enhance|enhanced|enhances|enhancing)\b',
        r'\b(elevat(e|ed|es|ing)|rais(e|ed|es|ing))\b',
        r'\bpotentiat(e|ed|es|ing)\b',
        r'\bamplif(y|ied|ies|ying)\b'
    ]
    
    # Check for decrease variations
    decrease_patterns = [
        r'\bdecreas(e|ed|es|ing)\b',
        r'\b(reduc(e|ed|es|ing)|diminish(ed|es|ing)?)\b',
        r'\b(lower(ed|s|ing)?)\b',
        r'\battenuat(e|ed|es|ing)\b',
        r'\bweaken(ed|s|ing)?\b',
        r'\blessen(ed|s|ing)?\b'
    ]
    
    has_increase = any(re.search(pattern, desc_lower) for pattern in increase_patterns)
    has_decrease = any(re.search(pattern, desc_lower) for pattern in decrease_patterns)
    
    # Handle cases with both (should be rare)
    if has_increase and has_decrease:
        # Count occurrences to determine dominant direction
        increase_count = sum(len(re.findall(pattern, desc_lower)) for pattern in increase_patterns)
        decrease_count = sum(len(re.findall(pattern, desc_lower)) for pattern in decrease_patterns)
        
        if increase_count > decrease_count:
            return 'increased'
        elif decrease_count > increase_count:
            return 'decreased'
        else:
            return 'mixed'
    elif has_increase:
        return 'increased'
    elif has_decrease:
        return 'decreased'
    else:
        return 'neutral'
    

def categorize_template(template):
    """
    Categorize interaction templates into high-level classes
    """
    template_lower = template.lower()
    
    # 1. ADVERSE EFFECTS / TOXICITY (General)
    if 'adverse effects' in template_lower and 'risk or severity' in template_lower:
        return 'adverse_effects_general'
    
    # 2. SPECIFIC TOXICITIES
    if 'cardiotoxic' in template_lower:
        return 'cardiotoxicity'
    elif 'nephrotoxic' in template_lower:
        return 'nephrotoxicity'
    elif 'hepatotoxic' in template_lower:
        return 'hepatotoxicity'
    elif 'neurotoxic' in template_lower:
        return 'neurotoxicity'
    elif 'myelosuppressive' in template_lower:
        return 'myelosuppression'
    elif 'ototoxic' in template_lower:
        return 'ototoxicity'
    
    # 3. PHARMACOKINETIC INTERACTIONS
    if 'metabolism' in template_lower:
        return 'metabolism_change'
    elif 'serum concentration' in template_lower:
        if 'active metabolites' in template_lower:
            return 'metabolite_concentration_change'
        else:
            return 'serum_concentration_change'
    elif 'bioavailability' in template_lower:
        return 'bioavailability_change'
    elif 'absorption' in template_lower:
        return 'absorption_change'
    elif 'excretion rate' in template_lower:
        return 'excretion_change'
    elif 'protein binding' in template_lower:
        return 'protein_binding_change'
    
    # 4. CARDIOVASCULAR EFFECTS
    if 'qtc' in template_lower or 'qt' in template_lower:
        return 'qtc_prolongation'
    elif 'bradycardic' in template_lower:
        return 'bradycardia'
    elif 'tachycardic' in template_lower:
        return 'tachycardia'
    elif 'arrhythmogenic' in template_lower:
        return 'arrhythmia'
    elif 'hypotensive' in template_lower and 'orthostatic' not in template_lower:
        return 'hypotension'
    elif 'hypertensive' in template_lower:
        return 'hypertension'
    elif 'orthostatic hypotensive' in template_lower:
        return 'orthostatic_hypotension'
    elif 'antihypertensive' in template_lower:
        return 'antihypertensive_effect'
    elif 'vasoconstricting' in template_lower or 'vasopressor' in template_lower:
        return 'vasoconstriction'
    elif 'vasodilatory' in template_lower:
        return 'vasodilation'
    elif 'heart failure' in template_lower:
        return 'heart_failure'
    elif 'av block' in template_lower or 'atrioventricular' in template_lower:
        return 'av_block'
    
    # 5. BLOOD/COAGULATION EFFECTS
    if 'anticoagulant' in template_lower:
        return 'anticoagulation'
    elif 'antiplatelet' in template_lower:
        return 'antiplatelet_effect'
    elif 'bleeding' in template_lower:
        return 'bleeding_risk'
    elif 'thrombogenic' in template_lower:
        return 'thrombosis'
    
    # 6. ELECTROLYTE EFFECTS
    if 'hypokalemic' in template_lower:
        return 'hypokalemia'
    elif 'hyperkalemic' in template_lower or 'hyperkalemia' in template_lower:
        return 'hyperkalemia'
    elif 'hypocalcemic' in template_lower:
        return 'hypocalcemia'
    elif 'hypercalcemic' in template_lower:
        return 'hypercalcemia'
    elif 'hyponatremic' in template_lower:
        return 'hyponatremia'
    
    # 7. CENTRAL NERVOUS SYSTEM
    if 'cns depressant' in template_lower:
        if 'hypertensive' in template_lower:
            return 'cns_depression_and_hypertension'
        elif 'hypotensive' in template_lower:
            return 'cns_depression_and_hypotension'
        else:
            return 'cns_depression'
    elif 'neuroexcitatory' in template_lower:
        return 'neuroexcitation'
    elif 'sedative' in template_lower:
        return 'sedation'
    elif 'central neurotoxic' in template_lower:
        return 'central_neurotoxicity'
    elif 'serotonergic' in template_lower:
        return 'serotonergic_effect'
    elif 'antipsychotic' in template_lower:
        return 'antipsychotic_effect'
    
    # 8. METABOLIC EFFECTS
    if 'hypoglycemic' in template_lower:
        return 'hypoglycemia'
    elif 'hyperglycemic' in template_lower:
        return 'hyperglycemia'
    
    # 9. RESPIRATORY EFFECTS
    if 'respiratory depressant' in template_lower:
        return 'respiratory_depression'
    elif 'bronchodilatory' in template_lower:
        return 'bronchodilation'
    elif 'bronchoconstrictory' in template_lower:
        return 'bronchoconstriction'
    
    # 10. NEUROMUSCULAR EFFECTS
    if 'neuromuscular blocking' in template_lower:
        return 'neuromuscular_blockade'
    elif 'adverse neuromuscular' in template_lower:
        return 'adverse_neuromuscular'
    elif 'myopathic rhabdomyolysis' in template_lower:
        return 'rhabdomyolysis'
    
    # 11. OTHER PHARMACOLOGICAL EFFECTS
    if 'therapeutic efficacy' in template_lower:
        return 'therapeutic_efficacy'
    elif 'analgesic' in template_lower:
        return 'analgesic_effect'
    elif 'anticholinergic' in template_lower:
        return 'anticholinergic_effect'
    elif 'immunosuppressive' in template_lower:
        return 'immunosuppression'
    elif 'diuretic' in template_lower:
        return 'diuretic_effect'
    elif 'stimulatory' in template_lower:
        return 'stimulation'
    
    # 12. GASTROINTESTINAL
    if 'ulcerogenic' in template_lower:
        return 'ulcerogenic_effect'
    elif 'constipating' in template_lower:
        return 'constipation'
    
    # 13. FLUID/RENAL
    if 'fluid retaining' in template_lower:
        return 'fluid_retention'
    
    # 14. DERMATOLOGIC
    if 'dermatologic' in template_lower:
        return 'dermatologic_adverse'
    
    # 15. HYPERSENSITIVITY
    if 'hypersensitivity' in template_lower:
        return 'hypersensitivity'
    
    # 16. DIAGNOSTIC
    if 'diagnostic agent' in template_lower:
        return 'diagnostic_interference'
    
    # DEFAULT
    return 'other_interaction'

# First create templates (from previous code)
def extract_interaction_template(description, drug1, drug2):
    template = description
    template = re.sub(re.escape(drug1), 'DRUG_A', template, flags=re.IGNORECASE)
    template = re.sub(re.escape(drug2), 'DRUG_B', template, flags=re.IGNORECASE)
    template = re.sub(r'\bincreas(e|ed|es|ing)\b', 'DIRECTION', template, flags=re.IGNORECASE)
    template = re.sub(r'\bdecreas(e|ed|es|ing)\b', 'DIRECTION', template, flags=re.IGNORECASE)
    template = re.sub(r'\b(enhance|enhanced|enhances|enhancing)\b', 'DIRECTION', template, flags=re.IGNORECASE)
    template = re.sub(r'\b(reduc(e|ed|es|ing)|diminish(ed|es|ing)?|lower(ed|s|ing)?)\b', 'DIRECTION', template, flags=re.IGNORECASE)
    template = re.sub(r'\b(elevat(e|ed|es|ing)|rais(e|ed|es|ing))\b', 'DIRECTION', template, flags=re.IGNORECASE)
    return template


df = pd.read_csv('drug_interactions_cleaned.csv')
print(f"Loaded {len(df):,} interactions")

# Create templates
df['Template'] = df.apply(
    lambda row: extract_interaction_template(
        row['Interaction Description'], 
        row['Drug 1'], 
        row['Drug 2']
    ), 
    axis=1
)
print(f"Created {df['Template'].nunique()} unique templates")

# Extract direction from ORIGINAL description (before replacing with DIRECTION)
df['Direction'] = df['Interaction Description'].apply(extract_direction)
print(f"\nDirection distribution:")
print(df['Direction'].value_counts())

# Categorize the templates
df['Interaction_Category'] = df['Template'].apply(categorize_template)
print(f"\nCreated {df['Interaction_Category'].nunique()} base categories")

# Combine category + direction
df['Category_With_Direction'] = df['Interaction_Category'] + '_' + df['Direction']
print(f"Combined into {df['Category_With_Direction'].nunique()} final categories")


le = LabelEncoder()
df['Interaction_Label'] = le.fit_transform(df['Category_With_Direction'])

num_classes = len(le.classes_)
print(f"\n{'='*80}")
print(f"✓ Number of classes for training: {num_classes}")
print("="*80)

# Show distribution
print("\nFINAL CLASS LABELS (with direction):")
print("="*80)
for i, label in enumerate(le.classes_):
    count = (df['Category_With_Direction'] == label).sum()
    # Parse out category and direction for display
    if '_increased' in label:
        category = label.replace('_increased', '')
        direction = '↑'
    elif '_decreased' in label:
        category = label.replace('_decreased', '')
        direction = '↓'
    elif '_neutral' in label:
        category = label.replace('_neutral', '')
        direction = '='
    elif '_mixed' in label:
        category = label.replace('_mixed', '')
        direction = '±'
    else:
        category = label
        direction = '?'
    
    print(f"  {i:3d}. [{direction}] {category:.<45} {count:>6,} samples")


df.to_csv('drug_interactions_categorized.csv', index=False)
print(f"\n✓ Saved to 'drug_interactions_categorized.csv'")

with open('label_encoder.pkl', 'wb') as f:
    pickle.dump(le, f)
print("✓ Saved label encoder")

# Save human-readable mapping
with open('label_mapping.txt', 'w') as f:
    f.write("LABEL MAPPING (Category + Direction)\n")
    f.write("="*80 + "\n\n")
    for i, label in enumerate(le.classes_):
        count = (df['Category_With_Direction'] == label).sum()
        f.write(f"{i:3d}. {label:<60} {count:>6,} samples\n")
print("✓ Saved label mapping to 'label_mapping.txt'")

print("\n" + "="*80)
print("Next step: Run your fingerprint generation script!")
print("="*80)

Loaded 189,932 interactions
Created 104 unique templates

Direction distribution:
Direction
increased    128502
decreased     61399
mixed            31
Name: count, dtype: int64

Created 60 base categories
Combined into 75 final categories

✓ Number of classes for training: 75

FINAL CLASS LABELS (with direction):
    0. [↓] absorption_change............................     45 samples
    1. [↑] adverse_effects_general...................... 60,010 samples
    2. [↑] adverse_neuromuscular........................     94 samples
    3. [↓] analgesic_effect.............................     10 samples
    4. [↑] analgesic_effect.............................    270 samples
    5. [↑] anticholinergic_effect.......................    323 samples
    6. [↓] anticoagulation..............................    225 samples
    7. [↑] anticoagulation..............................  3,055 samples
    8. [↓] antiplatelet_effect..........................      7 samples
    9. [↑] antiplatelet_effect......

In [8]:
def smiles_to_fingerprint(smiles, radius=2, nBits=2048):
    """Convert SMILES to Morgan fingerprint"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=nBits)
        fp = gen.GetFingerprint(mol)
        return np.array(fp)
    except:
        return None

df = pd.read_csv('drug_interactions_categorized.csv')

unique_smiles = pd.concat([df['Drug1_SMILES'], df['Drug2_SMILES']]).unique()
print(f"\nUnique SMILES strings: {len(unique_smiles)}")

# unique SMILES to fingerprints
smiles_to_fp_dict = {}
for smiles in tqdm(unique_smiles):
    fp = smiles_to_fingerprint(smiles)
    smiles_to_fp_dict[smiles] = fp

# fingerprints back to dataframe
df['Drug1_FP'] = df['Drug1_SMILES'].map(smiles_to_fp_dict)
df['Drug2_FP'] = df['Drug2_SMILES'].map(smiles_to_fp_dict)

before_removal = len(df)
df = df.dropna(subset=['Drug1_FP', 'Drug2_FP'])
print(f"Removed: {before_removal - len(df)} rows with failed fingerprints")

df.to_pickle('drug_interactions_with_fingerprints.pkl')
print("\n✓ Saved: drug_interactions_with_fingerprints.pkl")

print(f"Sample fingerprint shape: {df.iloc[0]['Drug1_FP'].shape}")


Unique SMILES strings: 1656


100%|██████████| 1656/1656 [00:01<00:00, 966.55it/s]


Removed: 0 rows with failed fingerprints

✓ Saved: drug_interactions_with_fingerprints.pkl
Sample fingerprint shape: (2048,)


In [9]:
df = pd.read_pickle('drug_interactions_with_fingerprints.pkl')
print(f"Dataset size: {len(df)} interactions")

# Load label encoder
with open('label_encoder.pkl', 'rb') as f:
    le = pickle.load(f)

num_classes = len(le.classes_)
print(f"Number of interaction types: {num_classes}")

Dataset size: 189932 interactions
Number of interaction types: 75


In [10]:
X1 = np.stack(df['Drug1_FP'].values).astype(np.float32)  # Shape: (n_samples, 2048)
X2 = np.stack(df['Drug2_FP'].values).astype(np.float32)  # Shape: (n_samples, 2048)
y = df['Interaction_Label'].values.astype(np.int64)

print(f"X1 shape: {X1.shape}")
print(f"X2 shape: {X2.shape}")
print(f"y shape: {y.shape}")

# Split data
X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split(
    X1, X2, y, test_size=0.2, random_state=42
)

print(f"\nTraining set: {len(X1_train)} samples")
print(f"Test set: {len(X1_test)} samples")

X1 shape: (189932, 2048)
X2 shape: (189932, 2048)
y shape: (189932,)

Training set: 151945 samples
Test set: 37987 samples


In [18]:
from torch.utils.data import WeightedRandomSampler

class DDIDataset(Dataset):
    """Custom Dataset for Drug-Drug Interactions"""
    def __init__(self, drug1_fps, drug2_fps, labels):
        self.drug1_fps = torch.FloatTensor(drug1_fps)
        self.drug2_fps = torch.FloatTensor(drug2_fps)
        self.labels = torch.LongTensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.drug1_fps[idx], self.drug2_fps[idx], self.labels[idx]

# Create datasets
train_dataset = DDIDataset(X1_train, X2_train, y_train)
test_dataset = DDIDataset(X1_test, X2_test, y_test)

# -------------------------------
# ADD WEIGHTED SAMPLER HERE
# -------------------------------
class_counts = torch.bincount(torch.tensor(y_train))          # count samples per class
class_weights = 1.0 / (class_counts + 1e-8)                   # avoid div-by-zero
sample_weights = class_weights[y_train]                       # weight each training sample

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

batch_size = 64

# Replace shuffle=True with the sampler
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=sampler
)

# Test loader stays the same
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")


Number of training batches: 2375
Number of test batches: 594


In [19]:
class DDIPredictor(nn.Module):
    """CNN model for Drug-Drug Interaction prediction"""
    
    def __init__(self, fp_size=2048, num_classes=100):
        super(DDIPredictor, self).__init__()
        
        # Drug encoder (shared for both drugs)
        self.drug_encoder = nn.Sequential(
            # First conv block
            nn.Conv1d(1, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.MaxPool1d(2),
            
            # Second conv block
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.MaxPool1d(2),
            
            # Third conv block
            nn.Conv1d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.AdaptiveMaxPool1d(1)  # Global max pooling
        )
        
        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 2, 1024),  # *2 because we concatenate two drugs
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, drug1, drug2):
        # Reshape for Conv1d: (batch, channels, length)
        drug1 = drug1.unsqueeze(1)  # (batch, 1, 2048)
        drug2 = drug2.unsqueeze(1)  # (batch, 1, 2048)
        
        # Encode both drugs
        drug1_features = self.drug_encoder(drug1).squeeze(-1)  # (batch, 512)
        drug2_features = self.drug_encoder(drug2).squeeze(-1)  # (batch, 512)
        
        # Concatenate features
        combined = torch.cat([drug1_features, drug2_features], dim=1)  # (batch, 1024)
        
        # Pass through FC layers
        output = self.fc_layers(combined)
        
        return output

# Create model
model = DDIPredictor(fp_size=2048, num_classes=num_classes)
model = model.to(device)

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: (N, C)
        # targets: (N)

        ce_loss = F.cross_entropy(
            logits,
            targets,
            reduction='none',
            weight=self.alpha
        )

        pt = torch.exp(-ce_loss)  # probability of the true class

        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        return focal_loss


In [21]:
# Loss function and optimizer
criterion = FocalLoss(gamma=2.0).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

print("✓ Training setup complete")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Loss function: {criterion.__class__.__name__}")
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")

✓ Training setup complete
Optimizer: Adam
Loss function: FocalLoss
Learning rate: 0.001


In [22]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for drug1, drug2, labels in pbar:
        # Move to device
        drug1, drug2, labels = drug1.to(device), drug2.to(device), labels.to(device)
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(drug1, drug2)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * drug1.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({'loss': loss.item(), 'acc': correct/total})
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device):
    """Evaluate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for drug1, drug2, labels in loader:
            drug1, drug2, labels = drug1.to(device), drug2.to(device), labels.to(device)
            
            outputs = model(drug1, drug2)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * drug1.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc

print("✓ Training functions defined")

✓ Training functions defined


In [24]:
# Training parameters
num_epochs = 100
best_val_acc = 0.0
patience = 5
patience_counter = 0

# History
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'learning_rate': []
}

print("Starting training...\n")

for epoch in range(num_epochs):
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate (use test set as validation)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)
    
    # Update scheduler
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    # Check if LR was reduced
    if new_lr < old_lr:
        print(f"  Learning rate reduced: {old_lr:.6f} -> {new_lr:.6f}")
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['learning_rate'].append(current_lr)
    
    # Print progress
    print(f"Epoch [{epoch+1}/{num_epochs}] LR: {current_lr:.6f}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_ddi_model.pth')
        print(f"  ✓ New best model saved! (Val Acc: {val_acc:.4f})")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    print()

print("\n✓ Training complete!")
print(f"Best validation accuracy: {best_val_acc:.4f}")

# Load best model
model.load_state_dict(torch.load('best_ddi_model.pth'))

from sklearn.metrics import classification_report

y_true = []
y_pred = []

model.eval()
with torch.no_grad():
    for drug1, drug2, labels in test_loader:
        drug1, drug2 = drug1.to(device), drug2.to(device)
        labels = labels.to(device)

        outputs = model(drug1, drug2)
        _, predicted = outputs.max(1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

print("\nClassification Report:")
print(classification_report(
    y_true, 
    y_pred, 
    digits=4,
    zero_division=0
))

Starting training...



Training: 100%|██████████| 2375/2375 [02:46<00:00, 14.26it/s, loss=3.39, acc=0.0831]


Epoch [1/100] LR: 0.001000
  Train Loss: 3.4862 | Train Acc: 0.0831
  Val Loss:   3.6193 | Val Acc:   0.0115
  ✓ New best model saved! (Val Acc: 0.0115)



Training: 100%|██████████| 2375/2375 [02:47<00:00, 14.18it/s, loss=2.57, acc=0.106]


Epoch [2/100] LR: 0.001000
  Train Loss: 3.3058 | Train Acc: 0.1063
  Val Loss:   3.4120 | Val Acc:   0.0091



Training: 100%|██████████| 2375/2375 [02:48<00:00, 14.11it/s, loss=3.48, acc=0.122]


Epoch [3/100] LR: 0.001000
  Train Loss: 3.1967 | Train Acc: 0.1220
  Val Loss:   3.4407 | Val Acc:   0.0086



Training: 100%|██████████| 2375/2375 [02:48<00:00, 14.08it/s, loss=2.66, acc=0.135]


Epoch [4/100] LR: 0.001000
  Train Loss: 3.1231 | Train Acc: 0.1353
  Val Loss:   3.4710 | Val Acc:   0.0071



Training: 100%|██████████| 2375/2375 [02:46<00:00, 14.23it/s, loss=2.55, acc=0.144]


Epoch [5/100] LR: 0.001000
  Train Loss: 3.0678 | Train Acc: 0.1436
  Val Loss:   3.6942 | Val Acc:   0.0049



Training: 100%|██████████| 2375/2375 [02:45<00:00, 14.39it/s, loss=3.5, acc=0.156] 


  Learning rate reduced: 0.001000 -> 0.000500
Epoch [6/100] LR: 0.001000
  Train Loss: 2.9902 | Train Acc: 0.1557
  Val Loss:   3.8681 | Val Acc:   0.0063

Early stopping triggered after 6 epochs

✓ Training complete!
Best validation accuracy: 0.0115


  model.load_state_dict(torch.load('best_ddi_model.pth'))



Classification Report:
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         8
           1     0.0000    0.0000    0.0000     11972
           2     0.0000    0.0000    0.0000        15
           3     0.0000    0.0000    0.0000         4
           4     0.0000    0.0000    0.0000        54
           5     0.0000    0.0000    0.0000        68
           6     0.0121    0.6538    0.0238        52
           7     0.0333    0.0905    0.0486       630
           8     0.0000    0.0000    0.0000         1
           9     0.0031    0.1127    0.0060        71
          10     0.0000    0.0000    0.0000        19
          11     0.5000    0.0143    0.0278        70
          12     0.0000    0.0000    0.0000       132
          13     0.0271    0.4255    0.0510        94
          14     0.0000    0.0000    0.0000         2
          15     0.0000    0.0000    0.0000        27
          16     0.0000    0.0000    0.0000       238
   

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Accuracy plot
axes[0].plot(history['train_acc'], label='Train Accuracy')
axes[0].plot(history['val_acc'], label='Val Accuracy')
axes[0].set_title('Model Accuracy')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True)

# Loss plot
axes[1].plot(history['train_loss'], label='Train Loss')
axes[1].plot(history['val_loss'], label='Val Loss')
axes[1].set_title('Model Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

best_epoch = np.argmax(history['val_acc'])
print(f"Best epoch: {best_epoch + 1}")
print(f"Best validation accuracy: {history['val_acc'][best_epoch]:.4f}")

In [None]:
# Evaluate on test set with best model
test_loss, test_acc = evaluate(model, test_loader, criterion, device)

print(f"{'='*60}")
print(f"Final Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"Final Test Loss: {test_loss:.4f}")
print(f"{'='*60}")

In [None]:
def smiles_to_fingerprint(smiles, radius=2, nBits=2048):
    """Convert SMILES to Morgan fingerprint"""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=nBits)
        fp = gen.GetFingerprint(mol)
        return np.array(fp, dtype=np.float32)
    except:
        return None

def predict_interaction(drug1_smiles, drug2_smiles, top_k=3):
    """Predict interaction between two drugs"""
    model.eval()
    
    # Convert to fingerprints
    fp1 = smiles_to_fingerprint(drug1_smiles)
    fp2 = smiles_to_fingerprint(drug2_smiles)
    
    if fp1 is None or fp2 is None:
        return [("Invalid SMILES", 0.0)]
    
    # Convert to tensors
    fp1 = torch.FloatTensor(fp1).unsqueeze(0).to(device)
    fp2 = torch.FloatTensor(fp2).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(fp1, fp2)
        probabilities = torch.softmax(outputs, dim=1)[0]
    
    # Get top k predictions
    top_probs, top_indices = torch.topk(probabilities, top_k)
    
    results = []
    for prob, idx in zip(top_probs, top_indices):
        interaction_type = le.inverse_transform([idx.item()])[0]
        results.append((interaction_type, prob.item()))
    
    return results

print("✓ Prediction function defined")

In [None]:
# Test predictions
print("Testing predictions on random samples:\n")
print("="*80)

test_samples = df.sample(5, random_state=42)

for idx, row in test_samples.iterrows():
    print(f"\n{'='*80}")
    print(f"Drug 1: {row['Drug 1']}")
    print(f"Drug 2: {row['Drug 2']}")
    print(f"\nActual Interaction:")
    print(f"  {row['Interaction Description']}")
    
    predictions = predict_interaction(row['Drug1_SMILES'], row['Drug2_SMILES'], top_k=3)
    
    print(f"\nTop 3 Predictions:")
    for i, (interaction, conf) in enumerate(predictions, 1):
        match = "✓" if interaction == row['Interaction Description'] else "✗"
        print(f"  {i}. [{match}] {interaction}")
        print(f"      Confidence: {conf:.2%}")

print(f"\n{'='*80}")