## 1. Import Required Libraries and Dependencies

In [None]:
import sys
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT))

# Import custom modules
from src.ingest.fetch_genomes import get_genomes_by_type, get_genome_count
from src.features.genome_features import (
    compute_genome_embedding,
    load_genome_embedding,
    compute_all_genome_embeddings,
    build_genome_embedding_matrix,
)
from src.features.build_dataset import build_genome_media_dataset
from src.models.media_generator import ConditionalMediaVAE
from src.training.trainer import TrainConfig, train

print("✓ Libraries imported successfully")
print(f"✓ Project root: {PROJECT_ROOT}")

## 2. Check Genome Database and Compute Embeddings

In [None]:
# Check current genome inventory
genome_counts = get_genome_count()
print("\nGenome counts by organism type:")
for org_type, count in sorted(genome_counts.items()):
    print(f"  {org_type:12s}: {count:5d}")

print(f"\n  Total genomes: {sum(genome_counts.values())}")

In [None]:
# Compute genome embeddings (k-mer method)
# This extracts genomic features for all genomes
print("Computing genome embeddings (k-mer method)...\n")

n_computed = compute_all_genome_embeddings(method="kmer_128")
print(f"\n✓ Computed {n_computed} new embeddings")

## 3. Build CVAE Dataset for Bacteria

In [None]:
# Build training dataset (bacteria only, to start)
print("Building genome-media dataset for bacteria...\n")

bacteria_dataset = build_genome_media_dataset(
    embedding_method="kmer_128",
    organism_type="bacteria",
    test_size=0.15,
    val_size=0.15,
    seed=42,
    save=True,
)

print("\nDataset summary for bacteria:")
print(f"  Training samples:   {len(bacteria_dataset['X_train'])}")
print(f"  Validation samples: {len(bacteria_dataset['X_val'])}")
print(f"  Test samples:       {len(bacteria_dataset['X_test'])}")
print(f"\n  Media composition dim:  {bacteria_dataset['X_train'].shape[1]}")
print(f"  Genome embedding dim:   {bacteria_dataset['C_train'].shape[1]}")
print(f"\n  Organism types in set: {np.unique(bacteria_dataset['y_train'])}")

## 4. Train CVAE on Bacterial Genomes

In [None]:
# Configure CVAE training
bacteria_cfg = TrainConfig(
    model_type="cvae",
    epochs=100,
    batch_size=64,
    latent_dim=32,
    hidden_dims=[256, 128],
    beta=1.0,  # KL weight
    lr=1e-3,
    seed=42,
    curriculum_phases=["bacteria"],
    embedding_method="kmer_128",
)

print("CVAE Configuration for Bacteria:")
print(f"  Model type:      {bacteria_cfg.model_type}")
print(f"  Epochs:          {bacteria_cfg.epochs}")
print(f"  Batch size:      {bacteria_cfg.batch_size}")
print(f"  Latent dim:      {bacteria_cfg.latent_dim}")
print(f"  Learning rate:   {bacteria_cfg.lr}")
print(f"  Beta (KL weight): {bacteria_cfg.beta}")

In [None]:
# Train CVAE on bacteria
print("\n" + "="*60)
print("TRAINING CVAE ON BACTERIAL GENOMES")
print("="*60 + "\n")

bacteria_result = train(bacteria_cfg)

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"\nElapsed time: {bacteria_result['elapsed_seconds']:.1f} seconds")
print(f"Model saved to: {bacteria_result['save_path']}")
print(f"\nMetrics: {bacteria_result['metrics']}")

In [None]:
# Visualize training curves
history = bacteria_result["history"].get("bacteria", {})

if history:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    epochs = range(1, len(history.get("train_loss", [])) + 1)
    
    if "train_loss" in history:
        axes[0].plot(epochs, history["train_loss"], label="Train Loss")
        axes[0].set_xlabel("Epoch")
        axes[0].set_ylabel("Total Loss")
        axes[0].set_title("Training Loss")
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()
    
    if "recon_loss" in history:
        axes[1].plot(epochs, history["recon_loss"], label="Reconstruction Loss", color="green")
        axes[1].set_xlabel("Epoch")
        axes[1].set_ylabel("MSE")
        axes[1].set_title("Reconstruction Loss")
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
    
    if "kl_loss" in history:
        axes[2].plot(epochs, history["kl_loss"], label="KL Divergence", color="red")
        axes[2].set_xlabel("Epoch")
        axes[2].set_ylabel("KL Loss")
        axes[2].set_title("KL Divergence")
        axes[2].grid(True, alpha=0.3)
        axes[2].legend()
    
    plt.tight_layout()
    plt.savefig(PROJECT_ROOT / "notebooks" / "cvae_bacteria_training.png", dpi=100, bbox_inches="tight")
    plt.show()
    print("✓ Training curves saved")

