# Pediatric Osteosarcoma Disease Progression Analysis

This notebook provides comprehensive analysis and visualization of:
1. Real vs. Synthetic patient comparison
2. Latent space exploration
3. Conditional generation quality
4. Biological validation results
5. Survival analysis

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
import yaml

# Visualization settings
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

%matplotlib inline

## 1. Load Data

In [None]:
# Load configuration
with open('../config/config.yaml') as f:
    config = yaml.safe_load(f)

# Paths
processed_dir = Path(config['data']['processed_dir'])
synthetic_dir = Path(config['output']['synthetic_data_dir'])
results_dir = Path(config['output']['results_dir'])

# Load real data
real_mutations = pd.read_csv(processed_dir / 'mutation_matrix_aligned.csv', index_col=0)
real_expression = pd.read_csv(processed_dir / 'expression_matrix_aligned.csv', index_col=0)
real_pathways = pd.read_csv(processed_dir / 'pathway_scores.csv', index_col=0)
real_clinical = pd.read_csv(processed_dir / 'clinical_aligned.csv')

print(f"Real data: {len(real_mutations)} samples")
print(f"Mutations: {real_mutations.shape[1]} genes")
print(f"Expression: {real_expression.shape[1]} genes")
print(f"Pathways: {real_pathways.shape[1]} pathways")

In [None]:
# Load synthetic data (all scenarios combined)
synth_mutations_list = []
synth_expression_list = []
synth_pathways_list = []
synth_conditions_list = []

for scenario in config['generation']['scenarios']:
    scenario_name = scenario['name']
    scenario_dir = synthetic_dir / scenario_name
    
    synth_mut = pd.read_csv(scenario_dir / f"{scenario_name}_mutations.csv")
    synth_expr = pd.read_csv(scenario_dir / f"{scenario_name}_expression.csv")
    synth_path = pd.read_csv(scenario_dir / f"{scenario_name}_pathways.csv")
    synth_cond = pd.read_csv(scenario_dir / f"{scenario_name}_conditions.csv")
    
    # Add scenario label
    synth_mut['scenario'] = scenario_name
    synth_expr['scenario'] = scenario_name
    synth_path['scenario'] = scenario_name
    synth_cond['scenario'] = scenario_name
    
    synth_mutations_list.append(synth_mut)
    synth_expression_list.append(synth_expr)
    synth_pathways_list.append(synth_path)
    synth_conditions_list.append(synth_cond)

synth_mutations = pd.concat(synth_mutations_list, ignore_index=True)
synth_expression = pd.concat(synth_expression_list, ignore_index=True)
synth_pathways = pd.concat(synth_pathways_list, ignore_index=True)
synth_conditions = pd.concat(synth_conditions_list, ignore_index=True)

print(f"\nSynthetic data: {len(synth_mutations)} samples")
print(f"Scenarios: {synth_mutations['scenario'].unique()}")

## 2. Mutation Analysis

In [None]:
# Mutation frequency comparison
real_freq = real_mutations.mean(axis=0).sort_values(ascending=False)
synth_freq = synth_mutations.drop('scenario', axis=1).mean(axis=0).sort_values(ascending=False)

# Top 20 mutated genes
top_genes = real_freq.head(20).index

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Real vs Synthetic frequency scatter
common_genes = real_freq.index.intersection(synth_freq.index)
axes[0].scatter(real_freq[common_genes], synth_freq[common_genes], alpha=0.5)
axes[0].plot([0, 1], [0, 1], 'r--', label='Perfect match')
axes[0].set_xlabel('Real Mutation Frequency')
axes[0].set_ylabel('Synthetic Mutation Frequency')
axes[0].set_title('Mutation Frequency: Real vs Synthetic')
axes[0].legend()

corr = np.corrcoef(real_freq[common_genes], synth_freq[common_genes])[0, 1]
axes[0].text(0.05, 0.95, f'Correlation: {corr:.3f}', 
            transform=axes[0].transAxes, fontsize=12, verticalalignment='top')

# Top 20 genes comparison
comparison_df = pd.DataFrame({
    'Real': real_freq[top_genes],
    'Synthetic': synth_freq[top_genes]
})

comparison_df.plot(kind='bar', ax=axes[1])
axes[1].set_xlabel('Gene')
axes[1].set_ylabel('Mutation Frequency')
axes[1].set_title('Top 20 Mutated Genes')
axes[1].legend(['Real', 'Synthetic'])
plt.xticks(rotation=45, ha='right')

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'mutation_frequency_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Driver gene analysis
driver_genes = config['evaluation']['driver_genes']
driver_genes_present = [g for g in driver_genes if g in real_mutations.columns]

