# TCGA PANCAN Multi-Omics Data Loading and Cox Regression Feature Engineering

This notebook implements comprehensive loading and preprocessing of TCGA PANCAN multi-omics data with Cox regression analysis for feature engineering.

## Data Files Structure:
1. **Transcriptome**: `unc.edu_PANCAN_IlluminaHiSeq_RNASeqV2.geneExp_whitelisted.tsv` (log2 transformed)
2. **Copy Number Variation**: `CNV.GISTIC_call.all_data_by_genes_whitelisted.tsv` (log2 transformed)
3. **microRNA**: `bcgsc.ca_PANCAN_IlluminaHiSeq_miRNASeq.miRNAExp_whitelisted.tsv` (log2 transformed)
4. **RPPA**: `mdanderson.org_PANCAN_MDA_RPPA_Core.RPPA_whitelisted.tsv` (log2 transformed)
5. **Methylation**: `jhu-usc.edu_PANCAN_HumanMethylation450.betaValue_whitelisted.csv` (NO transformation, for tab-transformer)
6. **Mutations**: `tcga_pancancer_082115.vep.filter_whitelisted.maf.gz` (impact scores, NO transformation)
7. **Clinical**: `clinical_PANCAN_patient_with_followup.tsv`

## Output:
- Cox coefficient lookup tables
- Processed multi-omics data (patient × features)
- Methylation data for tab-transformer
- Feature importance rankings
- Data quality reports

## 1. Environment Setup and Library Imports

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import CoxPHFitter
from lifelines.statistics import logrank_test
from lifelines import KaplanMeierFitter
import gzip
from tqdm import tqdm
import pickle
import warnings
import json
import os
from pathlib import Path
from scipy import stats
import statsmodels.api as sm
from collections import defaultdict

warnings.filterwarnings('ignore')

# Set display options
pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)

# Set plot style
plt.style.use('default')
sns.set_palette("husl")

# Define paths
DATA_RAW_PATH = Path('../data/raw')
DATA_PROCESSED_PATH = Path('../data/processed')
RESULTS_PATH = Path('../results')

# Create directories if they don't exist
DATA_PROCESSED_PATH.mkdir(exist_ok=True)
RESULTS_PATH.mkdir(exist_ok=True)

print("Environment setup complete!")
print(f"Raw data path: {DATA_RAW_PATH}")
print(f"Processed data path: {DATA_PROCESSED_PATH}")
print(f"Results path: {RESULTS_PATH}")

Environment setup complete!
Raw data path: ../data/raw
Processed data path: ../data/processed
Results path: ../results


## 2. Data Loading Functions

In [4]:
def standardize_patient_id(patient_id):
    """Standardize TCGA patient IDs to 12-character format (TCGA-XX-XXXX)"""
    if isinstance(patient_id, str):
        # Remove any trailing parts after the sample type (e.g., -01A, -11A)
        parts = patient_id.split('-')
        if len(parts) >= 3:
            return f"{parts[0]}-{parts[1]}-{parts[2]}"
    return patient_id

def load_transcriptome_data(file_path):
    """Load and preprocess transcriptome data with log2 transformation"""
    print("Loading transcriptome data...")
    
    # Load data
    df = pd.read_csv(file_path, sep='\t', index_col=0)
    
    # Parse gene symbols from first column (Gene_Symbol|Entrez_ID)
    gene_info = df.index.str.split('|', expand=True)
    
    # Handle potential IndexError
    if hasattr(gene_info, 'shape') and len(gene_info.shape) > 1 and gene_info.shape[1] >= 2:
        gene_symbols = gene_info.iloc[:, 0]
        entrez_ids = gene_info.iloc[:, 1]
        gene_symbols = gene_symbols.where(gene_symbols != '?', 'Gene_' + entrez_ids.astype(str))
    else:
        # No "|" separator found, use original index
        gene_symbols = df.index
        print("Gene symbols not split - using original index")
    
    # Set gene symbols as index
    df.index = gene_symbols
    
    # Transpose to get patients as rows
    df = df.T
    
    # Standardize patient IDs
    df.index = [standardize_patient_id(pid) for pid in df.index]
    
    # Store original values for comparison
    original_stats = {
        'mean': df.values.mean(),
        'std': df.values.std(),
        'min': df.values.min(),
        'max': df.values.max(),
        'zeros': (df.values == 0).sum()
    }
    
    # Apply log2 transformation: log2(x + 1)
    df_log = np.log2(df + 1)
    
    # Store transformed stats
    transformed_stats = {
        'mean': df_log.values.mean(),
        'std': df_log.values.std(),
        'min': df_log.values.min(),
        'max': df_log.values.max(),
        'zeros': (df_log.values == 0).sum()
    }
    
    transformation_stats = {
        'original': original_stats,
        'transformed': transformed_stats,
        'n_patients': df_log.shape[0],
        'n_genes': df_log.shape[1]
    }
    
    print(f"Transcriptome data loaded: {df_log.shape[0]} patients × {df_log.shape[1]} genes (log2 transformed)")
    
    return df_log, transformation_stats

def load_cnv_data(file_path):
    """Load and preprocess CNV data with log2 transformation"""
    print("Loading CNV data...")
    
    # Load data
    df = pd.read_csv(file_path, sep='\t')
    
    # Skip first 3 annotation columns and set gene symbol as index
    gene_symbols = df.iloc[:, 0]  # First column is Gene Symbol
    df_values = df.iloc[:, 3:]  # Skip first 3 columns (Gene Symbol, Locus ID, Cytoband)
    df_values.index = gene_symbols
    
    # Transpose to get patients as rows
    df_values = df_values.T
    
    # Standardize patient IDs
    df_values.index = [standardize_patient_id(pid) for pid in df_values.index]
    
    # Apply log2 transformation: log2(x + 1) for positive values, handle negatives
    min_val = df_values.values.min()
    if min_val < 0:
        # Shift negative values to make all positive before log transformation
        df_log = np.log2(df_values - min_val + 1)
        print(f"Applied log2(x - {min_val:.3f} + 1) transformation for negative CNV values")
    else:
        df_log = np.log2(df_values + 1)
        print("Applied log2(x + 1) transformation")
    
    print(f"CNV data loaded: {df_log.shape[0]} patients × {df_log.shape[1]} genes (log2 transformed)")
    
    return df_log

