In [None]:
# RNA to H&E Cell Image Generator with Rectified Flow
# ===================================================

# Cell 1: Import libraries
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import logging
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Cell 2: Import custom modules - adjust the path as needed
# If your notebook is in a different location than the original script,
# you may need to modify this path
notebook_dir = os.path.dirname(os.getcwd())
if notebook_dir not in sys.path:
    sys.path.append(notebook_dir)

# Import our modules - uncomment these after ensuring the modules are available
from dataset import CellImageGeneDataset, PatchImageGeneDataset
from single_model import RNAtoHnEModel
from multi_model import MultiCellRNAtoHnEModel
from rectified_flow import RectifiedFlow, EulerSolver
from train import train_with_rectified_flow, generate_images_with_rectified_flow
from utils import setup_parser, parse_adata, analyze_gene_importance

# Cell 3: Set parameters (replacing command-line arguments)
# Configuration parameters
config = {
    'gene_expr': "cell_256_aux/normalized.csv",
    'image_paths': "cell_256_aux/input/cell_image_paths.json",
    'patch_image_paths': "cell_256_aux/input/patch_image_paths.json",
    'patch_cell_mapping': "cell_256_aux/input/patch_cell_mapping.json",
    'output_dir': 'cell_256_aux/output_rectified',
    'epochs': 10,
    'batch_size': 6,
    'lr': 1e-4,
    'weight_decay': 0.01,
    'img_size': 256,
    'img_channels': 4,
    'use_amp': False,
    'patience': 5,
    'gen_steps': 100,
    'seed': np.random.randint(100),
    'adata': None  # Set this to your AnnData path if you're using AnnData
}

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)

# Cell 4: Set device and random seed
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(config['seed'])
np.random.seed(config['seed'])

# Cell 5: Load gene expression data
# Load gene expression data
missing_gene_symbols = None
if config['adata'] is not None:
    logger.info(f"Loading AnnData from {config['adata']}")
    # You'd need to define parse_adata arguments to match the original function
    expr_df, missing_gene_symbols = parse_adata(config)
else:
    logger.warning(f"Loading gene expression data from {config['gene_expr']}")
    expr_df = pd.read_csv(config['gene_expr'], index_col=0)
logger.info(f"Loaded gene expression data with shape: {expr_df.shape}")
gene_names = expr_df.columns.tolist()

# Cell 6: Load image paths
logger.info(f"Loading image paths from {config['image_paths']}")
with open(config['image_paths'], "r") as f:
    image_paths = json.load(f)
logger.info(f"Loaded {len(image_paths)} cell image paths")

# Filter out non-existent image paths
image_paths_tmp = {}
for k, v in image_paths.items():
    if os.path.exists(v):
        image_paths_tmp[k] = v
image_paths = image_paths_tmp
logger.info(f"Found {len(image_paths)} existing cell images")

# Cell 7: Create dataset
# Load patch_to_cells mapping if needed
with open(config['patch_cell_mapping'], "r") as f:
    patch_cell_mappings = json.load(f)

# Load patch image paths if needed
with open(config['patch_image_paths'], "r") as f:
    patch_image_paths = json.load(f)

# Create PatchImageGeneDataset
dataset = PatchImageGeneDataset(
    expr_df=expr_df,
    patch_image_paths=patch_image_paths,  # Simple dict mapping cell_id -> image_path
    patch_to_cells=patch_cell_mappings,  # Dict mapping patch_id -> image_path
    img_size=config['img_size'],
)

# Cell 8: Create DataLoader
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
)

logger.info(f"Created DataLoader with {len(train_loader)} training batches and {len(val_loader)} validation batches")

# Cell 9: Define model
# Define input and output dimensions
input_dim = len(gene_names)  # Number of genes
output_dim = config['img_channels'] * config['img_size'] * config['img_size']  # Image dimensions

# Create the model
model = MultiCellRNAtoHnEModel(
    input_dim=input_dim,
    output_dim=output_dim,
    hidden_dims=[512, 256, 128],
    img_channels=config['img_channels'],
    img_size=config['img_size'],
).to(device)

