<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/als_new_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pennylane

Collecting pennylane
  Downloading pennylane-0.43.0-py3-none-any.whl.metadata (11 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray==0.8.0 (from pennylane)
  Downloading autoray-0.8.0-py3-none-any.whl.metadata (6.1 kB)
Collecting pennylane-lightning>=0.43 (from pennylane)
  Downloading pennylane_lightning-0.43.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.43->pennylane)
  Downloading scipy_openblas32-0.3.30.0.7-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (57 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [2]:
import pandas as pd
import numpy as np

# -------------------------------
# 1. Load all relevant CSV tables
# -------------------------------
alsfrs_df = pd.read_csv('PROACT_ALSFRS.csv')
fvc_df = pd.read_csv('PROACT_FVC.csv')
vitals_df = pd.read_csv('PROACT_VITALSIGNS.csv')
labs_df = pd.read_csv('PROACT_LABS.csv')
onset_df = pd.read_csv('PROACT_ALSHISTORY.csv')
riluzole_df = pd.read_csv('PROACT_RILUZOLE.csv')
demographics_df = pd.read_csv('PROACT_DEMOGRAPHICS.csv')

# -------------------------------
# 2. Compute ALSFRS (convert ALSFRS-R to original if needed)
# -------------------------------
def convert_alsfrs_row(row):
    if pd.notna(row.get('ALSFRS_Total')):
        return row['ALSFRS_Total']
    total = 0
    for q in range(1, 10):
        val = row.get(f'Q{q}', np.nan)
        if pd.notna(val):
            total += val
    # Handle Q10 (respiratory)
    if pd.notna(row.get('Q10_Respiratory')):
        total += row['Q10_Respiratory']
    elif pd.notna(row.get('R_1_Dyspnea')):
        total += row.get('R_1_Dyspnea')
    return total

alsfrs_df['ALSFRS_Total_orig'] = alsfrs_df.apply(convert_alsfrs_row, axis=1)

# -------------------------------
# 3. Identify valid patients
# -------------------------------
months_start, months_end = 3, 12
min_records_start, min_records_end = 2, 2
days_start, days_end = months_start * 30, months_end * 30

alsfrs_counts = alsfrs_df.groupby('subject_id')['ALSFRS_Delta'].agg(
    records_before_start=lambda x: (x <= days_start).sum(),
    records_after_end=lambda x: (x >= days_end).sum()
)

valid_patients_df = alsfrs_counts[
    (alsfrs_counts['records_before_start'] >= min_records_start) &
    (alsfrs_counts['records_after_end'] >= min_records_end)
]
valid_patients = sorted(valid_patients_df.index.tolist())
print(f"‚úÖ Valid patients: {len(valid_patients)}")

# -------------------------------
# 4. Compute ALSFRS slope (3‚Äì12 months)
# -------------------------------
slope_targets = {}
for pid in valid_patients:
    patient_data = alsfrs_df[alsfrs_df['subject_id'] == pid].copy()
    patient_data.sort_values('ALSFRS_Delta', inplace=True)
    t1 = patient_data[patient_data['ALSFRS_Delta'] > 90]
    t2 = patient_data[patient_data['ALSFRS_Delta'] >= 365]
    if len(t1) > 0 and len(t2) > 0:
        t1_record = t1.iloc[0]
        t2_record = t2.iloc[0]
        delta_days = t2_record['ALSFRS_Delta'] - t1_record['ALSFRS_Delta']
        if delta_days > 0:
            slope = (t2_record['ALSFRS_Total_orig'] - t1_record['ALSFRS_Total_orig']) / (delta_days / 30.0)
            slope_targets[pid] = slope

target_df = pd.Series(slope_targets, name='ALSFRS_slope_3to12m')
print("‚úÖ ALSFRS slope computed for", len(target_df), "patients")
print(target_df.describe())

# -------------------------------
# 5. Summarize all numeric columns in a time-series table
# -------------------------------
def summarize_timeseries(df, time_col, value_col):
    grp = df.groupby('subject_id')
    summary = pd.DataFrame({
        'min': grp[value_col].min(),
        'max': grp[value_col].max(),
        'median': grp[value_col].median(),
        'std': grp[value_col].std(),
        'first': grp.apply(lambda g: g.sort_values(time_col)[value_col].iloc[0], include_groups=False),
        'last': grp.apply(lambda g: g.sort_values(time_col)[value_col].iloc[-1], include_groups=False)
    })
    time_first = grp[time_col].min()
    time_last = grp[time_col].max()
    time_diff_months = (time_last - time_first) / 30.0
    summary['slope'] = (summary['last'] - summary['first']) / time_diff_months
    summary.loc[time_diff_months == 0, 'slope'] = np.nan
    return summary

def summarize_all_numeric(df, time_col):
    numeric_cols = df.select_dtypes(include=['number']).columns.drop([time_col, 'subject_id'], errors='ignore')
    summaries = {}
    for col in numeric_cols:
        summaries[col] = summarize_timeseries(df, time_col, col)
        summaries[col].columns = [f'{col}_{c}' for c in summaries[col].columns]
    return summaries

# -------------------------------
# 6. Subset to first 90 days and summarize automatically
# -------------------------------
alsfrs_3m = alsfrs_df[alsfrs_df['subject_id'].isin(valid_patients) & (alsfrs_df['ALSFRS_Delta'] <= 90)]
fvc_df['FVC'] = fvc_df[['Subject_Liters_Trial_1','Subject_Liters_Trial_2','Subject_Liters_Trial_3']].max(axis=1)
fvc_3m = fvc_df[fvc_df['subject_id'].isin(valid_patients) & (fvc_df['Forced_Vital_Capacity_Delta'] <= 90)]
vitals_3m = vitals_df[vitals_df['subject_id'].isin(valid_patients) & (vitals_df['Vital_Signs_Delta'] <= 90)]
labs_3m = labs_df[labs_df['subject_id'].isin(valid_patients) & (labs_df['Laboratory_Delta'] <= 90)]

alsfrs_features = summarize_all_numeric(alsfrs_3m, 'ALSFRS_Delta')
fvc_features = summarize_all_numeric(fvc_3m, 'Forced_Vital_Capacity_Delta')
vitals_features = summarize_all_numeric(vitals_3m, 'Vital_Signs_Delta')
labs_features = summarize_all_numeric(labs_3m, 'Laboratory_Delta')

# -------------------------------
# 7. Merge all features
# -------------------------------
features_df = pd.DataFrame(index=valid_patients)

def encode_static_categoricals(df, categorical_columns):
    df = df.copy()
    for col in categorical_columns:
        if col in df.columns:
            # Add NaN as a category to preserve patient list shape
            df[col] = df[col].astype('category')
            dummies = pd.get_dummies(df[col], prefix=col, dummy_na=True)
            df = pd.concat([df, dummies], axis=1)
            df.drop(columns=[col], inplace=True)
    return df

categorical_cols_onset = ['Site_of_Onset']
onset_static = onset_df.drop_duplicates(subset='subject_id', keep='first').set_index('subject_id')
onset_static = encode_static_categoricals(onset_static, categorical_cols_onset)
if 'Onset_Delta' not in onset_static: onset_static['Onset_Delta'] = np.nan
if 'Diagnosis_Delta' not in onset_static: onset_static['Diagnosis_Delta'] = np.nan

riluzole_static = riluzole_df.drop_duplicates(subset='subject_id', keep='first').set_index('subject_id')
riluzole_static = encode_static_categoricals(riluzole_static, ['Subject_used_Riluzole'])
if 'Riluzole_use_Delta' not in riluzole_static: riluzole_static['Riluzole_use_Delta'] = np.nan

demographics_static = demographics_df.drop_duplicates(subset='subject_id', keep='first').set_index('subject_id')
demographics_static = encode_static_categoricals(demographics_static, ['Sex'])

features_df = features_df.join(onset_static, how='left')
features_df = features_df.join(riluzole_static, how='left', rsuffix='_rilu')
features_df = features_df.join(demographics_static, how='left', rsuffix='_demo')

for group in [alsfrs_features, fvc_features, vitals_features, labs_features]:
    for feat_df in group.values():
        features_df = features_df.join(feat_df, how='left')

features_df = features_df.join(target_df, how='left')
features_df = features_df.dropna(axis=1, how='all')
features_df = features_df.loc[:, features_df.nunique(dropna=False) > 1]

# ------------- NEW: Numeric Conversion ---------------
# Remove all remaining object columns (if any not captured)
for col in features_df.columns:
    if features_df[col].dtype == 'object':
        try:
            features_df[col] = pd.to_numeric(features_df[col], errors='coerce').fillna(0)
        except Exception:
            features_df = features_df.drop(columns=[col])

print(f"‚úÖ Final features shape: {features_df.shape}")
print(features_df.head(3))
# Now features_df is fully numeric and safe for PCA, scaling, or direct input to any ML model.


‚úÖ Valid patients: 2442
‚úÖ ALSFRS slope computed for 2439 patients
count    2439.000000
mean       -0.388076
std         0.496497
min        -3.100000
25%        -0.638298
50%        -0.218978
75%         0.000000
max         1.052632
Name: ALSFRS_slope_3to12m, dtype: float64
‚úÖ Final features shape: (2442, 352)
      Site_of_Onset___Limb  Subject_ALS_History_Delta  Symptom  Location  \
121                    NaN                        0.0      0.0       0.0   
1009                   NaN                        0.0      0.0       0.0   
1036                   NaN                        0.0      0.0       0.0   

      Onset_Delta  Diagnosis_Delta  Site_of_Onset_Onset: Bulbar  \
121           NaN              NaN                        False   
1009       -324.0            -63.0                        False   
1036          NaN              NaN                         True   

      Site_of_Onset_Onset: Limb  Site_of_Onset_Onset: Limb and Bulbar  \
121                        True     

In [10]:
# ============================================================================
# PURE QNN - PRODUCTION FINAL (Shape-Safe + Honest K + Q5 Fixed)
# ============================================================================

import re
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
from scipy.stats import pearsonr
import pennylane as qml
from pennylane.qnn import TorchLayer
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.set_default_dtype(torch.float32)

print("="*80)
print("‚öõÔ∏è  PURE QNN - PRODUCTION FINAL (Shape-Safe)")
print("="*80)

KEY = "subject_id"
TARGET = "ALSFRS_slope_3to12m"

# ============================================================================
# LOAD DATA
# ============================================================================

print(f"\nüìÇ Loading PROACT data...")
alsfrs_df = pd.read_csv('PROACT_ALSFRS.csv')
fvc_df = pd.read_csv('PROACT_FVC.csv')
vitals_df = pd.read_csv('PROACT_VITALSIGNS.csv')
labs_df = pd.read_csv('PROACT_LABS.csv')
svc_df = pd.read_csv('PROACT_SVC.csv')
grip_df = pd.read_csv('PROACT_HANDGRIPSTRENGTH.csv')
demographics_df = pd.read_csv('PROACT_DEMOGRAPHICS.csv')
riluzole_df = pd.read_csv('PROACT_RILUZOLE.csv')
onset_df = pd.read_csv('PROACT_ALSHISTORY.csv')
print(f"‚úÖ Loaded\n")

# ============================================================================
# FIX #3: ALSFRS CONVERSION (Q5 coalescing bug fixed)
# ============================================================================

print("üîÑ Converting ALSFRS-R ‚Üí ALSFRS (40-point, paper-faithful)...")

def to_ALSFRS_original(row):
    """Paper-faithful: Q5=max(Q5a,Q5b), Q10=Q10a (Dyspnea), drop 10b/10c
    FIX: Avoid 'or' on numeric zeros; use safe coalescing."""

    if pd.notna(row.get('ALSFRS_Total')):
        return pd.to_numeric(row['ALSFRS_Total'], errors='coerce')

    def _get(row, *keys):
        """Safe field lookup: return first non-NaN value"""
        for k in keys:
            if k in row.index and pd.notna(row[k]):
                return row[k]
        return np.nan

    s = 0.0

    # Q1-Q4
    for q in range(1, 5):
        q_name = f'Q{q}'
        if q_name in row.index:
            s += pd.to_numeric(row[q_name], errors='coerce') or 0.0

    # Q5 = max(Q5a, Q5b) - FIXED coalescing
    q5a = pd.to_numeric(_get(row, 'Q5a_Cutting_food_into_pieces', 'Q5a'), errors='coerce')
    q5b = pd.to_numeric(_get(row, 'Q5b_Cutting_food_with_utensils', 'Q5b'), errors='coerce')
    if pd.notna(q5a) and pd.notna(q5b):
        s += max(q5a, q5b)
    elif pd.notna(q5a):
        s += q5a
    elif pd.notna(q5b):
        s += q5b

    # Q6-Q9
    for q in range(6, 10):
        q_name = f'Q{q}'
        if q_name in row.index:
            s += pd.to_numeric(row[q_name], errors='coerce') or 0.0

    # Q10 = Q10a (Dyspnea)
    q10a_cols = [c for c in row.index if '10a' in c.lower() or 'dyspnea' in c.lower()]
    if q10a_cols:
        s += pd.to_numeric(row[q10a_cols[0]], errors='coerce') or 0.0
    elif 'Q10' in row.index:
        s += pd.to_numeric(row['Q10'], errors='coerce') or 0.0

    return s if s > 0 else np.nan

alsfrs_df['ALSFRS_40'] = alsfrs_df.apply(to_ALSFRS_original, axis=1)

# ============================================================================
# LABEL: ROBUST (Edge-case safe)
# ============================================================================

print("üìä Building labels (paper-faithful, edge-case safe)...")

def build_label_paper(als):
    """Paper-faithful: slope between first >90d and first >365d.
    If they're the SAME row (t2==t1), try the next >365d; else skip."""
    rows = []
    same_row_skips = 0
    no_12m_skips = 0

    g_all = als.dropna(subset=['ALSFRS_Delta', 'ALSFRS_40']).sort_values('ALSFRS_Delta')

    for sid, g in g_all.groupby(KEY):
        t0 = float(g['ALSFRS_Delta'].iloc[0])

        if not g['ALSFRS_Delta'].between(t0, t0 + 90).any():
            continue

        after3 = g[g['ALSFRS_Delta'] > t0 + 90]
        after12 = g[g['ALSFRS_Delta'] > t0 + 365]

        if after3.empty or after12.empty:
            no_12m_skips += 1
            continue

        t1 = float(after3['ALSFRS_Delta'].iloc[0])
        y1 = float(after3['ALSFRS_40'].iloc[0])
        t2 = float(after12['ALSFRS_Delta'].iloc[0])
        y2 = float(after12['ALSFRS_40'].iloc[0])

        if t2 - t1 <= 0:
            if len(after12) > 1:
                t2 = float(after12['ALSFRS_Delta'].iloc[1])
                y2 = float(after12['ALSFRS_40'].iloc[1])
            else:
                same_row_skips += 1
                continue

        dt_months = (t2 - t1) / 30.0
        if not np.isfinite(dt_months) or dt_months <= 0:
            same_row_skips += 1
            continue

        slope_pm = (y2 - y1) / dt_months
        slope_pm = float(np.clip(slope_pm, -3.0, 2.0))

        rows.append({KEY: sid, TARGET: slope_pm})

    print(f"  Skipped (no >365d): {no_12m_skips}, (same >90d row): {same_row_skips}")
    return pd.DataFrame(rows)

y_df = build_label_paper(alsfrs_df)
print(f"‚úÖ Labels: n={len(y_df)}\n")

assert (y_df[TARGET].replace([np.inf, -np.inf], np.nan).dropna().size > 0), \
    "No valid labels after filtering; check date units or table joins."

cohort = set(y_df[KEY].unique())

alsfrs_df = alsfrs_df[alsfrs_df[KEY].isin(cohort)].copy()
fvc_df = fvc_df[fvc_df[KEY].isin(cohort)].copy()
vitals_df = vitals_df[vitals_df[KEY].isin(cohort)].copy()
labs_df = labs_df[labs_df[KEY].isin(cohort)].copy()
svc_df = svc_df[svc_df[KEY].isin(cohort)].copy()
grip_df = grip_df[grip_df[KEY].isin(cohort)].copy()
onset_df = onset_df[onset_df[KEY].isin(cohort)].copy()

# ============================================================================
# FEATURES: 0-90d summaries
# ============================================================================

print("üìà Building 0-90d features (7 stats)...")

def summarize_0_90d_paper(df, time_col, value_cols, prefix, baseline_df):
    """7 stats: min, max, median, std, first, last, slope(first‚Üílast)"""
    if time_col not in df.columns or len(value_cols) == 0:
        return None

    baseline = baseline_df.groupby(KEY)['ALSFRS_Delta'].min().to_dict()
    df = df.copy()
    df[time_col] = pd.to_numeric(df[time_col], errors='coerce')

    rows = []
    for sid in df[KEY].unique():
        if sid not in baseline:
            continue

        g = df[df[KEY] == sid].copy()
        t0 = baseline[sid]
        g = g[(g[time_col] >= t0) & (g[time_col] <= t0 + 90)].sort_values(time_col)
        if g.empty:
            continue

        d = {KEY: sid}
        for col in value_cols:
            if col not in g.columns:
                continue

            numeric_vals = pd.to_numeric(g[col], errors='coerce')
            mask = numeric_vals.notna()
            if mask.sum() == 0:
                continue

            vals = numeric_vals[mask].values
            times = g.loc[mask, time_col].values

            d[f'{prefix}_{col}_first'] = float(vals[0])
            d[f'{prefix}_{col}_last'] = float(vals[-1])
            d[f'{prefix}_{col}_min'] = float(vals.min())
            d[f'{prefix}_{col}_max'] = float(vals.max())
            d[f'{prefix}_{col}_median'] = float(np.median(vals))
            d[f'{prefix}_{col}_std'] = float(vals.std()) if len(vals) > 1 else 0.0

            if len(vals) > 1:
                delta_t = (times[-1] - times[0]) / 30.0
                if delta_t > 0:
                    d[f'{prefix}_{col}_slope'] = float((vals[-1] - vals[0]) / delta_t)
                else:
                    d[f'{prefix}_{col}_slope'] = np.nan
            else:
                d[f'{prefix}_{col}_slope'] = np.nan

        if len(d) > 1:
            rows.append(d)

    return pd.DataFrame(rows) if rows else None

def find_time_col(df):
    cands = [c for c in df.columns if 'delta' in c.lower()]
    return cands[0] if cands else None

baseline_df = alsfrs_df[[KEY, 'ALSFRS_Delta']].groupby(KEY)['ALSFRS_Delta'].min().reset_index()

X_als = summarize_0_90d_paper(alsfrs_df, 'ALSFRS_Delta', ['ALSFRS_40'], 'ALS', baseline_df)

fvc_time = find_time_col(fvc_df)
X_fvc_L, X_fvc_pct = None, None
if fvc_time:
    fvc_liters = [c for c in fvc_df.columns if 'Subject_Liters' in c]
    fvc_pcts = [c for c in fvc_df.columns if 'pct_of_Normal' in c]
    if fvc_liters:
        fvc_df['FVC_L'] = fvc_df[fvc_liters].max(axis=1)
        X_fvc_L = summarize_0_90d_paper(fvc_df, fvc_time, ['FVC_L'], 'FVC_L', baseline_df)
    if fvc_pcts:
        fvc_df['FVC_pct'] = fvc_df[fvc_pcts].max(axis=1)
        X_fvc_pct = summarize_0_90d_paper(fvc_df, fvc_time, ['FVC_pct'], 'FVC_pct', baseline_df)

svc_time = find_time_col(svc_df)
X_svc_L, X_svc_pct = None, None
if svc_time:
    svc_liters = [c for c in svc_df.columns if 'Subject_Liters' in c]
    svc_pcts = [c for c in svc_df.columns if 'pct_of_Normal' in c]
    if svc_liters:
        svc_df['SVC_L'] = svc_df[svc_liters].max(axis=1)
        X_svc_L = summarize_0_90d_paper(svc_df, svc_time, ['SVC_L'], 'SVC_L', baseline_df)
    if svc_pcts:
        svc_df['SVC_pct'] = svc_df[svc_pcts].max(axis=1)
        X_svc_pct = summarize_0_90d_paper(svc_df, svc_time, ['SVC_pct'], 'SVC_pct', baseline_df)

v_time = find_time_col(vitals_df)
X_weight, X_vitals = None, None
if v_time:
    if 'Weight' in vitals_df.columns:
        X_weight = summarize_0_90d_paper(vitals_df, v_time, ['Weight'], 'WT', baseline_df)
    vital_cols = [c for c in vitals_df.columns if c in ['Blood_Pressure_Systolic', 'Blood_Pressure_Diastolic']]
    if vital_cols:
        X_vitals = summarize_0_90d_paper(vitals_df, v_time, vital_cols, 'VITAL', baseline_df)

grip_time = find_time_col(grip_df)
X_grip = None
if grip_time:
    grip_cols = [c for c in grip_df.columns if 'Test_Result' in c]
    if grip_cols:
        X_grip = summarize_0_90d_paper(grip_df, grip_time, grip_cols, 'GRIP', baseline_df)

X_labs = None
if 'Laboratory_Code' in labs_df.columns and 'Test_Result' in labs_df.columns:
    labs_cohort = labs_df[labs_df[KEY].isin(cohort)]
    top_codes = labs_cohort['Laboratory_Code'].value_counts().head(5).index.tolist()
    X_labs = y_df[[KEY]].copy()
    for code in top_codes:
        labs_code = labs_df[labs_df['Laboratory_Code'] == code].copy()
        safe_code = re.sub(r'[^a-zA-Z0-9_]', '', str(code)[:15])
        delta_col = find_time_col(labs_code)
        if delta_col:
            Xi = summarize_0_90d_paper(labs_code, delta_col, ['Test_Result'], f'LAB_{safe_code}', baseline_df)
            if Xi is not None and len(Xi) > 0:
                X_labs = X_labs.merge(Xi, on=KEY, how='left')
    X_labs = X_labs if X_labs.shape[1] > 1 else None

def extract_onset_delta(onset_df, KEY):
    cand = [c for c in onset_df.columns if 'onset' in c.lower() and 'delta' in c.lower()]
    if not cand:
        return None
    od = onset_df[[KEY, cand[0]]].copy()
    od.columns = [KEY, 'Onset_Delta']
    od = od.dropna().groupby(KEY, as_index=False).first()
    return od

X_onset = extract_onset_delta(onset_df, KEY)

print(f"  Built features\n")

# ============================================================================
# MERGE ALL BLOCKS
# ============================================================================

print("üîó Merging...")
X_all = y_df[[KEY, TARGET]].copy()

for block in [X_als, X_fvc_L, X_fvc_pct, X_svc_L, X_svc_pct, X_weight, X_vitals, X_grip, X_labs, X_onset]:
    if block is not None and len(block) > 0:
        X_all = X_all.merge(block, on=KEY, how='left')

missing_rates = X_all.isnull().mean()
keep_cols = [KEY, TARGET] + [c for c in X_all.columns if c not in [KEY, TARGET] and missing_rates[c] <= 0.30]
data = X_all[keep_cols].copy()

print(f"  Kept: {len(keep_cols)-2} features")
if X_onset is not None and 'Onset_Delta' in data.columns:
    print(f"  ‚úì Onset_Delta included\n")
else:
    print()

# ============================================================================
# SPLIT
# ============================================================================

print("üìä Split (80/20)...\n")

X = data.drop(columns=[KEY, TARGET], errors='ignore')
y = data[TARGET].values

X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.20, random_state=SEED)