def load_mirna_data(file_path):
    """Load and preprocess microRNA data with log2 transformation"""
    print("Loading microRNA data...")
    
    # Load data
    df = pd.read_csv(file_path, sep='\t', index_col=0)
    
    # Transpose to get patients as rows
    df = df.T
    
    # Standardize patient IDs
    df.index = [standardize_patient_id(pid) for pid in df.index]
    
    # Apply log2 transformation: log2(x + 1)
    df_log = np.log2(df + 1)
    
    print(f"microRNA data loaded: {df_log.shape[0]} patients × {df_log.shape[1]} miRNAs (log2 transformed)")
    
    return df_log

def load_rppa_data(file_path):
    """Load and preprocess RPPA protein data with log2 transformation"""
    print("Loading RPPA data...")
    
    # Load data
    df = pd.read_csv(file_path, sep='\t', index_col=0)
    
    # Transpose to get patients as rows
    df = df.T
    
    # Standardize patient IDs
    df.index = [standardize_patient_id(pid) for pid in df.index]
    
    # Apply log2 transformation: log2(x + 1) for positive values, handle negatives
    min_val = df.values.min()
    if min_val < 0:
        # Shift negative values to make all positive before log transformation
        df_log = np.log2(df - min_val + 1)
        print(f"Applied log2(x - {min_val:.3f} + 1) transformation for negative RPPA values")
    else:
        df_log = np.log2(df + 1)
        print("Applied log2(x + 1) transformation")
    
    print(f"RPPA data loaded: {df_log.shape[0]} patients × {df_log.shape[1]} proteins (log2 transformed)")
    
    return df_log

def load_methylation_data(file_path):
    """Load methylation data for tab-transformer (NO log2 transformation)"""
    print("Loading methylation data...")
    print("Note: NO log2 transformation applied - beta values (0-1) for tab-transformer")
    
    # Load data - try both .csv and .tsv extensions
    try:
        df = pd.read_csv(file_path, sep='\t', index_col=0)
    except FileNotFoundError:
        # Try with .csv extension
        csv_path = str(file_path).replace('.tsv', '.csv')
        df = pd.read_csv(csv_path, sep='\t', index_col=0)
    
    # Transpose to get patients as rows
    df = df.T
    
    # Standardize patient IDs
    df.index = [standardize_patient_id(pid) for pid in df.index]
    
    # Check data quality
    missing_values = df.isna().sum().sum()
    total_values = df.shape[0] * df.shape[1]
    missing_percentage = (missing_values / total_values) * 100
    
    print(f"Methylation data loaded: {df.shape[0]} patients × {df.shape[1]} probes")
    print(f"Missing values: {missing_values:,} ({missing_percentage:.2f}% of total)")
    print("This data is prepared for tab-transformer network (beta values preserved)")
    
    return df

def load_mutation_data(file_path):
    """Load and preprocess mutation data from MAF format (NO log2 transformation)"""
    print("Loading mutation data...")
    print("Note: Impact scores (0-2), NO log2 transformation")
    
    # Load MAF file with encoding handling, skipping the version line
    try:
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            # Skip the first line (#version 2.4)
            first_line = f.readline()
            if first_line.startswith('#version'):
                # Read the rest of the file
                df = pd.read_csv(f, sep='\t', low_memory=False)
            else:
                # Reset file pointer and read normally
                f.seek(0)
                df = pd.read_csv(f, sep='\t', low_memory=False)
    except UnicodeDecodeError:
        with gzip.open(file_path, 'rt', encoding='latin-1') as f:
            first_line = f.readline()
            if first_line.startswith('#version'):
                df = pd.read_csv(f, sep='\t', low_memory=False)
            else:
                f.seek(0)
                df = pd.read_csv(f, sep='\t', low_memory=False)
    
    print(f"Raw MAF data: {df.shape[0]} mutations")
    
    # Define mutation impact scoring
    variant_impact = {
        'Silent': 0,
        'Missense_Mutation': 1,
        'Nonsense_Mutation': 2,
        'Frame_Shift_Del': 2,
        'Frame_Shift_Ins': 2,
        'Splice_Site': 2,
        'Translation_Start_Site': 1,
        'Nonstop_Mutation': 1,
        'In_Frame_Del': 1,
        'In_Frame_Ins': 1,
        "3'UTR": 0,
        "5'UTR": 0,
        'Intron': 0,
        'RNA': 0
    }
    
    # Filter for relevant columns
    required_cols = ['Hugo_Symbol', 'Tumor_Sample_Barcode', 'Variant_Classification']
    if not all(col in df.columns for col in required_cols):
        print(f"Missing required columns. Available columns: {list(df.columns[:10])}...")
        return pd.DataFrame()
    
    # Standardize patient IDs
    df['Patient_ID'] = df['Tumor_Sample_Barcode'].apply(standardize_patient_id)
    
    # Map variant classifications to impact scores
    df['Impact_Score'] = df['Variant_Classification'].map(variant_impact).fillna(0)
    
    # Aggregate mutations per patient-gene pair (take maximum impact)
    mutation_matrix = df.groupby(['Patient_ID', 'Hugo_Symbol'])['Impact_Score'].max().unstack(fill_value=0)
    
    print(f"Mutation matrix: {mutation_matrix.shape[0]} patients × {mutation_matrix.shape[1]} genes (impact scores)")
    
    return mutation_matrix