logger.info(f"Created model with input dimension {input_dim} and output dimension {output_dim}")

# Cell 10: Setup Rectified Flow and optimizer
# Setup Rectified Flow
flow = RectifiedFlow(model=model)
solver = EulerSolver(flow=flow, steps=config['gen_steps'])

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay'],
)

# Cell 11: Train the model
# Train the model with rectified flow
train_losses, val_losses = train_with_rectified_flow(
    flow=flow,
    solver=solver,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device,
    epochs=config['epochs'],
    output_dir=config['output_dir'],
    patience=config['patience'],
    use_amp=config['use_amp'],
)

# Cell 12: Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Losses')
plt.savefig(os.path.join(config['output_dir'], 'loss_curve.png'))
plt.show()

# Cell 13: Generate images with trained model
# Select a few samples for generation
num_samples = 5
sample_indices = np.random.choice(len(val_dataset), num_samples, replace=False)
sample_data = [val_dataset[i] for i in sample_indices]

# Generate images
generated_images = generate_images_with_rectified_flow(
    flow=flow,
    solver=solver,
    samples=sample_data,
    device=device,
    img_size=config['img_size'],
    img_channels=config['img_channels'],
)

# Cell 14: Visualize generated images
fig, axes = plt.subplots(num_samples, 2, figsize=(10, 3 * num_samples))

for i, (sample, gen_img) in enumerate(zip(sample_data, generated_images)):
    # Original image
    if config['img_channels'] == 3:
        axes[i, 0].imshow(sample['image'].permute(1, 2, 0).cpu().numpy())
    else:
        axes[i, 0].imshow(sample['image'][0].cpu().numpy(), cmap='gray')
    axes[i, 0].set_title(f"Original Image {i+1}")
    axes[i, 0].axis('off')
    
    # Generated image
    if config['img_channels'] == 3:
        axes[i, 1].imshow(gen_img.permute(1, 2, 0).cpu().numpy())
    else:
        axes[i, 1].imshow(gen_img[0].cpu().numpy(), cmap='gray')
    axes[i, 1].set_title(f"Generated Image {i+1}")
    axes[i, 1].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(config['output_dir'], 'sample_generated_images.png'))
plt.show()

# Cell 15: Analyze gene importance
# Analyze which genes are most important for image generation
importance_scores = analyze_gene_importance(model, gene_names, device)

# Plot top 20 most important genes
plt.figure(figsize=(12, 6))
top_indices = np.argsort(importance_scores)[-20:]
plt.barh(np.array(gene_names)[top_indices], importance_scores[top_indices])
plt.xlabel('Importance Score')
plt.ylabel('Gene Name')
plt.title('Top 20 Most Important Genes for Image Generation')
plt.tight_layout()
plt.savefig(os.path.join(config['output_dir'], 'gene_importance.png'))
plt.show()

# Cell 16: Save the model
# Save the trained model
model_save_path = os.path.join(config['output_dir'], 'rna_to_image_model.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'config': config,
    'gene_names': gene_names,
}, model_save_path)

logger.info(f"Model saved to {model_save_path}")

2025-05-05 16:06:39,699 - __main__ - INFO - Using device: cuda
2025-05-05 16:06:40,479 - __main__ - INFO - Loaded gene expression data with shape: (87499, 382)
2025-05-05 16:06:40,480 - __main__ - INFO - Loading image paths from cell_256_aux/input/cell_image_paths.json
2025-05-05 16:06:40,481 - __main__ - INFO - Loaded 1483 cell image paths
2025-05-05 16:06:40,482 - __main__ - INFO - Found 1483 existing cell images
2025-05-05 16:06:40,528 - dataset - INFO - Dataset contains 21805 valid patches
2025-05-05 16:06:40,529 - dataset - INFO - Total number of cells across all patches: 412871
2025-05-05 16:06:40,531 - __main__ - INFO - Created DataLoader with 2908 training batches and 727 validation batches


TypeError: RNAtoHnEModel.__init__() got an unexpected keyword argument 'input_dim'