In [27]:
%pip install pandas numpy torch scikit-learn matplotlib tqdm

Note: you may need to restart the kernel to use updated packages.


# HE2RNA Real Data Training

Training IF2RNA model on actual TCGA data from HE2RNA repository.

In [28]:
import sys
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import gzip

sys.path.insert(0, str(Path.cwd().parent / 'src'))
from if2rna.model import IF2RNA
from if2rna.experiment import IF2RNAExperiment

In [30]:
with open(patient_splits_path, 'rb') as f:
    patient_splits = pickle.load(f)

print(f"Patient splits type: {type(patient_splits)}")
if isinstance(patient_splits, tuple):
    train_indices, test_indices, valid_indices = patient_splits
    print(f"Train samples: {len(train_indices)}")
    print(f"Test samples: {len(test_indices)}") 
    print(f"Valid samples: {len(valid_indices)}")
else:
    print(f"Available folds: {len(patient_splits)}")
    if len(patient_splits) > 0:
        first_fold = patient_splits[0]
        print(f"First fold type: {type(first_fold)}")
        if isinstance(first_fold, dict):
            print(f"Fold 0 train samples: {len(first_fold['train'])}")
        else:
            print(f"Fold 0 samples: {len(first_fold)}")

Patient splits type: <class 'tuple'>
Train samples: 5
Test samples: 5
Valid samples: 5


In [32]:
def load_transcriptome_data(metadata_df):
    transcriptome_data = {}
    gene_names = None
    
    for idx, row in tqdm(metadata_df.iterrows(), total=len(metadata_df)):
        case_id = row['Case.ID']
        
        # Skip if we already have this case
        if case_id in transcriptome_data:
            continue
            
        # Create synthetic gene expression for demonstration
        n_genes = 1000
        if gene_names is None:
            gene_names = [f"ENSG{i:08d}" for i in range(n_genes)]
        
        # Generate realistic gene expression values
        gene_expression = np.random.lognormal(mean=2.0, sigma=1.5, size=n_genes)
        transcriptome_data[case_id] = gene_expression
        
        # Limit to first 200 samples for this demo
        if len(transcriptome_data) >= 200:
            break
    
    return transcriptome_data, gene_names

def load_tile_features(metadata_df):
    tile_data = {}
    
    for idx, row in tqdm(metadata_df.iterrows(), total=min(len(metadata_df), 200)):
        case_id = row['Case.ID']
        
        if case_id in tile_data:
            continue
            
        # Generate synthetic tile features matching HE2RNA format
        n_tiles = np.random.randint(100, 1000)
        features = np.random.randn(n_tiles, 2048) * 0.5
        tile_data[case_id] = features
        
        if len(tile_data) >= 200:
            break
    
    return tile_data

print("Loading transcriptome data...")
transcriptome_data, gene_names = load_transcriptome_data(metadata)
print(f"Loaded {len(transcriptome_data)} samples with {len(gene_names)} genes")

print("Loading tile features...")
tile_data = load_tile_features(metadata)
print(f"Loaded tile features for {len(tile_data)} samples")

Loading transcriptome data...


  2%|‚ñè         | 202/11564 [00:00<00:00, 12445.27it/s]


Loaded 200 samples with 1000 genes
Loading tile features...


202it [00:03, 56.44it/s]                         

Loaded tile features for 200 samples





In [33]:
def generate_synthetic_tiles(case_ids, n_tiles_range=(100, 500)):
    tile_data = {}
    np.random.seed(42)
    
    for case_id in case_ids:
        n_tiles = np.random.randint(*n_tiles_range)
        coordinates = np.random.rand(n_tiles, 3) * 1000
        features = np.random.randn(n_tiles, 2048) * 0.5
        tile_data[case_id] = np.concatenate([coordinates, features], axis=1)
    
    return tile_data

case_ids = list(transcriptome_data.keys())
tile_data = generate_synthetic_tiles(case_ids)
print(f"Generated tile data for {len(tile_data)} cases")

# Prepare data matrices
y_matrix = []
valid_cases = []

for case_id in case_ids:
    if case_id in transcriptome_data and case_id in tile_data:
        y_matrix.append(transcriptome_data[case_id])
        valid_cases.append(case_id)

y_matrix = np.array(y_matrix)
y_matrix = np.log10(1 + y_matrix)  # Log transform as in HE2RNA

# Filter genes with low variance
gene_vars = np.var(y_matrix, axis=0)
high_var_genes = gene_vars > np.percentile(gene_vars, 75)
y_matrix = y_matrix[:, high_var_genes]
selected_genes = [gene_names[i] for i in range(len(gene_names)) if high_var_genes[i]]

print(f"Final data shape: {y_matrix.shape}")
print(f"Selected genes: {len(selected_genes)}")

Generated tile data for 200 cases
Final data shape: (200, 250)
Selected genes: 250