def load_clinical_data(file_path):
    """Load and preprocess clinical data"""
    print("Loading clinical data...")
    
    # Try different encodings to handle problematic characters
    encodings_to_try = ['utf-8', 'latin-1', 'iso-8859-1', 'cp1252']
    
    df = None
    for encoding in encodings_to_try:
        try:
            df = pd.read_csv(file_path, sep='\t', encoding=encoding, low_memory=False)
            print(f"Successfully loaded clinical data with {encoding} encoding")
            break
        except UnicodeDecodeError:
            continue
    
    if df is None:
        # Last resort: ignore problematic characters
        df = pd.read_csv(file_path, sep='\t', encoding='utf-8', errors='ignore', low_memory=False)
        print("Loaded clinical data with UTF-8 encoding, ignoring problematic characters")
    
    # Standardize patient IDs
    df['bcr_patient_barcode'] = df['bcr_patient_barcode'].apply(standardize_patient_id)
    
    # Set patient ID as index
    df = df.set_index('bcr_patient_barcode')
    
    print(f"Clinical data loaded: {df.shape[0]} patients × {df.shape[1]} features")
    
    return df

def clean_survival_data(clinical_df):
    """생존 데이터를 정리하여 올바른 숫자형으로 변환"""
    
    print("=== 생존 데이터 정리 ===")
    print()
    
    clinical_clean = clinical_df.copy()
    
    # 1. days_to_death 정리
    print("1. days_to_death 정리:")
    death_col = clinical_clean['days_to_death'].copy()
    
    # 비수치 값들을 NaN으로 변환
    invalid_death = death_col.isin(['[Not Applicable]', '[Not Available]', '[Discrepancy]', '[Unknown]'])
    print(f"   • 비수치 값: {invalid_death.sum()}개 → NaN으로 변환")
    
    death_col[invalid_death] = np.nan
    death_col = pd.to_numeric(death_col, errors='coerce')
    clinical_clean['days_to_death_clean'] = death_col
    
    # 2. days_to_last_followup 정리
    print("2. days_to_last_followup 정리:")
    followup_col = clinical_clean['days_to_last_followup'].copy()
    
    invalid_followup = followup_col.isin(['[Not Applicable]', '[Not Available]', '[Discrepancy]', '[Unknown]'])
    print(f"   • 비수치 값: {invalid_followup.sum()}개 → NaN으로 변환")
    
    followup_col[invalid_followup] = np.nan
    followup_col = pd.to_numeric(followup_col, errors='coerce')
    
    # 음수값 제거 (잘못된 데이터)
    negative_followup = followup_col < 0
    print(f"   • 음수 값: {negative_followup.sum()}개 → NaN으로 변환")
    followup_col[negative_followup] = np.nan
    
    clinical_clean['days_to_last_followup_clean'] = followup_col
    
    # 3. vital_status 정리
    print("3. vital_status 정리:")
    vital_status_counts = clinical_clean['vital_status'].value_counts()
    print(f"   • {vital_status_counts.to_dict()}")
    
    # 올바른 vital_status만 유지
    valid_vital_status = clinical_clean['vital_status'].isin(['Alive', 'Dead'])
    print(f"   • 유효한 vital_status: {valid_vital_status.sum()}개")
    
    # 4. 새로운 생존 시간과 이벤트 생성
    print("4. 새로운 survival_time과 survival_event 생성:")
    
    # survival_time 재계산
    survival_time_new = np.where(
        (clinical_clean['vital_status'] == 'Dead') & clinical_clean['days_to_death_clean'].notna(),
        clinical_clean['days_to_death_clean'],
        clinical_clean['days_to_last_followup_clean']
    )
    
    # survival_event 재계산
    survival_event_new = (clinical_clean['vital_status'] == 'Dead').astype(int)
    
    # 유효하지 않은 vital_status는 제외
    survival_event_new[~valid_vital_status] = np.nan
    survival_time_new[~valid_vital_status] = np.nan
    
    clinical_clean['survival_time_clean'] = survival_time_new
    clinical_clean['survival_event_clean'] = survival_event_new
    
    # 5. 유효한 생존 데이터만 남기기
    valid_survival = (
        pd.notna(clinical_clean['survival_time_clean']) & 
        pd.notna(clinical_clean['survival_event_clean']) &
        (clinical_clean['survival_time_clean'] >= 0)
    )
    
    print(f"   • 유효한 생존 데이터: {valid_survival.sum()}개")
    print(f"   • 사망 이벤트: {clinical_clean.loc[valid_survival, 'survival_event_clean'].sum()}개")
    print(f"   • 평균 생존 시간: {clinical_clean.loc[valid_survival, 'survival_time_clean'].mean():.1f}일")
    
    return clinical_clean, valid_survival

## 3. Load All Data

In [None]:
# Load all datasets
print("=" * 60)
print("LOADING TCGA PANCAN MULTI-OMICS DATA")
print("=" * 60)

# Load transcriptome data (log2 transformed)
expression_data, transformation_stats = load_transcriptome_data(
    DATA_RAW_PATH / 'unc.edu_PANCAN_IlluminaHiSeq_RNASeqV2.geneExp_whitelisted.tsv'
)

# Load CNV data (log2 transformed)
cnv_data = load_cnv_data(
    DATA_RAW_PATH / 'CNV.GISTIC_call.all_data_by_genes_whitelisted.tsv'
)

# Load microRNA data (log2 transformed)
mirna_data = load_mirna_data(
    DATA_RAW_PATH / 'bcgsc.ca_PANCAN_IlluminaHiSeq_miRNASeq.miRNAExp_whitelisted.tsv'
)

# Load RPPA data (log2 transformed)
rppa_data = load_rppa_data(
    DATA_RAW_PATH / 'mdanderson.org_PANCAN_MDA_RPPA_Core.RPPA_whitelisted.tsv'
)

# Load methylation data (NO transformation - for tab-transformer)
methylation_data = load_methylation_data(
    DATA_RAW_PATH / 'jhu-usc.edu_PANCAN_HumanMethylation450.betaValue_whitelisted.tsv'
)