# ============================================================================
# PREPROCESS
# ============================================================================

print("üîß Preprocess (fit on train only)...")

imp = SimpleImputer(strategy='median')
scl = StandardScaler()

X_tr_imp = imp.fit_transform(X_tr)
X_tr_s = scl.fit_transform(X_tr_imp)

X_te_imp = imp.transform(X_te)
X_te_s = scl.transform(X_te_imp)

print(f"  Train: {X_tr_s.shape} | Test: {X_te_s.shape}\n")

y_mu, y_sigma = np.mean(y_tr), np.std(y_tr)
y_tr_n = (y_tr - y_mu) / (y_sigma + 1e-8)
y_te_n = (y_te - y_mu) / (y_sigma + 1e-8)

# ============================================================================
# FEATURE RANKING & SELECTION (FIX #1: Honest K)
# ============================================================================

print("üìä Feature ranking + K-selection (5-fold CV on train)...")

import scipy.stats as ss

feat_names = X.columns.tolist()

rf = RandomForestRegressor(n_estimators=400, max_depth=12, random_state=SEED, n_jobs=-1)
rf.fit(X_tr_s, y_tr_n)
rf_rank = np.argsort(rf.feature_importances_)[::-1]

print(f"\nTop-15 features (RF):")
for j in range(min(15, len(rf_rank))):
    i = rf_rank[j]
    print(f"  {j+1:2d}. {feat_names[i]} (imp={rf.feature_importances_[i]:.4f})")

