# DeepGP Genomics Analysis

Deep learning-based Genome-wide Predictor, a novel multimodal deep learning framework that incorporates bidirectional state space modules to predict cardiometabolic disease risk using genome-wide variants and demographic data.


## Setup and Imports

Import all necessary libraries for deep learning, genomics data processing, and evaluation metrics. This includes PyTorch Lightning for scalable training, sklearn for metrics, and custom modules for genomics-specific functionality.

In [None]:
# -*- coding: utf-8 -*-
"""
@author: Taiyu Zhu
"""

import os
import torch
import random
import csv
import json
import datetime
import pickle
import numpy as np
import pytorch_lightning as pl

from sklearn import metrics
from models import DeepGP
from utils import SNPPCACHRDataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from args_generator import args_initial

## Configuration and Reproducibility

Set random seeds for reproducible experiments and define the model registry and evaluation metrics to track across different runs.

In [None]:
# Set random seed for reproducibility
fix_seed = 33
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)

base_folder = ''
model_dict = {'DeepGP':DeepGP }

# Metrics to track
fieldnames = ['pheno','accuracy','precision','recall', 'specificity','f1','auc','mcc']

## Load Configuration

Initialize experimental configuration parameters including data paths, model settings, and training hyperparameters from the args generator.

In [None]:
configs = args_initial()
configs.data_dir = base_folder+'pukb/genes'
print('Running experiment for the whole genome')

## Data Loading

Load and preprocess genomic SNP data using a custom DataModule that handles PCA dimensionality reduction, chromosome organization, and train/validation/test splits with proper balancing.

In [None]:
print("Loading data...")
# SNPPCACHRDataModule: Custom PyTorch Lightning DataModule that:
# - Loads SNP data
# - Organizes variants by chromosome (CHR-aware processing)
snpdata = SNPPCACHRDataModule(configs)

# Update model configs with data-specific dimensions
configs.enc_len = len(snpdata.genes)                                    # Total number of genetic variants across genome
configs.enc_len_chr = [len(snp_chr) for snp_chr in snpdata.genes_chr]  # Number of variants per chromosome (for chr-aware attention)
configs.pos = snpdata.pos_chr                                          # Physical genomic positions for positional encoding

## Model Setup

Initialize the appropriate DeepGP model architecture based on input data type - either SNPs-only for pure genomic analysis or SNPs+covariates for clinical prediction with additional metadata.

In [None]:
# Select model architecture based on data modality
if configs.dm == 'snps':
    # SNP_Model_mamba: Pure genomic model
    # - Takes only SNP genotype data
    # - Uses Mamba (state-space model) for long-range genomic dependencies
    # - Incorporates chromosome-aware positional encoding
    model = model_dict[configs.mn].SNP_Model_mamba(configs)  
elif configs.dm == 'snps_covs': 
    # SNPCOV_Model_mamba: Multi-modal genomic + clinical model
    # - Combines SNP data with demographic covariates (age, sex, PCs, etc.)
    # - Uses separate encoders for genomic and clinical data
    # - Fuses representations for improved phenotype prediction
    model = model_dict[configs.mn].SNPCOV_Model_mamba(configs)

## Training Configuration

Set up experiment tracking, model checkpointing, and early stopping. Configure TensorBoard logging and save the best model based on validation AUC.

In [None]:
# Primary metric for model selection and early stopping
# AUC is preferred for imbalanced genomic datasets over accuracy
monitor_acc = 'val_auc'

# Setup logging and checkpointing for experiment tracking
if configs.save_log:
    # Hierarchical logging structure: model_datamode/phenotype/experiment
    log_name = f'{configs.mn}_{configs.dm}/{configs.label}'
    if len(configs.use_sim)>0:
         log_name = log_name+'/sim'  # Separate simulated data experiments
    
    # TensorBoard logger for training visualization
    logger = TensorBoardLogger(save_dir=base_folder+'DeepGP/logs/', name=log_name,version=configs.exp_name)
    
    # ModelCheckpoint: Save only the best model based on validation AUC
    checkpoint_callback = ModelCheckpoint(monitor=monitor_acc,save_top_k=1,mode='max')
    
    # Callbacks for training control
    callbacks = [
        EarlyStopping(monitor=monitor_acc, patience=configs.patience, mode="max"),  # Stop if no improvement
        checkpoint_callback  # Save best model
    ]
    ecb = True
else:
    # Minimal setup for quick experiments without logging
    logger, ecb = False, False
    callbacks = [EarlyStopping(monitor=monitor_acc, patience=configs.patience, mode="max")]

## Initialize Trainer

Configure the PyTorch Lightning trainer with GPU acceleration, callbacks for training control, and validation frequency optimized for large genomic datasets.

In [None]:
trainer = pl.Trainer(
    accelerator="gpu",                    # Use GPU acceleration for large genomic datasets
    devices=configs.gpus,                 # Multi-GPU support for distributed training
    callbacks=callbacks,                  # Early stopping + model checkpointing
    max_epochs=30,                        # Maximum training epochs
    val_check_interval=300,               # Validate every 300 training steps (not epochs)
    logger=logger,                        # TensorBoard logging
    enable_checkpointing=ecb,             # Model checkpointing control
    enable_progress_bar=ecb,              # Progress bar display control
)