# Load mutation data (impact scores - NO transformation)
mutation_data = load_mutation_data(
    DATA_RAW_PATH / 'tcga_pancancer_082115.vep.filter_whitelisted.maf.gz'
)

# Load clinical data
clinical_data = load_clinical_data(
    DATA_RAW_PATH / 'clinical_PANCAN_patient_with_followup.tsv'
)

print("\n" + "=" * 60)
print("DATA LOADING SUMMARY")
print("=" * 60)
print(f"Expression: {expression_data.shape[0]} patients × {expression_data.shape[1]} genes (log2 transformed)")
print(f"CNV: {cnv_data.shape[0]} patients × {cnv_data.shape[1]} genes (log2 transformed)")
print(f"microRNA: {mirna_data.shape[0]} patients × {mirna_data.shape[1]} miRNAs (log2 transformed)")
print(f"RPPA: {rppa_data.shape[0]} patients × {rppa_data.shape[1]} proteins (log2 transformed)")
print(f"Methylation: {methylation_data.shape[0]} patients × {methylation_data.shape[1]} probes (NO transformation)")
print(f"Mutations: {mutation_data.shape[0]} patients × {mutation_data.shape[1]} genes (impact scores)")
print(f"Clinical: {clinical_data.shape[0]} patients × {clinical_data.shape[1]} features")

LOADING TCGA PANCAN MULTI-OMICS DATA
Loading transcriptome data...
Gene symbols not split - using original index
Transcriptome data loaded: 10327 patients × 20531 genes (log2 transformed)
Loading CNV data...
Applied log2(x - -1.290 + 1) transformation for negative CNV values
CNV data loaded: 10713 patients × 25128 genes (log2 transformed)
Loading microRNA data...
microRNA data loaded: 9350 patients × 1071 miRNAs (log2 transformed)
Loading RPPA data...
Applied log2(x + 1) transformation
RPPA data loaded: 7656 patients × 387 proteins (log2 transformed)
Loading methylation data...
Note: NO log2 transformation applied - beta values (0-1) for tab-transformer


## 4. Survival Data Cleaning

In [None]:
# 생존 데이터 정리 실행
clinical_data_clean, valid_survival_mask = clean_survival_data(clinical_data)

print()
print("=== 정리 후 검증 ===")
print(f"전체 환자: {len(clinical_data_clean)}")
print(f"유효한 생존 데이터 보유 환자: {valid_survival_mask.sum()}")

# 유효한 생존 데이터를 가진 환자 목록
valid_survival_patients = clinical_data_clean.index[valid_survival_mask]
print(f"유효한 생존 데이터 환자 목록: {len(valid_survival_patients)}명")

## 5. Patient ID Matching and Data Integration

In [None]:
# Analyze patient overlap across datasets
datasets = {
    'Expression': set(expression_data.index),
    'CNV': set(cnv_data.index),
    'microRNA': set(mirna_data.index),
    'RPPA': set(rppa_data.index),
    'Methylation': set(methylation_data.index),
    'Mutations': set(mutation_data.index),
    'Clinical': set(valid_survival_patients)  # Only valid survival patients
}

print("Patient counts per dataset:")
for name, patients in datasets.items():
    print(f"{name}: {len(patients)} patients")

# Find common patients across all datasets (including methylation)
common_all_datasets = set.intersection(*datasets.values())
print(f"\nCommon patients across ALL datasets (including methylation): {len(common_all_datasets)}")

# Find common patients excluding methylation (for Cox analysis)
cox_datasets = {k: v for k, v in datasets.items() if k != 'Methylation'}
common_cox_patients = set.intersection(*cox_datasets.values())
print(f"Common patients for Cox analysis (excluding methylation): {len(common_cox_patients)}")

# Convert to sorted lists for pandas indexing
final_all_patients = sorted(list(common_all_datasets))
final_cox_patients = sorted(list(common_cox_patients))

print(f"\n=== 정리 후 검증 ===")
print(f"전체 데이터셋 공통 환자: {len(final_all_patients)}")
print(f"Cox 분석 대상 환자: {len(final_cox_patients)}")
print(f"메틸레이션 데이터 (tab-transformer용): {len(final_all_patients)} 환자")

# 최종 분석 대상 환자 리스트 
final_patient_list_clean = final_cox_patients
print(f"최종 Cox 분석 대상: {len(final_patient_list_clean)} 환자")

## 6. Data Quality Visualization

In [None]:
# Create data quality visualization
plt.figure(figsize=(20, 15))

# Plot patient counts
plt.subplot(3, 3, 1)
dataset_counts = [len(patients) for patients in datasets.values()]
plt.bar(range(len(datasets)), dataset_counts, color='skyblue')
plt.xticks(range(len(datasets)), list(datasets.keys()), rotation=45)
plt.title('Patient Counts by Dataset')
plt.ylabel('Number of Patients')
for i, count in enumerate(dataset_counts):
    plt.text(i, count + 50, str(count), ha='center')

# Plot cancer type distribution
if 'acronym' in clinical_data_clean.columns:
    plt.subplot(3, 3, 2)
    cancer_counts = clinical_data_clean.loc[final_cox_patients, 'acronym'].value_counts()
    cancer_counts.head(15).plot(kind='bar', color='lightcoral')
    plt.title('Cancer Types (Cox Analysis Patients)')
    plt.xticks(rotation=45)
    plt.ylabel('Patient Count')

# Plot survival data quality
plt.subplot(3, 3, 3)
survival_times = clinical_data_clean.loc[final_cox_patients, 'survival_time_clean']
plt.hist(survival_times.dropna(), bins=50, alpha=0.7, color='lightgreen')
plt.title('Survival Time Distribution')
plt.xlabel('Days')
plt.ylabel('Patient Count')