# FIX #1: Make K grid honest (clamped to available features)
N_FEATS = X_tr_s.shape[1]
K_grid = sorted(set(min(k, N_FEATS) for k in [10, 20, 30, 40, 60, N_FEATS]))

best_k = min(30, N_FEATS)
best_cv_pcc = -1.0

print(f"\nGridding K ‚àà {K_grid} via 5-fold CV (max available: {N_FEATS})...")

for k in K_grid:
    k_eff = min(k, N_FEATS)  # Clamp to available
    top_rf_idx = rf_rank[:k_eff]
    top_rf_names = [feat_names[i] for i in top_rf_idx]

    X_tr_k = X_tr[top_rf_names].values

    imp_k = SimpleImputer(strategy='median').fit(X_tr_k)
    scl_k = StandardScaler().fit(imp_k.transform(X_tr_k))
    X_tr_k_s = scl_k.transform(imp_k.transform(X_tr_k))

    from sklearn.linear_model import Ridge
    cv_scores = []
    for i in range(5):
        idx = np.arange(len(X_tr_k_s))
        np.random.seed(SEED + i)
        np.random.shuffle(idx)
        cut = int(0.8 * len(idx))
        tr_idx, va_idx = idx[:cut], idx[cut:]

        X_cv_tr, X_cv_va = X_tr_k_s[tr_idx], X_tr_k_s[va_idx]
        y_cv_tr, y_cv_va = y_tr_n[tr_idx], y_tr_n[va_idx]

        ridge = Ridge(alpha=1.0).fit(X_cv_tr, y_cv_tr)
        phat = ridge.predict(X_cv_va)
        pcc = np.corrcoef(phat, y_cv_va)[0, 1]
        cv_scores.append(pcc)

    cv_pcc_mean = np.mean(cv_scores)
    print(f"  K={k:2d} (eff={k_eff:2d}): CV PCC = {cv_pcc_mean:.4f}")

    if cv_pcc_mean > best_cv_pcc:
        best_cv_pcc = cv_pcc_mean
        best_k = k_eff

