# Inference for Lifespan Prediction

This notebook demonstrates how to:
1. Load a trained model
2. Process new molecules
3. Generate predictions
4. Visualize predictions

This uses the refactored `lifespan_predictor` package for clean inference.

## 1. Setup and Imports

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# Import from the refactored package
from lifespan_predictor.config import Config
from lifespan_predictor.data.preprocessing import load_and_clean_csv
from lifespan_predictor.data.featurizers import CachedGraphFeaturizer
from lifespan_predictor.data.fingerprints import FingerprintGenerator
from lifespan_predictor.data.dataset import LifespanDataset
from lifespan_predictor.models.predictor import LifespanPredictor
from lifespan_predictor.utils.logging import setup_logger
from lifespan_predictor.utils.io import load_checkpoint
from lifespan_predictor.utils.visualization import plot_predictions

from torch_geometric.loader import DataLoader

# Setup logging
logger = setup_logger("inference", level="INFO")
logger.info("Starting inference notebook")

## 2. Load Configuration and Model

Load the configuration that was used during training and the trained model.

In [None]:
# Path to the trained model and configuration
model_dir = "../results"  # Adjust this to your output directory
config_path = os.path.join(model_dir, "training_config.yaml")
model_path = os.path.join(model_dir, "best_model.pt")

# Load configuration
logger.info(f"Loading configuration from: {config_path}")
config = Config.from_yaml(config_path)

# Set device
device = torch.device('cuda' if config.device.use_cuda and torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

In [None]:
# Initialize model
logger.info("Initializing model...")
model = LifespanPredictor(config)
model = model.to(device)

# Load trained weights
logger.info(f"Loading model weights from: {model_path}")
checkpoint = load_checkpoint(model_path, model, device=device)

model.eval()
logger.info("Model loaded successfully")

# Display model info
if 'epoch' in checkpoint:
    logger.info(f"Model trained for {checkpoint['epoch']} epochs")
if 'best_metric' in checkpoint:
    logger.info(f"Best validation metric: {checkpoint['best_metric']:.4f}")

## 3. Load and Process New Molecules

Load molecules for inference and process them using the same pipeline as training.

In [None]:
# Option 1: Load from CSV file
# Uncomment and modify the path to use your own data
# inference_csv = "path/to/your/molecules.csv"
# inference_df = load_and_clean_csv(
#     csv_path=inference_csv,
#     smiles_column="SMILES",
#     label_column=None  # No labels for inference
# )

# Option 2: Use test data from preprocessing
test_data_dir = os.path.join(config.data.output_dir, "test")
if os.path.exists(test_data_dir):
    logger.info(f"Loading test data from: {test_data_dir}")
    inference_df = pd.read_csv(os.path.join(test_data_dir, "processed_data.csv"))
else:
    # Option 3: Create sample molecules for demonstration
    logger.info("Creating sample molecules for demonstration")
    sample_smiles = [
        "CC(C)Cc1ccc(cc1)C(C)C(O)=O",  # Ibuprofen
        "CC(=O)Oc1ccccc1C(=O)O",  # Aspirin
        "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",  # Caffeine
        "CC(C)NCC(COc1ccccc1)O",  # Propranolol
        "CC(C)Cc1ccc(cc1)C(C)C(=O)O"  # Similar to Ibuprofen
    ]
    inference_df = pd.DataFrame({config.data.smiles_column: sample_smiles})

logger.info(f"Loaded {len(inference_df)} molecules for inference")
inference_df.head()

## 4. Generate Features for New Molecules

Process the molecules using the same featurization pipeline.

In [None]:
# Extract SMILES
inference_smiles = inference_df[config.data.smiles_column].tolist()

# Initialize featurizers
graph_featurizer = CachedGraphFeaturizer(
    cache_dir=config.data.graph_features_dir,
    max_atoms=config.featurization.max_atoms,
    atom_feature_dim=config.featurization.atom_feature_dim,
    n_jobs=config.featurization.n_jobs
)

fp_generator = FingerprintGenerator(
    morgan_radius=config.featurization.morgan_radius,
    morgan_nbits=config.featurization.morgan_nbits,
    rdkit_fp_nbits=config.featurization.rdkit_fp_nbits,
    n_jobs=config.featurization.n_jobs
)

logger.info("Featurizers initialized")

In [None]:
# Generate graph features
logger.info("Generating graph features...")
inference_adj, inference_features, _ = graph_featurizer.featurize(
    smiles_list=inference_smiles,
    labels=None,
    force_recompute=False
)

logger.info(f"Graph features shape: adj={inference_adj.shape}, features={inference_features.shape}")

In [None]:
# Generate fingerprints
logger.info("Generating fingerprints...")
inference_fp_hashed, inference_fp_nonhashed = fp_generator.generate_fingerprints(
    smiles_list=inference_smiles,
    cache_dir=config.data.fingerprints_dir
)

logger.info(f"Fingerprints shape: hashed={inference_fp_hashed.shape}, non-hashed={inference_fp_nonhashed.shape}")

## 5. Create Dataset and DataLoader

In [None]:
# Create dataset
logger.info("Creating inference dataset...")
inference_dataset = LifespanDataset(
    smiles_list=inference_smiles,
    graph_features=(inference_adj, inference_features),
    fingerprints=(inference_fp_hashed, inference_fp_nonhashed),
    labels=None  # No labels for inference
)

# Create dataloader
inference_loader = DataLoader(
    inference_dataset,
    batch_size=config.training.batch_size,
    shuffle=False,
    num_workers=0
)

logger.info(f"Created inference dataset with {len(inference_dataset)} samples")

## 6. Generate Predictions

Run inference on the processed molecules.

In [None]:
# Generate predictions
logger.info("Generating predictions...")
predictions = []

model.eval()
with torch.no_grad():
    for batch in inference_loader:
        batch = batch.to(device)
        outputs = model(batch)
        
        # Apply activation for classification
        if config.training.task == 'classification':
            outputs = torch.sigmoid(outputs)
        
        predictions.append(outputs.cpu().numpy())

predictions = np.concatenate(predictions, axis=0)
logger.info(f"Generated predictions for {len(predictions)} molecules")

## 7. Process and Display Results

In [None]:
# Add predictions to dataframe
if config.training.task == 'classification':
    inference_df['Probability'] = predictions.flatten()
    inference_df['Predicted_Class'] = (predictions.flatten() >= 0.5).astype(int)
    inference_df['Confidence'] = np.abs(predictions.flatten() - 0.5) * 2  # 0 to 1 scale
else:
    inference_df['Predicted_Value'] = predictions.flatten()

# Display results
print("\n" + "="*60)
print("INFERENCE RESULTS")
print("="*60 + "\n")
print(inference_df)
print("\n" + "="*60)

## 8. Visualize Predictions

In [None]:
# Visualization for classification
if config.training.task == 'classification':
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Probability distribution
    axes[0].hist(predictions.flatten(), bins=20, edgecolor='black', alpha=0.7)
    axes[0].axvline(x=0.5, color='red', linestyle='--', label='Decision Threshold')
    axes[0].set_xlabel('Predicted Probability')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Distribution of Predicted Probabilities')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Class distribution
    class_counts = inference_df['Predicted_Class'].value_counts().sort_index()
    axes[1].bar(class_counts.index, class_counts.values, edgecolor='black', alpha=0.7)
    axes[1].set_xlabel('Predicted Class')
    axes[1].set_ylabel('Count')
    axes[1].set_title('Distribution of Predicted Classes')
    axes[1].set_xticks([0, 1])
    axes[1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "inference_predictions.png"), dpi=300, bbox_inches='tight')
    plt.show()
    
    logger.info(f"Prediction plot saved to: {os.path.join(model_dir, 'inference_predictions.png')}")

