In [None]:
import h5py
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.sparse import csc_matrix
import numpy as np
from tqdm import tqdm

In [None]:
def select_highly_variable_genes(sparse_matrix, n_top_genes=5000):
    """Select the most variable genes based on their variance"""
    # Calculate variance for each gene
    gene_variances = np.var(sparse_matrix.toarray(), axis=0)
    
    # Get indices of top n variable genes
    top_gene_indices = np.argsort(gene_variances)[-n_top_genes:]
    top_gene_indices.sort()  # Sort indices for consistent ordering
    
    return top_gene_indices

In [None]:
def load_hdf5_in_chunks(h5_file_path, chunk_size=1000, n_top_genes=None):
    """Load HDF5 file in chunks with optional gene selection"""
    with h5py.File(h5_file_path, "r", libver='latest', swmr=True) as f:
        matrix_group = f["matrix"]
        data = matrix_group["data"][:]
        indices = matrix_group["indices"][:]
        indptr = matrix_group["indptr"][:]
        shape = tuple(matrix_group["shape"][:])
        cell_barcodes = [barcode.decode("utf-8") for barcode in matrix_group["barcodes"][:]]
        
        # Create sparse matrix
        sparse_matrix = csc_matrix((data, indices, indptr), shape=shape).T
        
        # Select top variable genes if specified
        if n_top_genes is not None and sparse_matrix.shape[1] > n_top_genes:
            print(f"Selecting {n_top_genes} most variable genes...")
            gene_indices = select_highly_variable_genes(sparse_matrix, n_top_genes)
            sparse_matrix = sparse_matrix[:, gene_indices]
        
        # Get total number of cells
        n_cells = sparse_matrix.shape[0]
        
        # Convert to chunks with progress bar
        chunks = []
        for start_idx in tqdm(range(0, n_cells, chunk_size), desc="Loading chunks"):
            end_idx = min(start_idx + chunk_size, n_cells)
            chunk = sparse_matrix[start_idx:end_idx].toarray()
            chunks.append(torch.tensor(chunk, dtype=torch.float32))
            
        return chunks, cell_barcodes

In [None]:
class CachedGeneExpressionDataset(Dataset):
    def __init__(self, v1_file_path, v2_file_path, chunk_size=1000, n_top_genes=5000):
        self.chunk_size = chunk_size
        
        # Load V1 data
        print("Loading V1 data...")
        v1_chunks, self.v1_barcodes = load_hdf5_in_chunks(v1_file_path, chunk_size)
        self.v1_data = torch.cat(v1_chunks)
        
        # Load V2 data with gene selection
        print("Loading V2 data...")
        self.v2_chunks, self.v2_barcodes = load_hdf5_in_chunks(
            v2_file_path, 
            chunk_size=chunk_size,
            n_top_genes=n_top_genes
        )
        
        # Store total size
        self.total_cells = len(self.v1_barcodes)
        print(f"Total cells: {self.total_cells}")
        print(f"V1 genes: {self.v1_data.shape[1]}")
        print(f"V2 genes: {self.v2_chunks[0].shape[1]}")
    
    def __len__(self):
        return self.total_cells
    
    def __getitem__(self, idx):
        v1_sample = self.v1_data[idx]
        chunk_idx = idx // self.chunk_size
        local_idx = idx % self.chunk_size
        v2_sample = self.v2_chunks[chunk_idx][local_idx]
        return v1_sample, v2_sample

In [None]:
class MultiEncoderAutoencoder(nn.Module):
    def __init__(self, input_dim_v1=541, input_dim_v2=5000, latent_dim=128):  # Updated V2 dim
        super().__init__()
        
        # Encoder for Xenium V1 (300 genes)
        self.encoder_v1 = nn.Sequential(
            nn.Linear(input_dim_v1, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, latent_dim),
            nn.ReLU()
        )
        
        # Encoder for Xenium V2 (5000 genes)
        self.encoder_v2 = nn.Sequential(
            nn.Linear(input_dim_v2, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, latent_dim),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, input_dim_v2),
            nn.ReLU()  # ReLU since gene expression is non-negative
        )
        
    def forward(self, x_v1, x_v2=None, inference=False):
        if inference:
            latent = self.encoder_v1(x_v1)
        else:
            latent_v1 = self.encoder_v1(x_v1)
            latent_v2 = self.encoder_v2(x_v2)
            latent = (latent_v1 + latent_v2) / 2
        return self.decoder(latent)

In [None]:
# Initialize dataset with reduced dimensionality
print("Initializing dataset...")
dataset = CachedGeneExpressionDataset(
    "cell_feature_matrix_v1.h5",
    "cell_feature_matrix_v2.h5",
    chunk_size=2000,  # Reduced chunk size
    n_top_genes=5000  # Select top 5000 variable genes
)

# Create dataloader
train_loader = DataLoader(
    dataset, 
    batch_size=32,  # Reduced batch size
    shuffle=True,
    num_workers=2,  # Reduced workers
    pin_memory=True
)

# Initialize model and training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiEncoderAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop with progress bar
num_epochs = 100
print(f"Training on {device}")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    batch_count = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for batch_v1, batch_v2 in pbar:
        batch_v1 = batch_v1.to(device, non_blocking=True)
        batch_v2 = batch_v2.to(device, non_blocking=True)
        
        output = model(batch_v1, batch_v2)
        loss = criterion(output, batch_v2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        batch_count += 1
        pbar.set_postfix({'loss': loss.item():.4f})
    
    avg_loss = total_loss / batch_count
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'checkpoint_epoch_{epoch+1}.pth')