print(f"\n‚úì Best K = {best_k} (CV PCC={best_cv_pcc:.4f})\n")

# FIX #2: Use actual K for final selection
ACTUAL_K = min(best_k, N_FEATS)

top_rf_idx = rf_rank[:ACTUAL_K]
top_rf_names = [feat_names[i] for i in top_rf_idx]

X_tr_top = X_tr[top_rf_names].values
X_te_top = X_te[top_rf_names].values

imp_top = SimpleImputer(strategy='median').fit(X_tr_top)
scl_top = StandardScaler().fit(imp_top.transform(X_tr_top))

X_tr_s = scl_top.transform(imp_top.transform(X_tr_top))
X_te_s = scl_top.transform(imp_top.transform(X_te_top))

print(f"Final shape: Train {X_tr_s.shape} | Test {X_te_s.shape}")
print(f"Using ACTUAL_K={ACTUAL_K} features downstream (FF/QNN).\n")

# ============================================================================
# SANITY CHECK: FFNN
# ============================================================================

print("üß™ SANITY CHECK: Feed-forward NN...")

class FF(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, 128, dtype=torch.float32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64, dtype=torch.float32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1, dtype=torch.float32)
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

ff = FF(ACTUAL_K)

Xt = torch.tensor(X_tr_s, dtype=torch.float32)
yt = torch.tensor(y_tr_n, dtype=torch.float32)
Xe = torch.tensor(X_te_s, dtype=torch.float32)
ye = torch.tensor(y_te_n, dtype=torch.float32)