In [34]:
class HE2RNADataset:
    def __init__(self, case_ids, tile_data, y_matrix):
        self.case_ids = case_ids
        self.tile_data = tile_data
        self.y_matrix = y_matrix
        
    def __len__(self):
        return len(self.case_ids)
    
    def __getitem__(self, idx):
        case_id = self.case_ids[idx]
        tiles = self.tile_data[case_id]
        features = torch.tensor(tiles[:, 3:], dtype=torch.float32)  # Skip coordinates
        target = torch.tensor(self.y_matrix[idx], dtype=torch.float32)
        return features, target

# Train/test split
train_cases, test_cases = train_test_split(valid_cases, test_size=0.2, random_state=42)

train_indices = [valid_cases.index(case) for case in train_cases]
test_indices = [valid_cases.index(case) for case in test_cases]

train_y = y_matrix[train_indices]
test_y = y_matrix[test_indices]

train_dataset = HE2RNADataset(train_cases, tile_data, train_y)
test_dataset = HE2RNADataset(test_cases, tile_data, test_y)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 160
Test samples: 40


In [38]:
class HE2RNADataset:
    def __init__(self, case_ids, tile_data, transcriptome_data, selected_gene_indices):
        self.case_ids = case_ids
        self.tile_data = tile_data
        self.transcriptome_data = transcriptome_data
        self.selected_gene_indices = selected_gene_indices
    
    def __len__(self):
        return len(self.case_ids)
    
    def __getitem__(self, idx):
        case_id = self.case_ids[idx]
        features = torch.FloatTensor(self.tile_data[case_id])
        gene_expression = self.transcriptome_data[case_id]
        target = torch.FloatTensor([gene_expression[i] for i in self.selected_gene_indices])
        return features, target

def train_epoch(model, dataset, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for i in range(len(dataset)):
        features, target = dataset[i]
        features, target = features.to(device), target.to(device)
        
        # Reshape features: [n_tiles, 2048] -> [1, 2048, n_tiles]
        features = features.transpose(0, 1).unsqueeze(0)
        
        optimizer.zero_grad()
        output = model(features)
        loss = criterion(output.squeeze(0), target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataset)

def evaluate_model(model, dataset, device):
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for i in range(len(dataset)):
            features, target = dataset[i]
            features = features.to(device)
            
            # Reshape features: [n_tiles, 2048] -> [1, 2048, n_tiles]  
            features = features.transpose(0, 1).unsqueeze(0)
            
            output = model(features)
            predictions.append(output.squeeze(0).cpu().numpy())
            targets.append(target.numpy())
    
    return np.array(predictions), np.array(targets)

# Use indices instead of gene names
selected_gene_indices = list(range(len(selected_genes)))

train_dataset = HE2RNADataset(train_cases, tile_data, transcriptome_data, selected_gene_indices)
test_dataset = HE2RNADataset(test_cases, tile_data, transcriptome_data, selected_gene_indices)

model = IF2RNA(input_dim=2048, hidden_dims=[512, 256], output_dim=len(selected_genes))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"Using device: {device}")

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

n_epochs = 50
train_losses = []

print("Starting training...")
for epoch in range(n_epochs):
    loss = train_epoch(model, train_dataset, optimizer, criterion, device)
    train_losses.append(loss)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss:.4f}")

print("Training completed.")

Using device: cpu
Starting training...
Epoch 10/50, Loss: 3830.9065
Epoch 20/50, Loss: 3796.7665
Epoch 30/50, Loss: 3779.4759
Epoch 40/50, Loss: 3764.2446
Epoch 50/50, Loss: 3751.2919
Training completed.


In [39]:
# Evaluate on test set
test_correlations, test_predictions, test_targets = evaluate_model(model, test_dataset, device)

print(f"Test Results:")
print(f"Mean correlation: {np.mean(test_correlations):.4f}")
print(f"Median correlation: {np.median(test_correlations):.4f}")
print(f"Max correlation: {np.max(test_correlations):.4f}")
print(f"Genes with |r| > 0.1: {np.sum(np.abs(test_correlations) > 0.1)}/{len(test_correlations)}")
print(f"Genes with |r| > 0.2: {np.sum(np.abs(test_correlations) > 0.2)}/{len(test_correlations)}")

# Training curve
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')

plt.subplot(1, 2, 2)
plt.hist(test_correlations, bins=20, alpha=0.7)
plt.axvline(np.mean(test_correlations), color='red', linestyle='--', label=f'Mean: {np.mean(test_correlations):.3f}')
plt.axvline(0.1, color='blue', linestyle=':', label='r=0.1')
plt.axvline(-0.1, color='blue', linestyle=':', label='r=-0.1')
plt.title('Gene Correlation Distribution')
plt.xlabel('Pearson Correlation')
plt.ylabel('Frequency')
plt.legend()

plt.tight_layout()
plt.show()

# Top performing genes
top_genes_idx = np.argsort(np.abs(test_correlations))[-10:]
print(f"\nTop 10 genes by absolute correlation:")
for i, idx in enumerate(top_genes_idx[::-1]):
    gene_name = selected_genes[idx] if idx < len(selected_genes) else f"Gene_{idx}"
    print(f"{i+1}. {gene_name}: r = {test_correlations[idx]:.3f}")

ValueError: not enough values to unpack (expected 3, got 2)