## Training and Testing

Execute the training loop with automatic validation and early stopping. Load the best checkpoint based on validation AUC and evaluate on the test set.

In [None]:
print("Start training...")
trainer.fit(model, snpdata) 

# Load best model and test
if configs.save_log:
    best_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)
else:
    best_model = model

trainer.test(best_model, snpdata)

## Model Evaluation

Generate predictions on the test set and calculate comprehensive evaluation metrics including clinical metrics (specificity, sensitivity) and genomics-appropriate measures (AUC, MCC) for imbalanced datasets.

In [None]:
# Extract predictions from all test batches
outputs = trainer.predict(best_model, snpdata)
y_pred_proba = np.concatenate([opt[0].numpy() for opt in outputs])  # Predicted probabilities [0,1]
y_pred = np.concatenate([opt[1].numpy() for opt in outputs])        # Binary predictions {0,1}
y_true = np.concatenate([opt[2].numpy() for opt in outputs])        # True labels {0,1}

# Standard classification metrics
accuracy = metrics.accuracy_score(y_true, y_pred)                   # Overall correctness
precision = metrics.precision_score(y_true, y_pred)                 # PPV: True positives / Predicted positives
recall = metrics.recall_score(y_true, y_pred)                       # Sensitivity: True positives / Actual positives

# Calculate specificity (True Negative Rate) - important for medical screening
# Specificity = TN / (TN + FP) - ability to correctly identify non-cases
tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)

# Advanced metrics for imbalanced genomic data
f1 = metrics.f1_score(y_true, y_pred)                              # Harmonic mean of precision and recall
mcc = metrics.matthews_corrcoef(y_true, y_pred)                    # Matthews Correlation Coefficient (-1 to 1, robust to class imbalance)
auc = metrics.roc_auc_score(y_true, y_pred_proba)                  # Area Under ROC Curve (discrimination ability)

results_dict = {'pheno': configs.label,'accuracy':accuracy,'precision':precision,
                'recall':recall, 'specificity':specificity,'f1':f1,'auc':auc,'mcc':mcc}

## Save Results

Save all experimental results, hyperparameters, and raw predictions in an organized directory structure for reproducibility and downstream analysis.

In [None]:
now = datetime.datetime.now().strftime('%m-%d_%H-%M')

if configs.save_results:
    # Create hierarchical results directory structure
    # Format: DeepGP/results/{model}_{datamode}/{phenotype}/{experiment_name}/
    results_dir = base_folder+f'DeepGP/results/{configs.mn}_{configs.dm}/{configs.label}/{configs.exp_name}'
    if len(configs.use_sim)>0:
        # Separate directory for simulated data experiments
        results_dir = base_folder+f'DeepGP/results/{configs.mn}_{configs.dm}/{configs.label}/sim/{configs.exp_name}'

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    
    # Save evaluation metrics to CSV for easy analysis
    results_csv = f'{results_dir}/results_{now}.csv'
    with open(results_csv, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerow(results_dict)
    
    # Save hyperparameters and configuration as JSON
    # Essential for reproducibility and experiment tracking
    with open(f'{results_dir}/params_{now}.txt', 'w') as f:
        configs.pos = None  # Remove large position arrays to reduce file size
        json.dump(configs.__dict__, f, indent=2)
    
    # Save raw predictions for additional analysis
    # Format: (true_labels, predicted_probabilities)
    # Useful for ROC curves, calibration plots, error analysis
    with open(f'{results_dir}/preds.pkl', 'wb') as f:
        pickle.dump((y_true, y_pred_proba), f)

## Display Results

Print all evaluation metrics to console for immediate review of model performance.

In [None]:
if configs.show:
    print(f'Accuracy: {accuracy}')
    print(f'Precision: {precision}')
    print(f'Recall: {recall}')
    print(f'Specificity: {specificity}')
    print(f'F1 Score: {f1}')
    print(f'AUC: {auc}')
    print(f'MCC: {mcc}')

## Key Features

- **DeepGP Model**: Uses Mamba (state-space) architecture for modeling long-range genomic dependencies across chromosomes
- **Chromosome-aware Processing**: Maintains biological structure with per-chromosome encoding and positional information
- **Multi-modal Support**: Handles both SNPs-only and SNPs+covariates (age, sex, PCs) for clinical applications
- **Genomics-optimized Metrics**: Focuses on AUC and MCC for imbalanced datasets, includes specificity for medical relevance
- **MLOps Ready**: Complete experiment tracking with TensorBoard logging, model checkpointing, and reproducible seeds
- **Scalable Architecture**: GPU acceleration and multi-device support for genome-wide datasets
- **Clinical Translation**: Saves raw predictions for downstream analysis (ROC curves, risk stratification, calibration)