opt = torch.optim.Adam(ff.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()

for epoch in range(150):
    ff.train()
    opt.zero_grad()
    pred = ff(Xt)
    loss = loss_fn(pred, yt)
    loss.backward()
    opt.step()

ff.eval()
with torch.no_grad():
    phat = ff(Xe).cpu().numpy()

rmse_ff = np.sqrt(mean_squared_error(y_te, phat * y_sigma + y_mu))
pcc_ff = pearsonr(y_te, phat * y_sigma + y_mu)[0]

print(f"FFNN: RMSE={rmse_ff:.4f}, PCC={pcc_ff:.4f}")
print(f"(Expected paper: RMSE ~0.52-0.55, PCC ~0.41-0.46)\n")

# ============================================================================
# BATCHED QNN
# ============================================================================

print("üöÄ Pure Batched QNN Training...\n")

n_qubits, L = 4, 2

try:
    dev = qml.device("lightning.qubit", wires=n_qubits, shots=None)
    print(f"‚úì Using lightning.qubit\n")
except Exception:
    dev = qml.device("default.qubit", wires=n_qubits, shots=None)
    print(f"‚úì Using default.qubit\n")

def circuit(inputs, weights):
    qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation='Z')
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

weight_shapes = {"weights": (L, n_qubits, 3)}