## 5. Build and Train on Multi-Organism Dataset

In [None]:
# Build dataset including archea, fungi, protists
print("Building multi-organism dataset...\n")

multi_dataset = build_genome_media_dataset(
    embedding_method="kmer_128",
    organism_type=None,  # Include all organisms
    test_size=0.15,
    val_size=0.15,
    seed=42,
    save=True,
)

# Count samples by organism type
unique_types, counts = np.unique(multi_dataset['y_train'], return_counts=True)
print("\nMulti-organism dataset summary:")
print(f"  Total training samples:   {len(multi_dataset['X_train'])}")
print(f"  Total validation samples: {len(multi_dataset['X_val'])}")
print(f"  Total test samples:       {len(multi_dataset['X_test'])}")
print(f"\n  Organism type distribution (training):")

org_type_map = {1: "bacteria", 2: "archea", 3: "fungi", 4: "protist", 5: "virus"}
for otype, count in zip(unique_types, counts):
    name = org_type_map.get(otype, f"unknown_{otype}")
    pct = 100 * count / len(multi_dataset['X_train'])
    print(f"    {name:12s}: {count:5d} ({pct:5.1f}%)")

In [None]:
# Configure CVAE for curriculum learning
multi_cfg = TrainConfig(
    model_type="cvae",
    epochs=100,
    batch_size=64,
    latent_dim=32,
    hidden_dims=[256, 128],
    beta=1.0,
    lr=1e-3,
    seed=42,
    curriculum_phases=["bacteria", "archea", "fungi", "protist"],
    embedding_method="kmer_128",
)

print("CVAE Curriculum Learning Configuration:")
print(f"  Curriculum phases: {' → '.join(multi_cfg.curriculum_phases)}")
print(f"  Epochs per phase:  {multi_cfg.epochs}")

In [None]:
# Train with curriculum learning
print("\n" + "="*60)
print("TRAINING CVAE WITH CURRICULUM LEARNING")
print("Phases: bacteria → archea → fungi → protist")
print("="*60 + "\n")

multi_result = train(multi_cfg)

print("\n" + "="*60)
print("CURRICULUM TRAINING COMPLETE")
print("="*60)
print(f"\nTotal elapsed time: {multi_result['elapsed_seconds']:.1f} seconds")
print(f"Model saved to: {multi_result['save_path']}")

## 6. Evaluate Model Performance

In [None]:
# Load trained model
from src.models.media_generator import ConditionalMediaVAE

model = ConditionalMediaVAE.load(multi_result["save_path"])
print(f"✓ Loaded CVAE model")
print(f"  Latent dimension: {model.latent_dim}")
print(f"  Condition dimension: {model.condition_dim}")
print(f"  Input dimension: {model.input_dim}")

In [None]:
# Reconstruct test set and evaluate
X_test = multi_dataset["X_test"]
C_test = multi_dataset["C_test"]
y_test = multi_dataset["y_test"]

# Get reconstructions
recon = model.predict(X_test, C_test)

# Calculate reconstruction error
mse = np.mean((X_test - recon) ** 2)
mae = np.mean(np.abs(X_test - recon))

print("\nReconstruction Performance (Test Set):")
print(f"  MSE: {mse:.6f}")
print(f"  MAE: {mae:.6f}")

# Per-ingredient error
ingredient_mse = np.mean((X_test - recon) ** 2, axis=0)
print(f"\n  Mean ingredient MSE: {np.mean(ingredient_mse):.6f}")
print(f"  Max ingredient MSE:  {np.max(ingredient_mse):.6f}")
print(f"  Min ingredient MSE:  {np.min(ingredient_mse):.6f}")

## 7. Visualize Latent Space

In [None]:
# Encode test samples into latent space
latent_test = model.encode(X_test, C_test)  # (N_test, latent_dim)

# Reduce to 2D with PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
latent_2d = pca.fit_transform(latent_test)

print(f"\nLatent Space Visualization:")
print(f"  PC1 explains {pca.explained_variance_ratio_[0]:.1%} of variance")
print(f"  PC2 explains {pca.explained_variance_ratio_[1]:.1%} of variance")

In [None]:
# Plot latent space colored by organism type
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Color by organism type
colors = {1: "red", 2: "blue", 3: "green", 4: "orange", 5: "purple"}
org_labels = {1: "Bacteria", 2: "Archea", 3: "Fungi", 4: "Protist", 5: "Virus"}

