# 03 - QSAR Model Training

**TB Drug Discovery ML Pipeline - Phase 1**

This notebook covers:
1. Data preparation and splitting
2. Random Forest QSAR model training
3. 5-fold cross-validation
4. Model evaluation and metrics
5. ROC curve and feature importance visualization

**Target Metrics:**
- ROC-AUC > 0.75 on test set
- Cross-validation std < 0.05

In [None]:
# Imports
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent / "src"))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay

from data.descriptor_calculator import DescriptorCalculator
from data.data_preprocessor import DataPreprocessor
from models.qsar_model import QSARModel
from evaluation.cross_validation import cross_validate_model
from evaluation.metrics import get_roc_curve

plt.style.use('seaborn-v0_8-whitegrid')

# Random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

print("Imports successful!")

## 1. Load Data

In [None]:
# Load descriptor data
data_path = Path.cwd().parent / "data" / "processed" / "descriptors.csv"

if data_path.exists():
    df = pd.read_csv(data_path)
    print(f"Loaded {len(df)} compounds with descriptors")
    print(f"Columns: {len(df.columns)}")
else:
    print("Error: Run 02_descriptor_calculation.ipynb first!")

In [None]:
# Define features and target
calculator = DescriptorCalculator(lipinski=True, topological=True, extended=True)
feature_cols = calculator.descriptor_names

# For classification
target_col = 'active'

# Filter valid data
df_train = df.dropna(subset=feature_cols + [target_col, 'pIC50'])
print(f"Training samples after filtering: {len(df_train)}")

# Check class balance
print(f"\nClass distribution:")
print(df_train[target_col].value_counts())

## 2. Data Preparation

In [None]:
# Extract features and target
X = df_train[feature_cols].values
y = df_train[target_col].values

print(f"Feature matrix shape: {X.shape}")
print(f"Target shape: {y.shape}")

In [None]:
# Split data
preprocessor = DataPreprocessor(random_seed=RANDOM_SEED)

X_train, X_test, y_train, y_test = preprocessor.split_data_simple(
    X, y, test_size=0.2, stratify=True
)

