# IF2RNA with Real GeoMx Data

Real gene expression from GSE289483 (114 ROIs) with simulated IF images (6 channels).

In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

sys.path.insert(0, str(Path.cwd().parent / 'src'))

from if2rna.real_geomx_parser import RealGeoMxDataParser
from if2rna.simulated_if_generator import SimulatedIFGenerator
from if2rna.hybrid_dataset import HybridIF2RNADataset, AggregatedIF2RNADataset, create_train_val_split

print("Imports loaded")

## 1. Load Real GeoMx Gene Expression Data

In [None]:
data_dir = Path("../data/geomx_datasets/GSE289483")

parser = RealGeoMxDataParser(data_dir)

raw_counts = parser.load_raw_counts()
processed_expr = parser.load_processed_expression()
metadata = parser.create_metadata()

print(f"Data loaded: {raw_counts.shape[0]} genes, {raw_counts.shape[1]} samples, {len(metadata)} ROIs")

In [None]:
print("Metadata preview:")
display(metadata.head())

print("Tissue regions:")
print(metadata['tissue_region'].value_counts())

In [None]:
integrated_data = parser.get_integrated_data(use_processed=True, n_genes=1000)

print(f"Data: {integrated_data['metadata']['n_rois']} ROIs, {integrated_data['metadata']['n_genes']} genes")
print(f"Expression shape: {integrated_data['gene_expression'].shape}")

## 2. Visualize Real Gene Expression Data

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

expr_values = integrated_data['gene_expression'].values.flatten()

axes[0].hist(expr_values, bins=50, alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Expression Value')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Expression Distribution')

axes[1].hist(np.log1p(expr_values[expr_values > 0]), bins=50, alpha=0.7, edgecolor='black', color='orange')
axes[1].set_xlabel('log(Expression + 1)')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Log-Transformed Expression')

mean_expr = integrated_data['gene_expression'].mean(axis=0)
axes[2].hist(mean_expr, bins=50, alpha=0.7, edgecolor='black', color='green')
axes[2].set_xlabel('Mean Expression')
axes[2].set_ylabel('Number of Genes')
axes[2].set_title('Mean Expression per Gene')

plt.tight_layout()
plt.show()

print(f"Stats: mean={expr_values.mean():.2f}, median={np.median(expr_values):.2f}, std={expr_values.std():.2f}")

In [None]:
# Correlation matrix (sample subset)
expr_df = integrated_data['gene_expression']
sample_genes = np.random.choice(expr_df.shape[1], min(50, expr_df.shape[1]), replace=False)

plt.figure(figsize=(10, 8))
corr_matrix = expr_df.iloc[:, sample_genes].corr()
sns.heatmap(corr_matrix, cmap='coolwarm', center=0, square=True, 
            xticklabels=False, yticklabels=False, cbar_kws={'label': 'Correlation'})
plt.title('Gene-Gene Correlation (50 random genes)')
plt.tight_layout()
plt.show()

## 3. Generate Simulated IF Images

In [None]:
if_generator = SimulatedIFGenerator(image_size=224, seed=42)

print(f"Generator: {if_generator.image_size}x{if_generator.image_size}, {if_generator.n_channels} channels")

In [None]:
tissue_types = ['Tumor', 'Immune_Aggregate', 'Stroma', 'Normal']

fig, axes = plt.subplots(len(tissue_types), if_generator.n_channels, figsize=(14, 8))
fig.suptitle('Simulated IF Images by Tissue Type', fontsize=14, fontweight='bold')

for i, tissue in enumerate(tissue_types):
    img = if_generator.generate_for_tissue_type(tissue, seed_offset=i*10)
    
    for ch, channel_name in enumerate(if_generator.channel_names):
        ax = axes[i, ch]
        ax.imshow(img[ch], cmap='hot', vmin=0, vmax=1)
        
        if i == 0:
            ax.set_title(channel_name, fontsize=10)
        if ch == 0:
            ax.set_ylabel(tissue, fontsize=10, fontweight='bold')
        
        ax.axis('off')

plt.tight_layout()
plt.show()

## 4. Create Hybrid Dataset (Real Expression + Simulated IF)

In [None]:
tile_dataset = HybridIF2RNADataset(
    integrated_data=integrated_data,
    if_generator=if_generator,
    n_tiles_per_roi=16
)

print(f"Tile dataset: {len(tile_dataset)} samples, {tile_dataset.n_rois} ROIs, {tile_dataset.n_tiles_per_roi} tiles/ROI")

In [None]:
sample = tile_dataset[0]

print(f"Sample: image {sample['image'].shape}, expression {sample['expression'].shape}")
print(f"ROI: {sample['roi_name']}, tissue: {sample['tissue_type']}")

In [None]:
# Visualize a sample
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle(f"Sample from {sample['tissue_type']} region", fontsize=14, fontweight='bold')

for i, channel_name in enumerate(if_generator.channel_names):
    ax = axes[i // 3, i % 3]
    ax.imshow(sample['image'][i].numpy(), cmap='hot')
    ax.set_title(channel_name)
    ax.axis('off')

plt.tight_layout()
plt.show()

# Plot expression distribution for this sample
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.hist(sample['expression'].numpy(), bins=50, alpha=0.7, edgecolor='black')
plt.xlabel('Expression Value')
plt.ylabel('Frequency')
plt.title('Expression Distribution for this ROI')

plt.subplot(1, 2, 2)
top_genes = sample['expression'].numpy().argsort()[-20:][::-1]
plt.bar(range(20), sample['expression'].numpy()[top_genes])
plt.xlabel('Gene Rank')
plt.ylabel('Expression')
plt.title('Top 20 Expressed Genes')

plt.tight_layout()
plt.show()

## 5. Create Aggregated Dataset (Multiple Tiles per ROI)

In [None]:
agg_dataset = AggregatedIF2RNADataset(
    integrated_data=integrated_data,
    if_generator=if_generator,
    n_tiles_per_roi=16
)

print(f"Aggregated dataset: {len(agg_dataset)} ROIs, {agg_dataset.n_tiles_per_roi} tiles/ROI")

In [None]:
agg_sample = agg_dataset[0]

print(f"Sample: {agg_sample['tiles'].shape}, {agg_sample['tissue_type']}")

## 6. Train/Val Split

In [None]:
train_dataset, val_dataset = create_train_val_split(tile_dataset, val_fraction=0.2, seed=42)

print(f"Split: {len(train_dataset)} train, {len(val_dataset)} val")

## 7. Create DataLoaders for Training

In [None]:
import torch
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"DataLoaders: {len(train_loader)} train batches, {len(val_loader)} val batches")

In [None]:
batch = next(iter(train_loader))

print(f"Batch: images {batch['image'].shape}, expression {batch['expression'].shape}")

## 8. Summary

Real gene expression (GSE289483, 114 ROIs, 1000 genes) combined with simulated 6-channel IF images (224x224). Dataset ready for training.

In [None]:
print(f"Real expression: {integrated_data['metadata']['n_rois']} ROIs, {integrated_data['metadata']['n_genes']} genes")
print(f"Simulated IF: {if_generator.n_channels} channels, {if_generator.image_size}x{if_generator.image_size}")
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")