for org_type in sorted(np.unique(y_test)):
    mask = y_test == org_type
    axes[0].scatter(
        latent_2d[mask, 0],
        latent_2d[mask, 1],
        c=colors.get(org_type, "gray"),
        label=org_labels.get(org_type, f"Type {org_type}"),
        alpha=0.6,
        s=30,
    )

axes[0].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%})")
axes[0].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%})")
axes[0].set_title("Latent Space by Organism Type")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Color by reconstruction error
recon_error = np.mean((X_test - recon) ** 2, axis=1)
scatter = axes[1].scatter(
    latent_2d[:, 0],
    latent_2d[:, 1],
    c=recon_error,
    cmap="RdYlGn_r",
    alpha=0.6,
    s=30,
)
axes[1].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%})")
axes[1].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%})")
axes[1].set_title("Latent Space by Reconstruction Error")
cbar = plt.colorbar(scatter, ax=axes[1])
cbar.set_label("MSE")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(PROJECT_ROOT / "notebooks" / "latent_space_visualization.png", dpi=100, bbox_inches="tight")
plt.show()
print("✓ Latent space visualization saved")

## 8. Generate Novel Media Compositions

In [None]:
# Generate new media for a specific bacterial genome
# Select a random test bacteria
bacteria_mask = y_test == 1
bacteria_indices = np.where(bacteria_mask)[0]

if len(bacteria_indices) > 0:
    sample_idx = bacteria_indices[0]
    genome_emb = C_test[sample_idx : sample_idx + 1]  # Keep 2D shape
    
    # Generate 5 novel media
    n_novel = 5
    novel_media = model.generate(n_novel, genome_emb)
    
    print(f"Generated {n_novel} novel media for bacterial genome (sample {sample_idx}):")
    print(f"\nGenerated media compositions (first 10 ingredients):")
    print(novel_media[:, :10])
    
    # Statistics
    print(f"\nGenerated media statistics:")
    print(f"  Mean concentration: {np.mean(novel_media):.4f}")
    print(f"  Std concentration:  {np.std(novel_media):.4f}")
    print(f"  Min concentration:  {np.min(novel_media):.4f}")
    print(f"  Max concentration:  {np.max(novel_media):.4f}")

## 9. Cross-Organism Validation

In [None]:
# Evaluate reconstruction error per organism type
print("\nReconstruction Error by Organism Type:")
print("\n  Organism Type | Samples | Mean MSE | Std MSE")
print("  " + "-" * 52)

for org_type in sorted(np.unique(y_test)):
    mask = y_test == org_type
    org_mse = np.mean((X_test[mask] - recon[mask]) ** 2)
    org_std = np.std(np.mean((X_test[mask] - recon[mask]) ** 2, axis=1))
    n_samples = np.sum(mask)
    
    org_name = org_labels.get(org_type, f"Type {org_type}")
    print(f"  {org_name:13s} | {n_samples:7d} | {org_mse:8.6f} | {org_std:7.6f}")

## 10. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print("CVAE TRAINING SUMMARY")
print("="*60)

print(f"\n1. GENOME PROCESSING")
print(f"   - Genomes by type: {dict(sorted(genome_counts.items()))}")
print(f"   - Embedding method: k-mer (128-dim, 4096-dimensional)")

print(f"\n2. DATASET BUILDING")
print(f"   - Training samples: {len(multi_dataset['X_train'])}")
print(f"   - Validation samples: {len(multi_dataset['X_val'])}")
print(f"   - Test samples: {len(multi_dataset['X_test'])}")

print(f"\n3. MODEL TRAINING")
print(f"   - Architecture: CVAE with genome conditioning")
print(f"   - Curriculum phases: bacteria → archea → fungi → protist")
print(f"   - Training time: {multi_result['elapsed_seconds']:.1f} seconds")

print(f"\n4. EVALUATION")
print(f"   - Test MSE: {mse:.6f}")
print(f"   - Test MAE: {mae:.6f}")
print(f"   - Latent dim explained variance: {sum(pca.explained_variance_ratio_):.1%}")

print(f"\n5. MODEL OUTPUTS")
print(f"   - Trained model: {multi_result['save_path']}")
print(f"   - Can generate novel media for unseen organisms")
print(f"   - Can interpolate between organism types")

print(f"\n" + "="*60)
print("NEXT STEPS:")
print("="*60)
print(f"1. Load model: model = ConditionalMediaVAE.load('{multi_result['save_path']}')") 
print(f"2. Generate media: novel_media = model.generate(n_samples=10, condition=genome_emb)")
print(f"3. Wet-lab validation: Test predictions experimentally")
print(f"4. Iterative improvement: Add experimental data to training set")
print()