In [None]:
from datetime import datetime
import os 
from src.Dataset import load_and_explore_data, preprocess_data
from src.Autoencoder import train_autoencoder, evaluate_autoencoder
from src.MIL import AttentionMIL
from src.train import leave_one_out_cross_validation


def run_pipeline_loocv(input_file, output_dir='results',
                       latent_dim=64, num_epochs_ae=200,
                       num_epochs=50, num_classes=2,
                       hidden_dim=128, sample_source_dim=4,
                       project_name="tcellMIL"):
    """run complete pipeline with leave one out cross validation
    
    Parameters:
    - input_file: path to input file
    - output_dir: directory to save results
    - latent_dim: dimension of latent space
    - num_epochs_ae: number of epochs for autoencoder
    - num_epoch_mil: number of epochs for MIL
    - num_classes: number of classes
    - hidden_dim: dimension of hidden layer

    Returns:
    - dict of results and models
    """

    # config = {
    #     "input_file": input_file,
    #     "output_dir": output_dir,
    #     "latent_dim": latent_dim,
    #     "num_epochs_ae": num_epochs_ae,
    #     "num_epochs_mil": num_epochs,
    #     "num_classes": num_classes,
    #     "hidden_dim": hidden_dim,
    #     "cv_method": "leave-one-out"
    # }

    # wandb.init(project=project_name, config=config)

    # Create output directories
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    result_dir = os.path.join(output_dir, f"run_{timestamp}")
    ae_dir = os.path.join(result_dir, "autoencoder")
    mil_dir = os.path.join(result_dir, "mil")
    
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(ae_dir, exist_ok=True)
    os.makedirs(mil_dir, exist_ok=True)
    
    
    # Step 1: Load and explore data
    print("\n" + "="*80)
    print("STEP 1: LOADING AND EXPLORING DATA")
    print("="*80)
    adata = load_and_explore_data(input_file)

    # wandb.config.update({
    #     "cells": adata.n_obs,
    #     "TFs": adata.n_vars,
    #     "patients": adata.obs["patient_id"].nunique()
    # })

    # if "Response_3m" in adata.obs.columns:
    #     wandb.config.update({
    #         "Response_distribution": dict(adata.obs["Response_3m"].value_counts())
    #     })

    # step 2: train autoencoder
    print("\n" + "="*80)
    print("STEP 2: TRAINING AUTOENCODER")
    print("="*80)
    train_loader, val_loader, test_loader, input_dim = preprocess_data(adata)

    # Step 3:train autoencoder
    print("\n" + "="*80)
    print("STEP 3: TRAINING AUTOENCODER")
    print("="*80)
    model, train_losses, val_losses = train_autoencoder(
            train_loader, val_loader, input_dim, latent_dim, num_epochs_ae, save_path=ae_dir
        )
    adata_latent, test_loss = evaluate_autoencoder(
        model, test_loader, adata, adata.var_names.tolist(), save_path=ae_dir
    )
    
    # Save latent representations
    latent_file = os.path.join(ae_dir, "latent_representation.h5ad")
    adata_latent.write(latent_file)

    # Step 4: Run LOOCV
    print("\n" + "="*80)
    print("STEP 4: RUNNING LEAVE-ONE-OUT CROSS-VALIDATION")
    print("="*80)
    
    # Check if we have response information
    if 'Response_3m' not in adata_latent.obs.columns:
        print("ERROR: 'response' column not found in the data. Cannot proceed with MIL.")
        # wandb.finish()
        return None
    
    # Remove patients with NaN responses
    patients_with_missing = adata_latent.obs[adata_latent.obs['Response_3m'].isna()]['patient_id'].unique()
    if len(patients_with_missing) > 0:
        print(f"Removing {len(patients_with_missing)} patients with missing responses")
        adata_latent = adata_latent[~adata_latent.obs['patient_id'].isin(patients_with_missing)].copy()
        

    cv_results = leave_one_out_cross_validation(
        adata_latent, 
        input_dim = latent_dim,
        num_classes = num_classes, 
        hidden_dim = hidden_dim,
        sample_source_dim = sample_source_dim,
        num_epochs = num_epochs,
        save_path = mil_dir
    )
        


    # wandb.finish()

    print(f"Pipeline completed successfully! Results saved to {result_dir}")

    return {
        'adata': adata,
        'autoencoder': model,
        'latent_data': adata_latent,
        'mil_results': cv_results,
        'results_dir': result_dir
    }