# Plot survival events
plt.subplot(3, 3, 4)
event_counts = clinical_data_clean.loc[final_cox_patients, 'survival_event_clean'].value_counts()
plt.pie(event_counts.values, labels=['Censored', 'Death'], autopct='%1.1f%%', colors=['lightblue', 'salmon'])
plt.title('Survival Events')

# Plot dataset overlap
plt.subplot(3, 3, 5)
overlap_data = {
    'All datasets\n(w/ methylation)': len(final_all_patients),
    'Cox datasets\n(no methylation)': len(final_cox_patients),
    'Valid survival': len(valid_survival_patients)
}
plt.bar(overlap_data.keys(), overlap_data.values(), color=['gold', 'darkblue', 'green'])
plt.title('Patient Overlap Analysis')
plt.ylabel('Number of Patients')
for i, (key, value) in enumerate(overlap_data.items()):
    plt.text(i, value + 20, str(value), ha='center')

# Plot transformation comparison for expression data
plt.subplot(3, 3, 6)
original_mean = transformation_stats['original']['mean']
transformed_mean = transformation_stats['transformed']['mean']
plt.bar(['Original', 'Log2+1'], [original_mean, transformed_mean], color=['orange', 'purple'])
plt.title('Expression Data Transformation')
plt.ylabel('Mean Value')

# Plot methylation data characteristics
plt.subplot(3, 3, 7)
meth_sample = methylation_data.loc[final_all_patients[:100]].values.flatten()
meth_sample_clean = meth_sample[~np.isnan(meth_sample)]
plt.hist(meth_sample_clean[:10000], bins=50, alpha=0.7, color='magenta')
plt.title('Methylation Beta Values Distribution\n(Sample)')
plt.xlabel('Beta Value (0-1)')
plt.ylabel('Frequency')

# Plot data completeness heatmap
plt.subplot(3, 3, 8)
completeness_matrix = []
dataset_names = list(datasets.keys())
for i, (name1, patients1) in enumerate(datasets.items()):
    row = []
    for j, (name2, patients2) in enumerate(datasets.items()):
        if i == j:
            overlap = 1.0
        else:
            intersection = len(patients1.intersection(patients2))
            union = len(patients1.union(patients2))
            overlap = intersection / union if union > 0 else 0
        row.append(overlap)
    completeness_matrix.append(row)

im = plt.imshow(completeness_matrix, cmap='Blues', vmin=0, vmax=1)
plt.xticks(range(len(dataset_names)), dataset_names, rotation=45)
plt.yticks(range(len(dataset_names)), dataset_names)
plt.title('Dataset Overlap Matrix')
plt.colorbar(im, shrink=0.6)

# Add text annotations
for i in range(len(dataset_names)):
    for j in range(len(dataset_names)):
        plt.text(j, i, f'{completeness_matrix[i][j]:.2f}',
                ha='center', va='center', color='black' if completeness_matrix[i][j] < 0.5 else 'white')

# Plot missing data analysis
plt.subplot(3, 3, 9)
missing_data = {
    'Methylation': (methylation_data.isna().sum().sum() / (methylation_data.shape[0] * methylation_data.shape[1])) * 100
}
plt.bar(missing_data.keys(), missing_data.values(), color='red', alpha=0.7)
plt.title('Missing Data Percentage')
plt.ylabel('Missing Data (%)')
for i, (key, value) in enumerate(missing_data.items()):
    plt.text(i, value + 0.1, f'{value:.1f}%', ha='center')

plt.tight_layout()
plt.show()

# Display transformation statistics
print(f"\nExpression Data Transformation Statistics:")
print(f"Original data - Mean: {transformation_stats['original']['mean']:.3f}, Std: {transformation_stats['original']['std']:.3f}")
print(f"Transformed data - Mean: {transformation_stats['transformed']['mean']:.3f}, Std: {transformation_stats['transformed']['std']:.3f}")
print(f"Zero values - Original: {transformation_stats['original']['zeros']}, Transformed: {transformation_stats['transformed']['zeros']}")

print(f"\nData Integration Summary:")
print(f"• Total datasets: {len(datasets)}")
print(f"• Log2 transformed: Expression, CNV, microRNA, RPPA")
print(f"• NO transformation: Methylation (beta values), Mutations (impact scores)")
print(f"• Cox analysis patients: {len(final_cox_patients)}")
print(f"• Tab-transformer patients (with methylation): {len(final_all_patients)}")

## 7. Cox Regression Analysis Functions