qnode = qml.QNode(circuit, dev, interface="torch", diff_method="adjoint")
qlayer = TorchLayer(qnode, weight_shapes)

# FIX #2: Use ACTUAL_K for compression layer
compress = nn.Linear(ACTUAL_K, n_qubits, dtype=torch.float32)
head = nn.Sequential(
    nn.Linear(n_qubits, 32, dtype=torch.float32),
    nn.Tanh(),
    nn.Linear(32, 1, dtype=torch.float32)
)

class QNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.compress = compress
        self.qlayer = qlayer
        self.head = head

    def forward(self, x):
        z4 = self.compress(x)
        qout = self.qlayer(z4)
        return self.head(qout).squeeze(-1)

model = QNNModel()

print("üß™ Smoke test...")
with torch.no_grad():
    xb = torch.tensor(X_tr_s[:2], dtype=torch.float32)
    yb = model(xb)
    print(f"‚úì Model batched: {yb.shape}\n")

opt = torch.optim.Adam(model.parameters(), lr=2e-3, weight_decay=1e-4)
loss_fn = nn.MSELoss()

BATCH = 128
EPOCHS = 80

dl = DataLoader(TensorDataset(Xt, yt), batch_size=BATCH, shuffle=True)

best_rmse = float('inf')
best_state = None
wait = 0