if driver_genes_present:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    driver_comparison = pd.DataFrame({
        'Real': real_mutations[driver_genes_present].mean(),
        'Synthetic': synth_mutations[driver_genes_present].mean()
    })
    
    driver_comparison.plot(kind='bar', ax=ax)
    ax.set_xlabel('Driver Gene')
    ax.set_ylabel('Mutation Frequency')
    ax.set_title('Osteosarcoma Driver Genes: Real vs Synthetic')
    ax.legend(['Real', 'Synthetic'])
    plt.xticks(rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(results_dir / 'figures' / 'driver_genes.png', dpi=300, bbox_inches='tight')
    plt.show()

## 3. Pathway Analysis

In [None]:
# Pathway score distributions
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Select 6 important pathways
pathways_to_plot = [
    'HALLMARK_P53_PATHWAY',
    'HALLMARK_APOPTOSIS',
    'HALLMARK_E2F_TARGETS',
    'HALLMARK_G2M_CHECKPOINT',
    'HALLMARK_DNA_REPAIR',
    'HALLMARK_HYPOXIA'
]

pathways_to_plot = [p for p in pathways_to_plot if p in real_pathways.columns]

for i, pathway in enumerate(pathways_to_plot[:6]):
    axes[i].hist(real_pathways[pathway], bins=30, alpha=0.5, label='Real', density=True)
    axes[i].hist(synth_pathways[pathway], bins=30, alpha=0.5, label='Synthetic', density=True)
    axes[i].set_xlabel('Pathway Score')
    axes[i].set_ylabel('Density')
    axes[i].set_title(pathway.replace('HALLMARK_', '').replace('_', ' '))
    axes[i].legend()

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'pathway_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

## 4. Latent Space Visualization

In [None]:
# UMAP visualization of latent space
from umap import UMAP

# Combine real and synthetic data
real_data = np.concatenate([
    real_mutations.values,
    real_expression.values,
    real_pathways.values
], axis=1)

synth_data = np.concatenate([
    synth_mutations.drop('scenario', axis=1).values,
    synth_expression.drop('scenario', axis=1).values,
    synth_pathways.drop('scenario', axis=1).values
], axis=1)

# UMAP embedding
combined_data = np.vstack([real_data, synth_data])
labels = ['Real'] * len(real_data) + ['Synthetic'] * len(synth_data)

reducer = UMAP(n_components=2, random_state=42)
embedding = reducer.fit_transform(combined_data)

# Plot
fig, ax = plt.subplots(figsize=(12, 8))

for label in ['Real', 'Synthetic']:
    mask = np.array(labels) == label
    ax.scatter(embedding[mask, 0], embedding[mask, 1], 
              alpha=0.5, s=50, label=label)

ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_title('UMAP Embedding: Real vs Synthetic Patients')
ax.legend()

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'umap_real_vs_synthetic.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Color by scenario
fig, ax = plt.subplots(figsize=(12, 8))

# Real data
real_mask = np.array(labels) == 'Real'
ax.scatter(embedding[real_mask, 0], embedding[real_mask, 1], 
          alpha=0.5, s=50, c='gray', label='Real')

# Synthetic by scenario
synth_scenarios = synth_mutations['scenario'].values
scenario_colors = {'early_stage_good_prognosis': 'green',
                  'metastatic_poor_prognosis': 'red',
                  'typical_patient': 'blue'}

synth_start_idx = len(real_data)
for scenario, color in scenario_colors.items():
    scenario_mask = synth_scenarios == scenario
    scenario_embedding = embedding[synth_start_idx:][scenario_mask]
    ax.scatter(scenario_embedding[:, 0], scenario_embedding[:, 1],
              alpha=0.6, s=50, c=color, label=scenario.replace('_', ' ').title())

ax.set_xlabel('UMAP 1')
ax.set_ylabel('UMAP 2')
ax.set_title('UMAP Embedding: Synthetic Patients by Scenario')
ax.legend()

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'umap_by_scenario.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Survival Analysis

In [None]:
from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test

# Kaplan-Meier curves for real vs synthetic scenarios
fig, ax = plt.subplots(figsize=(12, 8))

kmf = KaplanMeierFitter()

# Real data
kmf.fit(real_clinical['survival_days'], 
        real_clinical['event_occurred'],
        label='Real Patients')
kmf.plot_survival_function(ax=ax, ci_show=True)

# Synthetic scenarios
for scenario in config['generation']['scenarios']:
    scenario_name = scenario['name']
    scenario_mask = synth_conditions['scenario'] == scenario_name
    
    # Denormalize survival days
    survival_days = (synth_conditions.loc[scenario_mask, 'survival_days_norm'] * 500) + 800
    events = synth_conditions.loc[scenario_mask, 'event_occurred'].values
    
    kmf.fit(survival_days, events, 
           label=scenario_name.replace('_', ' ').title())
    kmf.plot_survival_function(ax=ax, ci_show=False)

ax.set_xlabel('Time (days)')
ax.set_ylabel('Survival Probability')
ax.set_title('Kaplan-Meier Survival Curves')
ax.legend()

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'kaplan_meier.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Validation Results

In [None]:
# Load validation results
validation_results = pd.read_csv(results_dir / 'validation_results.csv')

print("=" * 60)
print("BIOLOGICAL VALIDATION RESULTS")
print("=" * 60)
print(validation_results.T)

# Visualize key metrics
fig, ax = plt.subplots(figsize=(10, 6))

metrics_to_plot = [
    'mutation_frequency_correlation',
    'cooccurrence_pattern_correlation',
    'pathway_coherence_correlation',
    'overall_biological_score'
]

metrics_present = [m for m in metrics_to_plot if m in validation_results.columns]
values = [validation_results[m].values[0] for m in metrics_present]
labels = [m.replace('_', ' ').title() for m in metrics_present]

colors = ['green' if v > 0.8 else 'orange' if v > 0.6 else 'red' for v in values]
bars = ax.barh(labels, values, color=colors)

ax.set_xlabel('Score')
ax.set_title('Biological Validation Metrics')
ax.set_xlim([0, 1])
ax.axvline(x=0.85, color='red', linestyle='--', label='Target Threshold')
ax.legend()

# Add value labels
for bar, value in zip(bars, values):
    ax.text(value + 0.02, bar.get_y() + bar.get_height()/2, 
           f'{value:.3f}', va='center')

plt.tight_layout()
plt.savefig(results_dir / 'figures' / 'validation_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Summary Report

In [None]:
print("=" * 70)
print("PEDIATRIC OSTEOSARCOMA DISEASE PROGRESSION MODEL - SUMMARY REPORT")
print("=" * 70)

print("\n1. DATA STATISTICS")
print("-" * 70)
print(f"Real samples:                 {len(real_mutations)}")
print(f"Synthetic samples:            {len(synth_mutations)}")
print(f"Number of genes (mutations):  {real_mutations.shape[1]}")
print(f"Number of genes (expression): {real_expression.shape[1]}")
print(f"Number of pathways:           {real_pathways.shape[1]}")

print("\n2. MUTATION ANALYSIS")
print("-" * 70)
print(f"Mutation frequency correlation: {corr:.3f}")
print(f"Most mutated gene (real):       {real_freq.index[0]} ({real_freq.iloc[0]:.2%})")
print(f"Most mutated gene (synthetic):  {synth_freq.index[0]} ({synth_freq.iloc[0]:.2%})")

if 'mutation_frequency_correlation' in validation_results.columns:
    print(f"\nValidation - Mutation frequency correlation: "
          f"{validation_results['mutation_frequency_correlation'].values[0]:.3f}")

print("\n3. BIOLOGICAL VALIDATION")
print("-" * 70)
if 'overall_biological_score' in validation_results.columns:
    score = validation_results['overall_biological_score'].values[0]
    status = "✓ PASS" if score > 0.85 else "⚠ REVIEW" if score > 0.7 else "✗ FAIL"
    print(f"Overall Biological Score:     {score:.3f} {status}")
else:
    print("Overall score not available")

print("\n4. RECOMMENDATIONS")
print("-" * 70)
if score > 0.85:
    print("✓ Synthetic data quality is EXCELLENT")
    print("✓ Ready for downstream analysis")
    print("  - Can be used for data augmentation")
    print("  - Suitable for model pre-training")
elif score > 0.7:
    print("⚠ Synthetic data quality is GOOD but needs improvement")
    print("  - Consider increasing biological constraint weights")
    print("  - Review violation cases manually")
else:
    print("✗ Synthetic data quality is POOR")
    print("  - Retrain with higher constraint weights")
    print("  - Check for data preprocessing issues")

print("\n" + "=" * 70)
print("END OF REPORT")
print("=" * 70)