In [None]:
def perform_cox_regression_by_cancer(omics_data, clinical_data, omics_type, min_patients=20, p_threshold=0.05):
    """
    Perform Cox regression analysis by cancer type for given omics data
    
    Parameters:
    - omics_data: DataFrame with patients as rows, features as columns
    - clinical_data: DataFrame with survival information
    - omics_type: String identifier for the omics type
    - min_patients: Minimum number of patients required per cancer type
    - p_threshold: P-value threshold for significance
    
    Returns:
    - cox_results: Dictionary with results by cancer type
    - summary_stats: Overall summary statistics
    """
    
    print(f"\nPerforming Cox regression analysis for {omics_type} data...")
    
    # Filter for common patients
    common_patients = list(set(omics_data.index).intersection(set(clinical_data.index)))
    
    # Get survival data for common patients (use cleaned survival data)
    survival_data = clinical_data.loc[common_patients, ['survival_time_clean', 'survival_event_clean', 'acronym']].copy()
    survival_data = survival_data.dropna()
    
    # Filter omics data for patients with survival data
    omics_filtered = omics_data.loc[survival_data.index].copy()
    
    print(f"Analysis dataset: {len(survival_data)} patients with {omics_filtered.shape[1]} features")
    
    # Group by cancer type
    cancer_types = survival_data['acronym'].value_counts()
    valid_cancers = cancer_types[cancer_types >= min_patients].index
    
    print(f"Cancer types with >= {min_patients} patients: {len(valid_cancers)}")
    
    cox_results = {}
    summary_stats = {
        'total_features': omics_filtered.shape[1],
        'total_patients': len(survival_data),
        'cancer_types': len(valid_cancers),
        'significant_features': {},
        'top_features': {}
    }
    
    for cancer in tqdm(valid_cancers, desc=f"Processing {omics_type}"):
        # Get patients for this cancer type
        cancer_patients = survival_data[survival_data['acronym'] == cancer].index
        
        # Get omics and survival data for this cancer
        cancer_omics = omics_filtered.loc[cancer_patients]
        cancer_survival = survival_data.loc[cancer_patients, ['survival_time_clean', 'survival_event_clean']]
        
        # Remove features with zero variance
        feature_vars = cancer_omics.var()
        valid_features = feature_vars[feature_vars > 0].index
        cancer_omics_filtered = cancer_omics[valid_features]
        
        print(f"\n{cancer}: {len(cancer_patients)} patients, {len(valid_features)} variable features")
        
        # Perform univariate Cox regression for each feature
        feature_results = []
        
        for feature in valid_features:
            try:
                # Create dataframe for Cox regression
                cox_data = pd.DataFrame({
                    'T': cancer_survival['survival_time_clean'],
                    'E': cancer_survival['survival_event_clean'],
                    feature: cancer_omics_filtered[feature]
                })
                
                # Remove rows with missing data
                cox_data = cox_data.dropna()
                
                if len(cox_data) < 5:  # Need at least 5 observations
                    continue
                
                # Fit Cox model
                cph = CoxPHFitter()
                cph.fit(cox_data, duration_col='T', event_col='E')
                
                # Extract results
                coef = cph.summary.loc[feature, 'coef']
                p_value = cph.summary.loc[feature, 'p']
                hr = np.exp(coef)
                ci_lower = np.exp(cph.summary.loc[feature, 'coef lower 95%'])
                ci_upper = np.exp(cph.summary.loc[feature, 'coef upper 95%'])
                
                feature_results.append({
                    'feature': feature,
                    'coef': coef,
                    'hr': hr,
                    'p_value': p_value,
                    'ci_lower': ci_lower,
                    'ci_upper': ci_upper,
                    'n_patients': len(cox_data)
                })
                
            except Exception as e:
                # Skip problematic features
                continue
        
        # Convert to DataFrame and sort by p-value
        if feature_results:
            results_df = pd.DataFrame(feature_results)
            results_df = results_df.sort_values('p_value')
            
            # Count significant features
            significant_count = sum(results_df['p_value'] < p_threshold)
            
            cox_results[cancer] = results_df
            summary_stats['significant_features'][cancer] = significant_count
            summary_stats['top_features'][cancer] = results_df.head(10)
            
            print(f"  Significant features (p < {p_threshold}): {significant_count}/{len(results_df)}")
        else:
            print(f"  No valid results for {cancer}")
    
    return cox_results, summary_stats

def create_cox_coefficient_lookup(cox_results_dict, omics_types):
    """Create a comprehensive Cox coefficient lookup table"""
    
    print("\nCreating Cox coefficient lookup tables...")
    
    # Initialize lookup dictionary
    lookup_tables = {}
    
    for omics_type in omics_types:
        if omics_type in cox_results_dict:
            print(f"\nProcessing {omics_type} results...")
            
            # Combine all cancer type results
            all_results = []
            
            for cancer, results_df in cox_results_dict[omics_type].items():
                if not results_df.empty:
                    results_copy = results_df.copy()
                    results_copy['cancer_type'] = cancer
                    results_copy['omics_type'] = omics_type
                    all_results.append(results_copy)
            
            if all_results:
                combined_df = pd.concat(all_results, ignore_index=True)
                
                # Create pivot table: features × cancer_types with coefficients
                pivot_coef = combined_df.pivot_table(
                    index='feature', 
                    columns='cancer_type', 
                    values='coef', 
                    fill_value=0
                )
                
                # Create pivot table for p-values
                pivot_pval = combined_df.pivot_table(
                    index='feature', 
                    columns='cancer_type', 
                    values='p_value', 
                    fill_value=1
                )
                
                # Create summary statistics per feature
                feature_stats = combined_df.groupby('feature').agg({
                    'coef': ['mean', 'std', 'count'],
                    'p_value': ['min', 'mean'],
                    'hr': ['mean']
                }).round(4)
                
                # Flatten column names
                feature_stats.columns = ['_'.join(col).strip() for col in feature_stats.columns]
                
                lookup_tables[omics_type] = {
                    'coefficients': pivot_coef,
                    'p_values': pivot_pval,
                    'feature_stats': feature_stats,
                    'raw_results': combined_df
                }
                
                print(f"  {omics_type}: {len(pivot_coef)} features across {len(pivot_coef.columns)} cancer types")
    
    return lookup_tables

