# Aptamer Model Training

This notebook trains machine learning models for aptamer binding affinity prediction and cross-reactivity analysis.

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, roc_curve, auc, confusion_matrix

%matplotlib inline
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

In [None]:
# Add the project root to the path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

In [None]:
from src.data_processing.data_loader import AptamerDataLoader
from src.models.binding_affinity import BindingAffinityPredictor
from src.models.cross_reactivity import CrossReactivityAnalyzer
from src.visualization.plot_utils import AptamerVisualizer

## Load Feature-Enriched Data

In [None]:
# Load the feature-enriched dataset
feature_path = '../data/processed/aptamers_with_features.csv'

if os.path.exists(feature_path):
    df = pd.read_csv(feature_path)
    print(f"Loaded feature-enriched data: {len(df)} rows, {len(df.columns)} columns")
else:
    print(f"Feature-enriched data not found at {feature_path}")
    print("Please run the 02_feature_engineering.ipynb notebook first")

In [None]:
# Display the first few rows
df.head()

## Split Data for Training

In [None]:
# Split the data into training, validation, and test sets
data_loader = AptamerDataLoader()
train_df, val_df, test_df = data_loader.split_data(
    df,
    test_size=0.2,
    validation_size=0.1,
    random_state=42
)

print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")
print(f"Test set: {len(test_df)} samples")

## Binding Affinity Model Training

In [None]:
# Check if we have binding affinity data
binding_col = None

for col in df.columns:
    if 'binding' in col.lower() or 'affinity' in col.lower() or 'kd' in col.lower():
        binding_col = col
        break

if binding_col:
    print(f"Using '{binding_col}' as binding affinity target")
else:
    print("No binding affinity column found. Creating a synthetic target for demonstration.")
    
    # Create synthetic binding affinity for demonstration
    # This should be replaced with real data in production
    np.random.seed(42)
    gc_col = 'GC_Content' if 'GC_Content' in df.columns else 'gc_content'
    df['binding_affinity'] = (
        0.5 * (df[gc_col] / 100) +
        0.3 * (1 - abs(df['length'] - 30) / 30) +
        0.2 * np.random.random(len(df))
    )
    binding_col = 'binding_affinity'
    
    # Update the training, validation, and test sets
    train_df['binding_affinity'] = df['binding_affinity'].iloc[train_df.index].values
    val_df['binding_affinity'] = df['binding_affinity'].iloc[val_df.index].values
    test_df['binding_affinity'] = df['binding_affinity'].iloc[test_df.index].values

In [None]:
# Train multiple binding affinity models and compare them
model_types = ['random_forest', 'gradient_boosting', 'xgboost']
binding_models = {}
binding_metrics = {}

for model_type in model_types:
    print(f"\nTraining {model_type} model...")
    
    # Configure model
    if model_type == 'random_forest':
        config = {'model_type': model_type, 'n_estimators': 100, 'max_depth': 10}
    elif model_type == 'gradient_boosting':
        config = {'model_type': model_type, 'n_estimators': 100, 'max_depth': 5, 'learning_rate': 0.1}
    elif model_type == 'xgboost':
        config = {'model_type': model_type, 'n_estimators': 100, 'max_depth': 6, 'learning_rate': 0.05}
    else:
        config = {'model_type': model_type}
    
    # Initialize and train model
    predictor = BindingAffinityPredictor(config)
    metrics = predictor.train(train_df, binding_col, validation_df=val_df)
    
    # Store model and metrics
    binding_models[model_type] = predictor
    binding_metrics[model_type] = metrics
    
    # Print training metrics
    print(f"Training MSE: {metrics.get('training_mse', 'N/A'):.6f}")
    print(f"Training R²: {metrics.get('training_r2', 'N/A'):.6f}")
    
    if 'validation_mse' in metrics:
        print(f"Validation MSE: {metrics.get('validation_mse', 'N/A'):.6f}")
        print(f"Validation R²: {metrics.get('validation_r2', 'N/A'):.6f}")

In [None]:
# Compare model performance
model_comparison = pd.DataFrame([
    {
        'Model': model_type,
        'Training MSE': metrics.get('training_mse', float('nan')),
        'Training R²': metrics.get('training_r2', float('nan')),
        'Validation MSE': metrics.get('validation_mse', float('nan')),
        'Validation R²': metrics.get('validation_r2', float('nan'))
    }
    for model_type, metrics in binding_metrics.items()
])

model_comparison

In [None]:
# Select the best performing model based on validation R²
best_model_type = model_comparison.iloc[model_comparison['Validation R²'].idxmax()]['Model']
best_binding_model = binding_models[best_model_type]

print(f"Best binding affinity model: {best_model_type}")

In [None]:
# Evaluate the best model on the test set
test_metrics = best_binding_model.evaluate_model(test_df, binding_col)

print("Test set evaluation:")
print(f"MSE: {test_metrics['mse']:.6f}")
print(f"RMSE: {test_metrics['rmse']:.6f}")
print(f"R²: {test_metrics['r2']:.6f}")
print(f"MAE: {test_metrics['mae']:.6f}")
print(f"Correlation: {test_metrics['correlation']:.6f}")

In [None]:
# Visualize actual vs predicted values
test_predictions = best_binding_model.predict(test_df)