pbar = tqdm(range(EPOCHS), desc="QNN Training")
for epoch in pbar:
    model.train()
    for xb, yb in dl:
        opt.zero_grad()
        pred = model(xb)
        loss = loss_fn(pred, yb)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            pv = model(Xe)
        rmse_v = np.sqrt(loss_fn(pv, ye).item())
        pbar.set_postfix({'Val_RMSE': f'{rmse_v:.4f}'})

        if rmse_v < best_rmse:
            best_rmse = rmse_v
            wait = 0
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            wait += 1

        if wait >= 15:
            break

if best_state:
    model.load_state_dict(best_state)

# ============================================================================
# EVALUATION
# ============================================================================

print("\n" + "="*80)
print("üéØ FINAL RESULTS (SHAPE-SAFE + Q5-FIXED)")
print("="*80)

model.eval()
with torch.no_grad():
    yhat_te_n = model(Xe).cpu().numpy()

yhat_te = yhat_te_n * y_sigma + y_mu
rmse_qnn = np.sqrt(mean_squared_error(y_te, yhat_te))
pcc_qnn = pearsonr(y_te, yhat_te)[0]

print(f"Config: 4q√ó{L}L (batched) + compress({ACTUAL_K}‚Üí{n_qubits}) + head")
print(f"Features: {ACTUAL_K} (honest K selection)")
print(f"  ‚úì Q10a rule + Q5 safe coalescing")
print(f"  ‚úì Slope = first‚Üílast")
print(f"  ‚úì Onset_Delta included")
print(f"  ‚úì Handles t1==t2 edge case")
print(f"  ‚úì Shape-matched: compress({ACTUAL_K}‚Üí4)")
print(f"Train/Test: {len(y_tr)}/{len(y_te)}")
print()
print(f"FFNN: RMSE={rmse_ff:.4f}, PCC={pcc_ff:.4f}")
print(f"QNN:  RMSE={rmse_qnn:.4f}, PCC={pcc_qnn:.4f}")
print(f"Œî:    RMSE={rmse_qnn-rmse_ff:+.4f}, PCC={pcc_qnn-pcc_ff:+.4f}")
print()

