# HE2RNA Validation Test

In [None]:
import sys
import os
import configparser
import pickle as pkl
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from pathlib import Path

sys.path.append('../src')
from if2rna.model import IF2RNA, fit, evaluate, predict
from if2rna.data import create_synthetic_data, IF2RNADataset

he2rna_path = Path('../external/HE2RNA_code')
config_path = he2rna_path / 'configs' / 'config_all_genes.ini'

In [None]:
config = configparser.ConfigParser()
config.read(config_path)

layers = [int(x) for x in config['architecture']['layers'].split(',')]
ks = [int(x) for x in config['architecture']['ks'].split(',')]
dropout = float(config['architecture']['dropout'])
batch_size = int(config['training']['batch_size'])
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Config loaded: layers={layers}, ks={ks}, dropout={dropout}")
print(f"Device: {device}")

In [None]:
n_samples = 500
n_genes = 1000
input_dim = 2048

X, y, patients, projects = create_synthetic_data(
    n_samples=n_samples,
    n_tiles=100,
    feature_dim=input_dim,
    n_genes=n_genes
)

genes = [f"ENSG{i:08d}" for i in range(n_genes)]
dataset = IF2RNADataset(genes, patients, projects, X, y)

train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], 
    generator=torch.Generator().manual_seed(42)
)

print(f"Dataset sizes: train={len(train_set)}, val={len(val_set)}, test={len(test_set)}")

In [None]:
model = IF2RNA(
    input_dim=input_dim,
    output_dim=n_genes,
    layers=layers,
    ks=ks,
    dropout=dropout,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
val_projects = np.array([projects[i] for i in val_set.indices])

training_params = {
    'max_epochs': 10,
    'patience': 5,
    'batch_size': batch_size,
    'num_workers': 0
}

preds, labels = fit(
    model=model,
    train_set=train_set,
    valid_set=val_set,
    valid_projects=val_projects,
    params=training_params,
    optimizer=optimizer,
    test_set=test_set,
    logdir='./logs_validation'
)

In [None]:
from scipy.stats import pearsonr

gene_correlations = []
for i in range(labels.shape[1]):
    if len(np.unique(labels[:, i])) > 1:
        corr, _ = pearsonr(labels[:, i], preds[:, i])
        gene_correlations.append(corr if not np.isnan(corr) else 0.0)
    else:
        gene_correlations.append(0.0)

gene_correlations = np.array(gene_correlations)
overall_corr = pearsonr(labels.flatten(), preds.flatten())[0]

print(f"Overall correlation: {overall_corr:.4f}")
print(f"Mean gene correlation: {np.mean(gene_correlations):.4f}")
print(f"Median gene correlation: {np.median(gene_correlations):.4f}")
print(f"Max gene correlation: {np.max(gene_correlations):.4f}")
print(f"Significant genes (|r| > 0.1): {np.sum(np.abs(gene_correlations) > 0.1)}/{len(gene_correlations)}")

In [None]:
results_df = pd.DataFrame({
    'gene_id': genes,
    'correlation': gene_correlations,
    'abs_correlation': np.abs(gene_correlations)
})

results_df = results_df.sort_values('abs_correlation', ascending=False)

top_genes = results_df.head(20)
print("Top 20 genes:")
print(top_genes[['gene_id', 'correlation']].to_string(index=False))

results_df.to_csv('he2rna_validation_results.csv', index=False)
np.save('he2rna_validation_predictions.npy', preds)
np.save('he2rna_validation_labels.npy', labels)

## Paper Comparison

In [None]:
paper_results = {
    'BRCA': {'samples': 1085, 'significant_genes': 786, 'correlation_threshold': 0.4},
    'LUNG': {'samples': 1046, 'significant_genes': 15391, 'correlation_threshold': 0.2},
    'LIHC': {'samples': 371, 'significant_genes': 765, 'correlation_threshold': 0.4},
    'COAD': {'samples': 463, 'significant_genes': 324, 'correlation_threshold': None},
    'DLBC': {'samples': 44, 'significant_genes': 7, 'correlation_threshold': 0.64}
}

our_results = {
    'samples': n_samples,
    'genes_tested': n_genes,
    'significant_genes_01': np.sum(np.abs(gene_correlations) > 0.1),
    'significant_genes_02': np.sum(np.abs(gene_correlations) > 0.2),
    'significant_genes_03': np.sum(np.abs(gene_correlations) > 0.3),
    'max_correlation': np.max(np.abs(gene_correlations)),
    'median_correlation': np.median(np.abs(gene_correlations))
}

print(f"Paper thresholds: DLBC (44 samples) R>0.64, LUNG (1046) R>0.20")
print(f"Our test (500 samples): max R = {our_results['max_correlation']:.3f}")
print(f"Paper BRCA: {786/17759*100:.1f}% genes (R>0.4)")
print(f"Our test: {our_results['significant_genes_01']/n_genes*100:.1f}% genes (R>0.1)")

In [None]:
architecture_match = all([
    len(layers) == 2 and layers == [1024, 1024],
    len(ks) == 7 and ks == [1, 2, 5, 10, 20, 50, 100],
    dropout == 0.25
])

performance_reasonable = all([
    our_results['max_correlation'] > 0.3,
    our_results['significant_genes_01'] > 300,
    training_params['max_epochs'] > 5
])

print(f"Architecture: {'OK' if architecture_match else 'FAIL'}")
print(f"  Layers: {layers}")
print(f"  Top-k: {ks}")
print(f"  Dropout: {dropout}")

print(f"\nPerformance: {'OK' if performance_reasonable else 'FAIL'}")
print(f"  Max correlation: {our_results['max_correlation']:.3f}")
print(f"  Significant genes: {our_results['significant_genes_01']}/1000")