plt.figure(figsize=(10, 8))
plt.scatter(test_df[binding_col], test_predictions, alpha=0.7)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('Actual Binding Affinity')
plt.ylabel('Predicted Binding Affinity')
plt.title(f'Binding Affinity Prediction ({best_model_type})')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Get feature importance
feature_importance = best_binding_model.get_feature_importance()

# Plot top 20 features
top_n = min(20, len(feature_importance))
top_features = feature_importance.sort_values('importance', ascending=False).head(top_n)

plt.figure(figsize=(12, 8))
bars = plt.barh(top_features['feature'], top_features['importance'])
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title(f'Top {top_n} Features for Binding Affinity Prediction')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## Cross-Reactivity Model Training

In [None]:
# Check if we have target information for cross-reactivity analysis
if 'Target_Name' not in df.columns:
    print("No 'Target_Name' column found. Cross-reactivity analysis requires target information.")
else:
    # Train cross-reactivity model
    print("Training cross-reactivity model...")
    
    cross_reactivity_model = CrossReactivityAnalyzer({
        'model_type': 'xgboost',
        'n_estimators': 100,
        'max_depth': 5,
        'learning_rate': 0.05
    })
    
    cr_metrics = cross_reactivity_model.train_cross_reactivity_model(train_df, validation_df=val_df)
    
    # Print training metrics
    print(f"Training accuracy: {cr_metrics.get('training_accuracy', 'N/A'):.4f}")
    
    if 'validation_accuracy' in cr_metrics:
        print(f"Validation accuracy: {cr_metrics.get('validation_accuracy', 'N/A'):.4f}")
    
    # Print ROC AUC scores
    if 'training_roc_auc' in cr_metrics:
        print("\nROC AUC scores:")
        for target, auc_score in cr_metrics['training_roc_auc'].items():
            if target != 'average':
                print(f"{target}: {auc_score:.4f}")
        print(f"Average: {cr_metrics['training_roc_auc'].get('average', 'N/A'):.4f}")

In [None]:
# Evaluate cross-reactivity model on the test set
if 'Target_Name' in df.columns:
    # Make predictions
    test_cr_predictions = cross_reactivity_model.predict_cross_reactivity(test_df)
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(test_df['Target_Name'], test_cr_predictions['predicted_target'])
    target_names = cross_reactivity_model.target_names
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
               xticklabels=target_names, yticklabels=target_names)
    plt.title('Confusion Matrix for Target Prediction')
    plt.xlabel('Predicted Target')
    plt.ylabel('Actual Target')
    plt.tight_layout()
    plt.show()
    
    # Calculate and print accuracy
    accuracy = (test_cr_predictions['predicted_target'] == test_df['Target_Name']).mean()
    print(f"Test set accuracy: {accuracy:.4f}")

In [None]:
# Analyze cross-reactivity
if 'Target_Name' in df.columns:
    # Identify potentially cross-reactive aptamers
    crossreact_df = cross_reactivity_model.identify_cross_reactive_aptamers(
        test_cr_predictions, threshold=0.3
    )
    
    # Count cross-reactive aptamers
    cross_reactive_count = crossreact_df['is_cross_reactive'].sum()
    total_count = len(crossreact_df)
    
    print(f"Found {cross_reactive_count} potentially cross-reactive aptamers out of {total_count} ({cross_reactive_count/total_count:.1%})")
    
    # Show examples of cross-reactive aptamers
    if cross_reactive_count > 0:
        cr_examples = crossreact_df[crossreact_df['is_cross_reactive']].head(5)
        display_cols = ['Target_Name', 'predicted_target', 'cross_reactive_targets', 'specificity_score']
        seq_col = 'Sequence' if 'Sequence' in cr_examples.columns else 'sequence'
        if seq_col in cr_examples.columns:
            display_cols = [seq_col] + display_cols
        
        print("\nExamples of cross-reactive aptamers:")
        display(cr_examples[display_cols])

In [None]:
# Visualize cross-reactivity
if 'Target_Name' in df.columns:
    visualizer = AptamerVisualizer()
    visualizer.plot_cross_reactivity_matrix(crossreact_df, output_path=None)

## Save Trained Models

In [None]:
# Create output directory
model_dir = '../models'
os.makedirs(model_dir, exist_ok=True)

In [None]:
# Save binding affinity model
binding_model_path = os.path.join(model_dir, 'binding_affinity_model.pkl')
best_binding_model.save_model(binding_model_path)
print(f"Binding affinity model saved to {binding_model_path}")

# Save cross-reactivity model
if 'Target_Name' in df.columns:
    cross_reactivity_model_path = os.path.join(model_dir, 'cross_reactivity_model.pkl')
    cross_reactivity_model.save_model(cross_reactivity_model_path)
    print(f"Cross-reactivity model saved to {cross_reactivity_model_path}")

## Conclusions

Key findings from model training:

1. Binding affinity prediction: [Fill in after running the notebook]
2. Most important features for binding prediction: [Fill in after running the notebook]
3. Cross-reactivity analysis: [Fill in after running the notebook]
4. Most promising aptamer candidates: [Fill in after running the notebook]

These models will be used to select and optimize aptamers in the following notebooks.