# Hierarchical Temporal Model Training

This notebook demonstrates training the Hierarchical Temporal Model, which combines hierarchical gene organization with temporal dynamics for comprehensive cellular modeling.

## Model Features
- **Hierarchical + Temporal**: Combines pathway organization with temporal kinetics
- **Pathway-temporal correlations**: Analyzes relationships between pathways and time
- **Compartment-temporal processing**: Handles spatial-temporal dynamics
- **Combined attention**: Multi-level attention with temporal awareness


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import lightning as L
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import logging
from typing import Dict, List, Optional, Tuple

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Import enhanced models and loaders
from src.state.emb.nn.enhanced_models import HierarchicalTemporalModel
from src.state.emb.data.enhanced_loaders import HierarchicalTemporalLoader
from src.state.emb.data.loader import create_dataloader
from src.state.emb.nn.model import StateEmbeddingModel
from src.state import utils

print("Enhanced models and loaders imported successfully!")


## Configuration Setup


In [None]:
# Model configuration for Hierarchical Temporal Model
config = {
    'model': {
        'token_dim': 512,
        'd_model': 512,
        'nhead': 8,
        'd_hid': 2048,
        'nlayers': 6,
        'output_dim': 128,
        'dropout': 0.1,
        'warmup_steps': 1000,
        'max_lr': 4e-4,
        'compiled': False,
        'dataset_correction': True,
        'counts': True,
        'rda': True
    },
    'dataset': {
        'pad_length': 512,
        'P': 50,
        'N': 50,
        'S': 20
    },
    'training': {
        'batch_size': 32,
        'max_epochs': 100,
        'gradient_clip_val': 1.0,
        'accumulate_grad_batches': 1,
        'precision': '16-mixed'
    },
    'hierarchical_temporal': {
        'num_pathways': 1000,
        'num_compartments': 5,
        'time_steps': 5,
        'pathway_annotation_file': 'data/pathways.txt',
        'compartment_annotation_file': 'data/compartments.txt'
    }
}

print("Configuration loaded:")
print(json.dumps(config, indent=2))