if rmse_ff <= 0.56 and pcc_ff >= 0.40:
    print("‚úÖ FFNN MATCHES PAPER EXPECTATIONS!")
else:
    print(f"‚ö†Ô∏è  FFNN may still deviate")

if pcc_qnn >= pcc_ff:
    print(f"‚úÖ QNN beats FFNN!")
else:
    print(f"üìä FFNN stronger")

print("="*80)
print("\n‚úÖ Production-ready: shape-safe, honest K, Q5 fixed.\n")


‚öõÔ∏è  PURE QNN - PRODUCTION FINAL (Shape-Safe)

üìÇ Loading PROACT data...
‚úÖ Loaded

üîÑ Converting ALSFRS-R ‚Üí ALSFRS (40-point, paper-faithful)...
üìä Building labels (paper-faithful, edge-case safe)...
  Skipped (no >365d): 5112, (same >90d row): 4
‚úÖ Labels: n=3091

üìà Building 0-90d features (7 stats)...
  Built features

üîó Merging...
  Kept: 36 features
  ‚úì Onset_Delta included

üìä Split (80/20)...

üîß Preprocess (fit on train only)...
  Train: (2472, 36) | Test: (619, 36)

üìä Feature ranking + K-selection (5-fold CV on train)...

Top-15 features (RF):
   1. ALS_ALSFRS_40_first (imp=0.3299)
   2. Onset_Delta (imp=0.1579)
   3. ALS_ALSFRS_40_last (imp=0.0368)
   4. ALS_ALSFRS_40_slope (imp=0.0351)
   5. FVC_L_FVC_L_slope (imp=0.0331)
   6. FVC_L_FVC_L_std (imp=0.0297)
   7. WT_Weight_slope (imp=0.0231)
   8. ALS_ALSFRS_40_median (imp=0.0223)
   9. ALS_ALSFRS_40_min (imp=0.0207)
  10. VITAL_Blood_Pressure_Diastolic_slope (imp=0.0198)
  11. VITAL_Blood_Pressure

QNN Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 80/80 [22:44<00:00, 17.05s/it, Val_RMSE=1.0508]



üéØ FINAL RESULTS (SHAPE-SAFE + Q5-FIXED)
Config: 4q√ó2L (batched) + compress(36‚Üí4) + head
Features: 36 (honest K selection)
  ‚úì Q10a rule + Q5 safe coalescing
  ‚úì Slope = first‚Üílast
  ‚úì Onset_Delta included
  ‚úì Handles t1==t2 edge case
  ‚úì Shape-matched: compress(36‚Üí4)
Train/Test: 2472/619

FFNN: RMSE=0.4245, PCC=0.6519
QNN:  RMSE=0.5592, PCC=nan
Œî:    RMSE=+0.1347, PCC=+nan

‚úÖ FFNN MATCHES PAPER EXPECTATIONS!
üìä FFNN stronger

‚úÖ Production-ready: shape-safe, honest K, Q5 fixed.