# Visualization for regression
else:
    fig, ax = plt.subplots(figsize=(8, 6))
    
    ax.hist(predictions.flatten(), bins=20, edgecolor='black', alpha=0.7)
    ax.set_xlabel('Predicted Value')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of Predicted Values')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "inference_predictions.png"), dpi=300, bbox_inches='tight')
    plt.show()
    
    logger.info(f"Prediction plot saved to: {os.path.join(model_dir, 'inference_predictions.png')}")

## 9. Save Results

In [None]:
# Save predictions to CSV
output_path = os.path.join(model_dir, "inference_results.csv")
inference_df.to_csv(output_path, index=False)
logger.info(f"Results saved to: {output_path}")

# Save predictions as numpy array
predictions_path = os.path.join(model_dir, "predictions.npy")
np.save(predictions_path, predictions)
logger.info(f"Predictions array saved to: {predictions_path}")

print("\n" + "="*60)
print("INFERENCE COMPLETE")
print("="*60)
print(f"\nResults saved to: {model_dir}")
print(f"  - CSV: inference_results.csv")
print(f"  - Predictions: predictions.npy")
print(f"  - Visualization: inference_predictions.png")
print("\n" + "="*60)

## 10. Summary Statistics

In [None]:
# Display summary statistics
print("\n" + "="*60)
print("SUMMARY STATISTICS")
print("="*60 + "\n")

if config.training.task == 'classification':
    print(f"Total molecules: {len(inference_df)}")
    print(f"\nPredicted as Class 0: {(inference_df['Predicted_Class'] == 0).sum()}")
    print(f"Predicted as Class 1: {(inference_df['Predicted_Class'] == 1).sum()}")
    print(f"\nMean probability: {predictions.mean():.4f}")
    print(f"Std probability: {predictions.std():.4f}")
    print(f"Min probability: {predictions.min():.4f}")
    print(f"Max probability: {predictions.max():.4f}")
    print(f"\nMean confidence: {inference_df['Confidence'].mean():.4f}")
    
    # High confidence predictions
    high_conf = inference_df[inference_df['Confidence'] > 0.8]
    print(f"\nHigh confidence predictions (>0.8): {len(high_conf)}")
    if len(high_conf) > 0:
        print("\nTop 5 most confident predictions:")
        print(high_conf.nlargest(5, 'Confidence')[['SMILES', 'Probability', 'Predicted_Class', 'Confidence']])
else:
    print(f"Total molecules: {len(inference_df)}")
    print(f"\nMean predicted value: {predictions.mean():.4f}")
    print(f"Std predicted value: {predictions.std():.4f}")
    print(f"Min predicted value: {predictions.min():.4f}")
    print(f"Max predicted value: {predictions.max():.4f}")
    print(f"Median predicted value: {np.median(predictions):.4f}")

print("\n" + "="*60)