print(f"Training set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print(f"\nClass balance in train: {np.mean(y_train):.2%} active")
print(f"Class balance in test: {np.mean(y_test):.2%} active")

In [None]:
# Scale features
X_train_scaled = preprocessor.fit_transform(X_train)
X_test_scaled = preprocessor.transform(X_test)

# Save scaler
scaler_path = Path.cwd().parent / "models" / "qsar_scaler.pkl"
scaler_path.parent.mkdir(exist_ok=True)
preprocessor.save(str(scaler_path))

print(f"Features scaled. Scaler saved to: {scaler_path}")

## 3. Train QSAR Model

In [None]:
# Initialize and train model
model = QSARModel(
    task='classification',
    n_estimators=100,
    max_depth=None,
    min_samples_split=2,
    min_samples_leaf=1,
    random_seed=RANDOM_SEED
)

# Train
model.fit(X_train_scaled, y_train, feature_names=feature_cols)

print("\nTraining metrics:")
for key, value in model.training_metrics.items():
    print(f"  {key}: {value:.4f}")

## 4. Model Evaluation

In [None]:
# Evaluate on test set
test_metrics = model.evaluate(X_test_scaled, y_test)

print("Test Set Metrics:")
print("=" * 40)
print(f"  ROC-AUC:    {test_metrics['roc_auc']:.4f}")
print(f"  Accuracy:   {test_metrics['accuracy']:.4f}")
print(f"  Precision:  {test_metrics['precision']:.4f}")
print(f"  Recall:     {test_metrics['recall']:.4f}")
print(f"  F1-Score:   {test_metrics['f1']:.4f}")
print("=" * 40)

# Check target
target_met = test_metrics['roc_auc'] >= 0.75
print(f"\nTarget (ROC-AUC >= 0.75): {'✅ PASSED' if target_met else '❌ NOT MET'}")

## 5. Cross-Validation

In [None]:
# 5-fold cross-validation
cv_results = cross_validate_model(
    model.model,
    X_train_scaled,
    y_train,
    n_folds=5,
    task='classification',
    random_seed=RANDOM_SEED,
    return_predictions=True
)

print("\n5-Fold Cross-Validation Results:")
print("=" * 50)
print(f"  ROC-AUC:    {cv_results['roc_auc_mean']:.4f} ± {cv_results['roc_auc_std']:.4f}")
print(f"  Accuracy:   {cv_results['accuracy_mean']:.4f} ± {cv_results['accuracy_std']:.4f}")
print(f"  Precision:  {cv_results['precision_mean']:.4f} ± {cv_results['precision_std']:.4f}")
print(f"  Recall:     {cv_results['recall_mean']:.4f} ± {cv_results['recall_std']:.4f}")
print(f"  F1-Score:   {cv_results['f1_mean']:.4f} ± {cv_results['f1_std']:.4f}")
print("=" * 50)

# Check stability
stable = cv_results['roc_auc_std'] < 0.05
print(f"\nStability (std < 0.05): {'✅ STABLE' if stable else '⚠️ HIGH VARIANCE'}")

## 6. Visualizations

In [None]:
# ROC Curve
y_proba = model.predict_proba(X_test_scaled)[:, 1]
fpr, tpr, thresholds, roc_auc = get_roc_curve(y_test, y_proba)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('QSAR Model - ROC Curve', fontsize=14)
plt.legend(loc='lower right', fontsize=11)
plt.grid(True, alpha=0.3)

# Save figure
fig_path = Path.cwd().parent / "results" / "figures"
fig_path.mkdir(parents=True, exist_ok=True)
plt.savefig(fig_path / 'qsar_roc_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved: {fig_path / 'qsar_roc_curve.png'}")

In [None]:
# Confusion Matrix
y_pred = model.predict(X_test_scaled)
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(8, 6))
disp = ConfusionMatrixDisplay(cm, display_labels=['Inactive', 'Active'])
disp.plot(cmap='Blues', values_format='d')
plt.title('QSAR Model - Confusion Matrix', fontsize=14)

plt.savefig(fig_path / 'qsar_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Feature Importance
importance_df = model.get_feature_importance(top_n=15)

plt.figure(figsize=(10, 8))
plt.barh(range(len(importance_df)), importance_df['importance'].values, color='steelblue')
plt.yticks(range(len(importance_df)), importance_df['feature'].values)
plt.xlabel('Feature Importance', fontsize=12)
plt.ylabel('Descriptor', fontsize=12)
plt.title('QSAR Model - Top 15 Important Features', fontsize=14)
plt.gca().invert_yaxis()

plt.savefig(fig_path / 'qsar_feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

# Display table
print("\nTop 15 Features:")
importance_df

In [None]:
# Predicted vs Actual (for regression context, showing predicted probability)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Probability distribution by class
axes[0].hist(y_proba[y_test == 0], bins=30, alpha=0.7, label='Inactive', color='blue')
axes[0].hist(y_proba[y_test == 1], bins=30, alpha=0.7, label='Active', color='red')
axes[0].axvline(x=0.5, color='black', linestyle='--', label='Decision threshold')
axes[0].set_xlabel('Predicted Probability (Active)', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Prediction Distribution by Class', fontsize=14)
axes[0].legend()

# pIC50 vs predicted probability
axes[1].scatter(df_train['pIC50'].values[:len(y_test)], y_proba, alpha=0.5, c=y_test, cmap='coolwarm')
axes[1].axhline(y=0.5, color='black', linestyle='--', alpha=0.5)
axes[1].axvline(x=6.0, color='green', linestyle='--', alpha=0.5, label='Activity threshold')
axes[1].set_xlabel('pIC50', fontsize=12)
axes[1].set_ylabel('Predicted Probability (Active)', fontsize=12)
axes[1].set_title('pIC50 vs Predicted Probability', fontsize=14)
axes[1].legend()

plt.tight_layout()
plt.savefig(fig_path / 'qsar_predictions_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Save Model and Results

In [None]:
# Save model
model_path = Path.cwd().parent / "models" / "qsar_rf_model.pkl"
model.save(str(model_path))
print(f"Model saved: {model_path}")

# Save metrics
import json

all_metrics = {
    "task": "classification",
    "n_samples_train": len(X_train),
    "n_samples_test": len(X_test),
    "n_features": len(feature_cols),
    "test_metrics": test_metrics,
    "cv_results": {
        "roc_auc_mean": cv_results['roc_auc_mean'],
        "roc_auc_std": cv_results['roc_auc_std'],
        "accuracy_mean": cv_results['accuracy_mean'],
        "accuracy_std": cv_results['accuracy_std'],
    },
    "config": {
        "n_estimators": 100,
        "n_folds": 5,
        "random_seed": RANDOM_SEED,
    }
}

metrics_path = Path.cwd().parent / "results" / "metrics" / "qsar_metrics.json"
metrics_path.parent.mkdir(parents=True, exist_ok=True)
with open(metrics_path, 'w') as f:
    json.dump(all_metrics, f, indent=2)
print(f"Metrics saved: {metrics_path}")

# Save feature importance
importance_path = Path.cwd().parent / "models" / "feature_importance.csv"
importance_df.to_csv(importance_path, index=False)
print(f"Feature importance saved: {importance_path}")

## Summary

### Final Results:

| Metric | Value | Target | Status |
|--------|-------|--------|--------|
| Test ROC-AUC | See above | ≥ 0.75 | Check above |
| CV ROC-AUC | See above | - | - |
| CV Std | See above | < 0.05 | Check above |

### Top Features:
Check the feature importance plot and table above.

### Artifacts Saved:
- `models/qsar_rf_model.pkl` - Trained model
- `models/qsar_scaler.pkl` - Feature scaler
- `models/feature_importance.csv` - Feature importances
- `results/metrics/qsar_metrics.json` - All metrics
- `results/figures/qsar_*.png` - Visualizations

### Next Steps:
1. If ROC-AUC < 0.75: Consider hyperparameter tuning or additional features
2. If target met: Proceed to Phase 2 (Molecular Docking)
3. Prepare for first GitHub commit