def visualize_cox_results(cox_results, omics_type, top_n=20):
    """Visualize Cox regression results"""
    
    # Combine results across cancer types
    all_results = []
    for cancer, results_df in cox_results.items():
        if not results_df.empty:
            results_copy = results_df.copy()
            results_copy['cancer_type'] = cancer
            all_results.append(results_copy)
    
    if not all_results:
        print(f"No results to visualize for {omics_type}")
        return
    
    combined_df = pd.concat(all_results, ignore_index=True)
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'Cox Regression Results: {omics_type}', fontsize=16)
    
    # 1. P-value distribution
    axes[0, 0].hist(combined_df['p_value'], bins=50, alpha=0.7)
    axes[0, 0].axvline(x=0.05, color='red', linestyle='--', label='p=0.05')
    axes[0, 0].set_xlabel('P-value')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('P-value Distribution')
    axes[0, 0].legend()
    
    # 2. Coefficient distribution
    axes[0, 1].hist(combined_df['coef'], bins=50, alpha=0.7)
    axes[0, 1].axvline(x=0, color='red', linestyle='--', label='coef=0')
    axes[0, 1].set_xlabel('Cox Coefficient')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Coefficient Distribution')
    axes[0, 1].legend()
    
    # 3. Top significant features
    significant_features = combined_df[combined_df['p_value'] < 0.05]
    if len(significant_features) > 0:
        top_features = significant_features.nsmallest(top_n, 'p_value')
        
        # Create a color map for cancer types
        cancer_types = top_features['cancer_type'].unique()
        colors = plt.cm.Set3(np.linspace(0, 1, len(cancer_types)))
        color_map = dict(zip(cancer_types, colors))
        
        scatter_colors = [color_map[cancer] for cancer in top_features['cancer_type']]
        
        scatter = axes[1, 0].scatter(top_features['coef'], -np.log10(top_features['p_value']), 
                                   c=scatter_colors, alpha=0.7)
        axes[1, 0].axhline(y=-np.log10(0.05), color='red', linestyle='--', label='p=0.05')
        axes[1, 0].axvline(x=0, color='red', linestyle='--')
        axes[1, 0].set_xlabel('Cox Coefficient')
        axes[1, 0].set_ylabel('-log10(p-value)')
        axes[1, 0].set_title(f'Top {top_n} Significant Features')
        
        # Add legend for cancer types
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                    markerfacecolor=color_map[cancer], markersize=8, label=cancer)
                         for cancer in cancer_types[:10]]  # Limit legend size
        axes[1, 0].legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # 4. Significant features by cancer type
    cancer_sig_counts = combined_df[combined_df['p_value'] < 0.05]['cancer_type'].value_counts()
    if len(cancer_sig_counts) > 0:
        cancer_sig_counts.head(15).plot(kind='bar', ax=axes[1, 1])
        axes[1, 1].set_xlabel('Cancer Type')
        axes[1, 1].set_ylabel('Significant Features')
        axes[1, 1].set_title('Significant Features by Cancer')
        axes[1, 1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    total_tests = len(combined_df)
    significant_tests = len(combined_df[combined_df['p_value'] < 0.05])
    
    print(f"\n{omics_type} Summary:")
    print(f"Total tests: {total_tests:,}")
    print(f"Significant tests (p < 0.05): {significant_tests:,} ({100*significant_tests/total_tests:.1f}%)")
    print(f"Cancer types analyzed: {combined_df['cancer_type'].nunique()}")
    print(f"Unique features tested: {combined_df['feature'].nunique():,}")

## 8. Perform Cox Regression Analysis

In [None]:
# Prepare filtered datasets for analysis
filtered_data = {}

# Filter all datasets to common patients with survival data
for name, data in [
    ('Expression', expression_data),
    ('CNV', cnv_data),
    ('microRNA', mirna_data),
    ('RPPA', rppa_data),
    ('Mutations', mutation_data)
]:
    # Filter to common survival patients
    common_patients_data = data.loc[final_patient_list_clean]
    filtered_data[name] = common_patients_data
    print(f"{name}: {common_patients_data.shape[0]} patients × {common_patients_data.shape[1]} features")

# Also filter clinical data
filtered_clinical = clinical_data_clean.loc[final_patient_list_clean]
print(f"Clinical: {filtered_clinical.shape[0]} patients × {filtered_clinical.shape[1]} features")

print(f"\\nAll datasets now have {len(final_patient_list_clean)} patients with complete omics and survival data")

In [None]:
# Perform Cox regression analysis for each omics type
print("\n" + "=" * 60)
print("PERFORMING COX REGRESSION ANALYSIS")
print("=" * 60)

# Store all results
all_cox_results = {}
all_summary_stats = {}

# Define omics types and their corresponding data
omics_data_map = {
    'Expression': filtered_data['Expression'],
    'CNV': filtered_data['CNV'],
    'microRNA': filtered_data['microRNA'],
    'RPPA': filtered_data['RPPA'],
    'Mutations': filtered_data['Mutations']
}

# Run Cox regression for each omics type
for omics_type, omics_data in omics_data_map.items():
    print(f"\n{'='*50}")
    print(f"Processing {omics_type}")
    print(f"{'='*50}")
    
    # Perform Cox regression analysis
    cox_results, summary_stats = perform_cox_regression_by_cancer(
        omics_data=omics_data,
        clinical_data=filtered_clinical,
        omics_type=omics_type,
        min_patients=20,  # Minimum 20 patients per cancer type
        p_threshold=0.05
    )
    
    # Store results
    all_cox_results[omics_type] = cox_results
    all_summary_stats[omics_type] = summary_stats
    
    # Display summary
    print(f"\n{omics_type} Analysis Summary:")
    print(f"  Total features analyzed: {summary_stats['total_features']:,}")
    print(f"  Total patients: {summary_stats['total_patients']:,}")
    print(f"  Cancer types analyzed: {summary_stats['cancer_types']}")
    
    # Display significant features by cancer type
    if summary_stats['significant_features']:
        print(f"  Significant features by cancer type:")
        for cancer, count in summary_stats['significant_features'].items():
            print(f"    {cancer}: {count} significant features")

print(f"\n{'='*60}")
print("COX REGRESSION ANALYSIS COMPLETED")
print(f"{'='*60}")

## 10. Create Cox Coefficient Lookup Tables and Save Results

In [None]:
# Create comprehensive lookup tables
print("\n" + "=" * 60)
print("CREATING COX COEFFICIENT LOOKUP TABLES")
print("=" * 60)

# Create lookup tables for all omics types
omics_types = list(all_cox_results.keys())
lookup_tables = create_cox_coefficient_lookup(all_cox_results, omics_types)

# Visualize results for each omics type
for omics_type in omics_types:
    if omics_type in all_cox_results and all_cox_results[omics_type]:
        print(f"\nVisualizing {omics_type} results...")
        visualize_cox_results(all_cox_results[omics_type], omics_type, top_n=20)

# Save all results
print("\n" + "=" * 60)
print("SAVING PROCESSED DATA AND RESULTS")
print("=" * 60)

# Save lookup tables
for omics_type, tables in lookup_tables.items():
    # Save coefficient matrix
    coef_file = DATA_PROCESSED_PATH / f'cox_coefficients_{omics_type.lower()}.parquet'
    tables['coefficients'].to_parquet(coef_file)
    print(f"Saved {omics_type} coefficients: {coef_file}")
    
    # Save p-values matrix
    pval_file = DATA_PROCESSED_PATH / f'cox_pvalues_{omics_type.lower()}.parquet'
    tables['p_values'].to_parquet(pval_file)
    print(f"Saved {omics_type} p-values: {pval_file}")
    
    # Save feature statistics
    stats_file = DATA_PROCESSED_PATH / f'cox_feature_stats_{omics_type.lower()}.parquet'
    tables['feature_stats'].to_parquet(stats_file)
    print(f"Saved {omics_type} feature stats: {stats_file}")
    
    # Save raw results
    raw_file = DATA_PROCESSED_PATH / f'cox_raw_results_{omics_type.lower()}.parquet'
    tables['raw_results'].to_parquet(raw_file)
    print(f"Saved {omics_type} raw results: {raw_file}")

# Save processed omics data
for omics_type, data in filtered_data.items():
    processed_file = DATA_PROCESSED_PATH / f'processed_{omics_type.lower()}_data.parquet'
    data.to_parquet(processed_file)
    print(f"Saved processed {omics_type} data: {processed_file}")

# Save methylation data separately for tab-transformer
methylation_file = DATA_PROCESSED_PATH / 'methylation_data_for_tabtransformer.parquet'
# Filter methylation data to common patients
if len(final_all_patients) > 0:
    methylation_filtered = methylation_data.loc[final_all_patients]
    methylation_filtered.to_parquet(methylation_file)
    print(f"Saved methylation data for tab-transformer: {methylation_file}")
    print(f"  Shape: {methylation_filtered.shape[0]} patients × {methylation_filtered.shape[1]} probes")
    print(f"  Beta values preserved (0-1 range) for tab-transformer network")
else:
    print("Warning: No common patients found for methylation data")

# Save processed clinical data
clinical_file = DATA_PROCESSED_PATH / 'processed_clinical_data.parquet'
filtered_clinical.to_parquet(clinical_file)
print(f"Saved processed clinical data: {clinical_file}")

# Save analysis summary
summary_file = RESULTS_PATH / 'cox_analysis_summary.json'
with open(summary_file, 'w') as f:
    # Convert numpy types to native Python types for JSON serialization
    summary_for_json = {}
    for omics_type, stats in all_summary_stats.items():
        summary_for_json[omics_type] = {
            'total_features': int(stats['total_features']),
            'total_patients': int(stats['total_patients']),
            'cancer_types': int(stats['cancer_types']),
            'significant_features': {k: int(v) for k, v in stats['significant_features'].items()}
        }
    
    json.dump(summary_for_json, f, indent=2)
print(f"Saved analysis summary: {summary_file}")

# Save transformation statistics
transform_file = RESULTS_PATH / 'transformation_stats.json'
with open(transform_file, 'w') as f:
    # Convert numpy types to native Python types
    transform_for_json = {}
    for key, value in transformation_stats.items():
        if isinstance(value, dict):
            transform_for_json[key] = {k: float(v) for k, v in value.items()}
        else:
            transform_for_json[key] = int(value) if isinstance(value, (int, np.integer)) else float(value)
    
    json.dump(transform_for_json, f, indent=2)
print(f"Saved transformation stats: {transform_file}")

# Save data processing metadata
metadata = {
    'data_processing_info': {
        'total_datasets': len(datasets),
        'log2_transformed': ['Expression', 'CNV', 'microRNA', 'RPPA'],
        'no_transformation': ['Methylation', 'Mutations'],
        'cox_analysis_patients': len(final_cox_patients),
        'methylation_patients': len(final_all_patients),
        'transformation_applied': 'log2(x+1) for Expression, CNV, microRNA, RPPA',
        'methylation_note': 'Beta values (0-1) preserved for tab-transformer',
        'mutation_note': 'Impact scores (0-2) for variant classification'
    }
}

metadata_file = RESULTS_PATH / 'data_processing_metadata.json'
with open(metadata_file, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"Saved data processing metadata: {metadata_file}")

print(f"\n{'='*60}")
print("ALL DATA PROCESSING AND ANALYSIS COMPLETED!")
print(f"{'='*60}")
print(f"Processed data saved to: {DATA_PROCESSED_PATH}")
print(f"Analysis results saved to: {RESULTS_PATH}")
print(f"Cox analysis patients: {len(final_patient_list_clean)}")
print(f"Methylation data patients: {len(final_all_patients)}")
print(f"Omics types processed: {', '.join(omics_types)}")

# Display final summary
print(f"\nFinal Analysis Summary:")
for omics_type, stats in all_summary_stats.items():
    total_significant = sum(stats['significant_features'].values())
    print(f"  {omics_type}:")
    print(f"    - Features: {stats['total_features']:,}")
    print(f"    - Cancer types: {stats['cancer_types']}")
    print(f"    - Significant features: {total_significant:,}")

print(f"\nData Files Summary:")
print(f"  Cox Regression Analysis:")
print(f"    - Patients: {len(final_patient_list_clean)}")
print(f"    - Omics types: Expression, CNV, microRNA, RPPA, Mutations")
print(f"    - All data log2 transformed (except mutations)")
print(f"  Tab-Transformer Data:")
print(f"    - Patients: {len(final_all_patients)}")
print(f"    - Methylation probes: {methylation_filtered.shape[1] if len(final_all_patients) > 0 else 0}")
print(f"    - Beta values preserved (0-1 range)")