<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/ALS_QNN_PRO_ACT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
%pip install qiskit~=1.0 qiskit-machine-learning~=0.8.1 qiskit_algorithms

# Qiskit Imports
from qiskit.circuit.library import ZZFeatureMap, RealAmplitudes
from qiskit_algorithms.optimizers import COBYLA
from qiskit_machine_learning.algorithms.regressors import VQR
from qiskit.primitives import Sampler

Collecting qiskit~=1.0
  Downloading qiskit-1.4.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting qiskit-machine-learning~=0.8.1
  Downloading qiskit_machine_learning-0.8.4-py3-none-any.whl.metadata (13 kB)
Collecting qiskit_algorithms
  Downloading qiskit_algorithms-0.4.0-py3-none-any.whl.metadata (4.7 kB)
Collecting rustworkx>=0.15.0 (from qiskit~=1.0)
  Downloading rustworkx-0.17.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting stevedore>=3.0.0 (from qiskit~=1.0)
  Downloading stevedore-5.5.0-py3-none-any.whl.metadata (2.2 kB)
Collecting symengine<0.14,>=0.11 (from qiskit~=1.0)
  Downloading symengine-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting scipy>=1.5 (from qiskit~=1.0)
  Downloading scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31

In [None]:
import pandas as pd
import numpy as np
import warnings
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
warnings.filterwarnings('ignore')
class ALSDataProcessor:
    """
    A robust class to load, clean, and process PRO-ACT data for predicting ALSFRS slope,
    replicating the methodology from the "Deep learning methods to predict amyotrophic
    lateral sclerosis disease progression" paper.
    """
    def __init__(self):
        self.label_encoders = {}
        # A list of columns to exclude from feature engineering
        self.id_and_delta_cols = [
            'subject_id', 'alsfrs_delta', 'fvc_delta', 'vitals_delta',
            'labs_delta', 'grip_delta', 'muscle_delta', 'onset_delta',
            'death_delta', 'history_delta'
        ]
    def _convert_alsfrs_r(self, alsfrs_df):
        """Convert ALSFRS-R questions to the original ALSFRS format."""
        df = alsfrs_df.copy()
        # Ensure ALSFRS_Total is numeric, coercing errors
        df['ALSFRS_Total'] = pd.to_numeric(df['ALSFRS_Total'], errors='coerce')
        return df
    def load_and_inspect_data(self, file_path=''):
        """Load all datasets and inspect their structure."""
        datasets = {}
        file_list = [
            'PROACT_ALSFRS.csv', 'PROACT_FVC.csv', 'PROACT_VITALSIGNS.csv',
            'PROACT_RILUZOLE.csv', 'PROACT_DEMOGRAPHICS.csv', 'PROACT_LABS.csv',
            'PROACT_DEATHDATA.csv', 'PROACT_HANDGRIPSTRENGTH.csv',
            'PROACT_MUSCLESTRENGTH.csv', 'PROACT_ALSHISTORY.csv' # Added missing file
        ]
        print("--- Loading and Inspecting Data ---")
        for file_name in file_list:
            try:
                df = pd.read_csv(file_path + file_name, on_bad_lines='skip')
                # --- CORRECTED RENAMING LOGIC ---
                # Check if 'subject_id' already exists. If not, find a candidate and rename only the first one found.
                if 'subject_id' not in df.columns:
                    potential_id_cols = [col for col in df.columns if 'subject' in col.lower()]
                    if potential_id_cols:
                        df.rename(columns={potential_id_cols[0]: 'subject_id'}, inplace=True)
                # --- END CORRECTION ---
                # Convert delta columns to numeric
                delta_cols = [col for col in df.columns if 'delta' in col.lower()]
                for col in delta_cols:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                datasets[file_name] = df
                print(f"✓ {file_name}: Loaded successfully with shape {df.shape}")
            except FileNotFoundError:
                print(f"✗ {file_name}: File not found. Will be skipped.")
        return datasets
    def calculate_alsfrs_slope(self, alsfrs_df):
        """Calculate the primary target variable: ALSFRS slope between months 3-12."""
        df = alsfrs_df.copy()
        df.rename(columns={c:'alsfrs_delta' for c in df.columns if 'delta' in c.lower()}, inplace=True)
        df['months'] = df['alsfrs_delta'] / 30.44
        df.sort_values(['subject_id', 'months'], inplace=True)
        slopes = {}
        for subject_id, subject_data in df.groupby('subject_id'):
            t1_candidates = subject_data[(subject_data['months'] > 3) & (subject_data['months'] <= 12)]
            t2_candidates = subject_data[subject_data['months'] >= 12]
            if not t1_candidates.empty and not t2_candidates.empty:
                t1_row = t1_candidates.iloc[0]
                t2_row = t2_candidates.iloc[0]
                t1, alsfrs_t1 = t1_row['months'], t1_row['ALSFRS_Total']
                t2, alsfrs_t2 = t2_row['months'], t2_row['ALSFRS_Total']
                if t2 > t1 and pd.notna(alsfrs_t1) and pd.notna(alsfrs_t2):
                    slope = (alsfrs_t2 - alsfrs_t1) / (t2 - t1)
                    slopes[subject_id] = slope
        return pd.DataFrame(list(slopes.items()), columns=['subject_id', 'alsfrs_slope'])
    def create_longitudinal_features(self, df, time_col, prefix):
        """Create the seven summary statistics from longitudinal data (first 3 months)."""
        df_sorted = df.sort_values(['subject_id', time_col])
        # Convert potential value columns to numeric
        potential_value_cols = [col for col in df_sorted.columns if col not in ['subject_id', time_col]]
        for col in potential_value_cols:
            df_sorted[col] = pd.to_numeric(df_sorted[col], errors='coerce')
        df_filtered = df_sorted[df_sorted[time_col] <= 90].copy()
        value_cols = [col for col in df_filtered.select_dtypes(include=np.number).columns
                      if col.lower() not in self.id_and_delta_cols]
        if not value_cols:
            return pd.DataFrame()
        summary_dfs = []
        for value_col in value_cols:
            grouped = df_filtered.groupby('subject_id')
            summary = grouped[value_col].agg(['min', 'max', 'median', 'first', 'last']).join(
                grouped[value_col].std(ddof=0).rename('std')
            )
            # Ensure there are at least two data points for slope calculation
            slope_df = grouped.apply(
                lambda g: (g[value_col].iloc[-1] - g[value_col].iloc[0]) / (g[time_col].iloc[-1] - g[time_col].iloc[0])
                if len(g) > 1 and (g[time_col].iloc[-1] - g[time_col].iloc[0]) > 0 else np.nan
            ).rename('slope')
            summary = summary.join(slope_df).fillna(0) # Fill NaN slopes with 0
            summary.columns = [f"{prefix}{value_col}_{stat}" for stat in summary.columns]
            summary_dfs.append(summary)
        return pd.concat(summary_dfs, axis=1).reset_index()
    def process_static_data(self, df):
        """Process static data files (like demographics, riluzole)."""
        processed = df.copy()
        for col in processed.select_dtypes(include=['object', 'category']).columns:
            if col != 'subject_id':
                le = self.label_encoders.setdefault(col, LabelEncoder())
                processed[col] = le.fit_transform(processed[col].astype(str))
        return processed.drop_duplicates(subset=['subject_id'])
    def merge_all_features(self, datasets):
        """Merge all static and longitudinal features into a single dataframe."""
        if 'PROACT_DEMOGRAPHICS.csv' not in datasets:
            raise ValueError("Demographics file is missing.")
        final_df = self.process_static_data(datasets['PROACT_DEMOGRAPHICS.csv'])
        static_files = ['PROACT_RILUZOLE.csv', 'PROACT_ALSHISTORY.csv'] # Added ALSHISTORY
        for file in static_files:
            if file in datasets:
                static_df = self.process_static_data(datasets[file])
                final_df = pd.merge(final_df, static_df, on='subject_id', how='left')
        longitudinal_configs = {
            'PROACT_ALSFRS.csv': 'alsfrs_',
            'PROACT_FVC.csv': 'fvc_',
            'PROACT_VITALSIGNS.csv': 'vitals_',
            'PROACT_LABS.csv': 'labs_',
            'PROACT_HANDGRIPSTRENGTH.csv': 'grip_',
            'PROACT_MUSCLESTRENGTH.csv': 'muscle_'
        }
        print("\n--- Generating Longitudinal Features (from first 3 months) ---")
        for file, prefix in longitudinal_configs.items():
            if file in datasets:
                df = datasets[file].copy()
                time_col_actual = next((c for c in df.columns if 'delta' in c.lower()), None)
                if not time_col_actual:
                    print(f"Warning: No time delta column found in {file}. Skipping.")
                    continue
                print(f"Processing {file}...")
                # Pivot long format files if necessary
                if file in ['PROACT_LABS.csv', 'PROACT_MUSCLESTRENGTH.csv', 'PROACT_HANDGRIPSTRENGTH.csv']:
                    try:
                        test_cols = [c for c in df.columns if any(keyword in c.lower() for keyword in ['test', 'exam', 'muscle', 'site', 'name', 'strength_test']) and c not in ['subject_id', time_col_actual]]
                        if test_cols:
                            test_col = test_cols[0]
                            value_cols = [c for c in df.columns if any(keyword in c.lower() for keyword in ['result', 'value', 'strength', 'score']) and c not in ['subject_id', time_col_actual]]
                            if value_cols:
                                value_col = value_cols[0]
                                df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
                                df = df.pivot_table(index=['subject_id', time_col_actual], columns=test_col, values=value_col, aggfunc='mean').reset_index()
                    except Exception as e:
                        print(f"Warning: Pivoting failed for {file}: {e}")
                summary_features = self.create_longitudinal_features(df, time_col_actual, prefix)
                if not summary_features.empty:
                    final_df = pd.merge(final_df, summary_features, on='subject_id', how='left')
        return final_df
    def filter_eligible_patients(self, feature_df, alsfrs_df):
        """Filter for patients meeting the paper's criteria."""
        df = alsfrs_df.copy()
        df.rename(columns={c:'alsfrs_delta' for c in df.columns if 'delta' in c.lower()}, inplace=True)
        df['months'] = df['alsfrs_delta'] / 30.44
        eligibility = df.groupby('subject_id')['months'].agg(['min', 'max'])
        eligible_ids = eligibility[(eligibility['min'] <= 3) & (eligibility['max'] >= 12)].index
        print(f"\nFound {len(eligible_ids)} eligible patients out of {df['subject_id'].nunique()}.")
        return feature_df[feature_df['subject_id'].isin(eligible_ids)]
    def run_pipeline(self, file_path=''):
        """Execute the complete data preprocessing pipeline."""
        print("====== Starting ALS Data Preprocessing Pipeline ======")
        datasets = self.load_and_inspect_data(file_path)
        if 'PROACT_ALSFRS.csv' not in datasets:
            print("CRITICAL ERROR: PROACT_ALSFRS.csv not found. Aborting.")
            return None
        datasets['PROACT_ALSFRS.csv'] = self._convert_alsfrs_r(datasets['PROACT_ALSFRS.csv'])
        target_df = self.calculate_alsfrs_slope(datasets['PROACT_ALSFRS.csv'])
        print(f"\nCalculated ALSFRS slope for {len(target_df)} patients.")
        full_features = self.merge_all_features(datasets)
        eligible_features = self.filter_eligible_patients(full_features, datasets['PROACT_ALSFRS.csv'])
        final_df = pd.merge(eligible_features, target_df, on='subject_id', how='inner')
        print("\n--- Handling Missing Values ---")
        missing_thresh = 0.30
        initial_cols = len(final_df.columns)
        max_missing = len(final_df) * (1 - missing_thresh)
        final_df.dropna(axis=1, thresh=max_missing, inplace=True)
        print(f"Dropped {initial_cols - len(final_df.columns)} features with >{missing_thresh*100}% missing values.")
        X = final_df.drop(columns=['subject_id', 'alsfrs_slope'])
        y = final_df['alsfrs_slope']
        valid_y_mask = y.notna()
        X = X[valid_y_mask]
        y = y[valid_y_mask]
        subject_ids = final_df.loc[valid_y_mask, 'subject_id']
        imputer = SimpleImputer(strategy='median')
        X_imputed = pd.DataFrame(imputer.fit_transform(X), columns=X.columns)
        print("\n--- Performing Feature Selection (Top 30 via Random Forest) ---")
        rf = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
        rf.fit(X_imputed, y)
        importance_df = pd.DataFrame({
            'feature': X.columns,
            'importance': rf.feature_importances_
        }).sort_values('importance', ascending=False)
        selected_features = importance_df['feature'].head(30).tolist()
        X_selected = X_imputed[selected_features]
        print("\n====== Pipeline Complete ======")
        print(f"Final feature matrix shape: {X_selected.shape}")
        print(f"Final target vector shape: {y.shape}")
        # Save the final data for the next step
        final_output = pd.concat([subject_ids.reset_index(drop=True),
                                  y.reset_index(drop=True),
                                  X_selected.reset_index(drop=True)], axis=1)
        final_output.to_csv("final_processed_als_data.csv", index=False)
        print("\n✅ Successfully saved processed data to 'final_processed_als_data.csv'")
        return {
            'X': X_selected,
            'y': y,
            'subject_ids': subject_ids,
            'feature_importance': importance_df,
        }
if __name__ == "__main__":
    # --- IMPORTANT ---
    # If your CSV files are in a different folder, change this path.
    # For example: file_path = "C:/Users/YourUser/Downloads/PROACT_data/"
    file_path = ""
    processor = ALSDataProcessor()
    processed_data = processor.run_pipeline(file_path=file_path)
    if processed_data:
        print("\n--- Top 15 Most Important Features ---")
        print(processed_data['feature_importance'].head(15))

--- Loading and Inspecting Data ---
✓ PROACT_ALSFRS.csv: Loaded successfully with shape (73845, 20)
✓ PROACT_FVC.csv: Loaded successfully with shape (49110, 10)
✓ PROACT_VITALSIGNS.csv: Loaded successfully with shape (84721, 36)
✓ PROACT_RILUZOLE.csv: Loaded successfully with shape (10363, 3)
✓ PROACT_DEMOGRAPHICS.csv: Loaded successfully with shape (12504, 14)
✓ PROACT_LABS.csv: Loaded successfully with shape (2937162, 5)
✓ PROACT_DEATHDATA.csv: Loaded successfully with shape (5043, 3)
✓ PROACT_HANDGRIPSTRENGTH.csv: Loaded successfully with shape (19032, 11)
✓ PROACT_MUSCLESTRENGTH.csv: Loaded successfully with shape (204875, 10)
✓ PROACT_ALSHISTORY.csv: Loaded successfully with shape (13765, 16)

Calculated ALSFRS slope for 2023 patients.

--- Generating Longitudinal Features (from first 3 months) ---
Processing PROACT_ALSFRS.csv...
Processing PROACT_FVC.csv...
Processing PROACT_VITALSIGNS.csv...
Processing PROACT_LABS.csv...
Processing PROACT_HANDGRIPSTRENGTH.csv...
Processing PROAC

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import warnings

warnings.filterwarnings('ignore')

def calculate_metrics(y_true, y_pred):
    """Calculates RMSD and PCC."""
    rmsd = np.sqrt(mean_squared_error(y_true, y_pred))
    pcc, _ = pearsonr(y_true, y_pred)
    return rmsd, pcc

def run_classical_pipeline():
    """
    Loads the processed data, trains baseline models, and evaluates their performance.
    """
    print("====== Starting Classical Baseline Model Pipeline ======")

    # --- 1. Load Data ---
    try:
        data = pd.read_csv("final_processed_als_data.csv")
        print(f"✓ Successfully loaded 'final_processed_als_data.csv' with shape {data.shape}")
    except FileNotFoundError:
        print("✗ ERROR: 'final_processed_als_data.csv' not found. Please run the preprocessing script first.")
        return

    # --- 2. Prepare Data ---
    X = data.drop(columns=['subject_id', 'alsfrs_slope'])
    y = data['alsfrs_slope']

    # 80/20 Train-Test Split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    print(f"Data split into training ({X_train.shape[0]} samples) and testing ({X_test.shape[0]} samples).")

    # Scale data for SVR
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # --- 3. Train and Evaluate Models ---
    results = {}

    # Model 1: Random Forest Regressor
    print("\n--- Training Random Forest Regressor ---")
    rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
    rf_model.fit(X_train, y_train)
    rf_preds = rf_model.predict(X_test)
    rf_rmsd, rf_pcc = calculate_metrics(y_test, rf_preds)
    results['Random Forest'] = {'RMSD': rf_rmsd, 'PCC': rf_pcc}
    print("✓ Training and evaluation complete.")

    # Model 2: Support Vector Regressor
    print("\n--- Training Support Vector Regressor (SVR) ---")
    svr_model = SVR(kernel='rbf', C=1.0, epsilon=0.1)
    svr_model.fit(X_train_scaled, y_train)
    svr_preds = svr_model.predict(X_test_scaled)
    svr_rmsd, svr_pcc = calculate_metrics(y_test, svr_preds)
    results['Support Vector Regressor'] = {'RMSD': svr_rmsd, 'PCC': svr_pcc}
    print("✓ Training and evaluation complete.")

    # --- 4. Display Results ---
    print("\n====== Classical Model Performance ======")
    results_df = pd.DataFrame(results).T
    print(results_df)
    print("\nReminder:")
    print("  - RMSD (Root Mean Squared Deviation): Lower is better.")
    print("  - PCC (Pearson Correlation Coefficient): Higher is better (closer to 1.0).")

    return results_df

if __name__ == "__main__":
    run_classical_pipeline()

✓ Successfully loaded 'final_processed_als_data.csv' with shape (2022, 32)
Data split into training (1617 samples) and testing (405 samples).

--- Training Random Forest Regressor ---
✓ Training and evaluation complete.

--- Training Support Vector Regressor (SVR) ---
✓ Training and evaluation complete.

                              RMSD       PCC
Random Forest             0.560428  0.266818
Support Vector Regressor  0.574893  0.246569

Reminder:
  - RMSD (Root Mean Squared Deviation): Lower is better.
  - PCC (Pearson Correlation Coefficient): Higher is better (closer to 1.0).


In [8]:
# qrf_qnn_residual_v4_fastbatch.py
import os, time, numpy as np, pandas as pd
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore")

# CPU hygiene
for k in ["OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", "NUMEXPR_NUM_THREADS"]:
    os.environ.setdefault(k, "1")
np.random.seed(42)

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.impute import SimpleImputer
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.feature_selection import mutual_info_regression
from sklearn.linear_model import RidgeCV
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr

# --- Qiskit
from qiskit.circuit.library import ZZFeatureMap, EfficientSU2
from qiskit.quantum_info import SparsePauliOp
from qiskit.primitives import Estimator
try:
    from qiskit_aer.primitives import Estimator as AerEstimator
    AER_OK = True
except Exception:
    AER_OK = False

# ---------------- utils ----------------
def safe_pcc(a, b):
    a, b = np.asarray(a).ravel(), np.asarray(b).ravel()
    if a.std()==0 or b.std()==0: return 0.0
    v = pearsonr(a, b)[0]
    return float(v) if np.isfinite(v) else 0.0

def metrics(y_true, y_pred):
    rmsd = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    return rmsd, safe_pcc(y_true, y_pred), float(r2_score(y_true, y_pred))

def direction_accuracy(y_true, y_pred, stable_band=0.10):
    def cls(y):
        if y < -stable_band: return 0
        if y >  stable_band: return 2
        return 1
    yt = np.array([cls(v) for v in y_true])
    yp = np.array([cls(v) for v in y_pred])
    return float((yt == yp).mean())

def load_xy(path="final_processed_als_data.csv"):
    df = pd.read_csv(path)
    if "alsfrs_slope" not in df.columns:
        raise ValueError("Target 'alsfrs_slope' missing in CSV.")
    X = df.drop(columns=["subject_id", "alsfrs_slope"], errors="ignore")
    y = df["alsfrs_slope"].values
    m = ~np.isnan(y)
    X, y = X.loc[m].reset_index(drop=True), y[m]
    print(f"✓ Data loaded: X={X.shape}, y={y.shape}", flush=True)
    return X, y

def select_topk_features(X, y, k=8):
    imp = SimpleImputer(strategy="median")
    Xn = imp.fit_transform(X)
    rf = RandomForestRegressor(n_estimators=200, random_state=42, n_jobs=-1).fit(Xn, y)
    rf_rank = rf.feature_importances_
    mi = mutual_info_regression(Xn, y, random_state=42)
    corr = np.array([abs(np.corrcoef(Xn[:, i], y)[0,1]) if Xn[:, i].std()>0 else 0.0
                     for i in range(Xn.shape[1])])
    def nz(v): m=v.max(); return v/(m+1e-8) if m>0 else v
    score = nz(rf_rank) + nz(mi) + nz(corr)
    idx = np.argsort(score)[::-1][:k]
    cols = [X.columns[i] for i in idx]
    print(f"✓ Top-{k} features: {cols}", flush=True)
    return idx, cols

# ---------------- quantum helpers ----------------
def make_observables(n_qubits=4, use_pairs=True):
    obs = []
    for i in range(n_qubits):  # <Z_i>
        p = ['I']*n_qubits; p[i] = 'Z'
        obs.append(SparsePauliOp.from_list([("".join(p[::-1]), 1.0)]))
    if use_pairs:              # <Z_i Z_j>
        for i in range(n_qubits):
            for j in range(i+1, n_qubits):
                p = ['I']*n_qubits; p[i]=p[j]='Z'
                obs.append(SparsePauliOp.from_list([("".join(p[::-1]), 1.0)]))
    return obs

def _idx_from_param_name(name: str) -> int:
    if '[' in name and ']' in name:
        return int(name[name.find('[')+1:name.find(']')])
    if '_' in name:
        return int(name.split('_')[-1])
    return int(''.join(ch for ch in name if ch.isdigit()))

def build_random_sink(n_qubits=4, fmap_reps=2, ansatz_reps=1, rng=None):
    rng = np.random.default_rng(None if rng is None else rng)
    fmap = ZZFeatureMap(feature_dimension=n_qubits, reps=fmap_reps)
    ans  = EfficientSU2(num_qubits=n_qubits, reps=ansatz_reps, entanglement="linear")
    circ = fmap.compose(ans)
    # freeze random ansatz weights
    rand_theta = {p: float(rng.normal(0, 0.25)) for p in ans.parameters}
    circ = circ.assign_parameters(rand_theta, inplace=False)
    # pull feature params from composed circuit
    feat_params = [p for p in circ.parameters if p.name.startswith("x")]
    feat_params = sorted(feat_params, key=lambda p: _idx_from_param_name(p.name))
    assert len(feat_params) == n_qubits, f"Expected {n_qubits} feature params, got {len(feat_params)}"
    return circ, feat_params

def build_estimator():
    if AER_OK:
        # statevector mode is typically fastest for exact expectations
        return AerEstimator(run_options={"shots": None}, backend_options={"method": "statevector"})
    return Estimator()

# -------- batched featurization (BIG speed win) --------
def qrf_features_batched(estimator, sinks, observables, X_theta, batch_size=64):
    """
    sinks: list of (circ, feat_params)
    X_theta: (N, n_qubits) in [0, π]
    Returns: (N, T*D)
    """
    N = X_theta.shape[0]
    T = len(sinks)
    D = len(observables)
    feats = np.empty((N, T * D), dtype=float)

    for t, (circ, feat_params) in enumerate(sinks):
        col0 = t * D
        for s in tqdm(range(0, N, batch_size), desc=f"Sink {t+1}/{T}", leave=False):
            e = min(N, s + batch_size)
            circuits, obs_list = [], []
            for i in range(s, e):
                pmap = {feat_params[k]: float(X_theta[i, k]) for k in range(len(feat_params))}
                c_bound = circ.assign_parameters(pmap, inplace=False)
                circuits.extend([c_bound] * D)
                obs_list.extend(observables)
            vals = estimator.run(circuits, obs_list).result().values
            feats[s:e, col0:col0+D] = np.array(vals).reshape(e - s, D)
    return feats

# ---------------- QRFR fit/predict ----------------
def qrfr_fit_predict(
    X_tr, X_te, y_tr,
    n_qubits=4, topk_for_pca=8,
    fmap_reps=2, ansatz_reps=1, use_pairs=True,
    n_sinks=8, whiten=True,
    alphas=(0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0),
    train_cap=900,       # coreset size for speed; set None to use all
    batch_size=64
):
    # top-k → PCA→ n_qubits angles
    idxk, colsk = select_topk_features(X_tr, y_tr, k=topk_for_pca)
    imp = SimpleImputer(strategy="median")
    std = StandardScaler()
    Xtr_k = std.fit_transform(imp.fit_transform(X_tr.iloc[:, idxk]))
    Xte_k = std.transform(imp.transform(X_te.iloc[:, idxk]))
    pca = PCA(n_components=n_qubits, random_state=42)
    Xtr = pca.fit_transform(Xtr_k)
    Xte = pca.transform(Xte_k)

    ang = MinMaxScaler(feature_range=(0.0, np.pi))
    Xtr_th = ang.fit_transform(Xtr)
    Xte_th = ang.transform(Xte)

    # coreset for TRAIN features
    tr_idx = np.arange(Xtr_th.shape[0])
    if train_cap and len(tr_idx) > train_cap:
        rng = np.random.default_rng(42)
        tr_idx = rng.choice(tr_idx, size=train_cap, replace=False)
    Xtr_th_sub = Xtr_th[tr_idx]
    y_tr_sub   = y_tr[tr_idx]

    # sinks/obs
    est = build_estimator()
    obs = make_observables(n_qubits=n_qubits, use_pairs=use_pairs)
    sinks = [build_random_sink(n_qubits=n_qubits, fmap_reps=fmap_reps, ansatz_reps=ansatz_reps)
             for _ in range(n_sinks)]

    # features (batched)
    Z_tr = qrf_features_batched(est, sinks, obs, Xtr_th_sub, batch_size=batch_size)
    Z_te = qrf_features_batched(est, sinks, obs, Xte_th,     batch_size=batch_size)

    if whiten:
        zs = StandardScaler()
        Z_tr = zs.fit_transform(Z_tr)
        Z_te = zs.transform(Z_te)

    # y standardize for ridge
    ysc = StandardScaler()
    y_tr_s = ysc.fit_transform(y_tr_sub.reshape(-1,1)).ravel()

    head = RidgeCV(alphas=np.array(alphas), cv=3)
    head.fit(Z_tr, y_tr_s)
    y_hat_s = head.predict(Z_te)
    y_hat = ysc.inverse_transform(y_hat_s.reshape(-1,1)).ravel()

    return y_hat, dict(obs_dim=Z_tr.shape[1], sinks=n_sinks, pairs=use_pairs, fmap=fmap_reps, ans=ansatz_reps,
                       topk=topk_for_pca, coreset=len(tr_idx), batch=batch_size, features_used=colsk)

# ---------------- main experiment ----------------
def run_all(
    data_path="final_processed_als_data.csv",
    n_qubits=4, topk_for_pca=8,
    fmap_reps=2, ansatz_reps=1, use_pairs=True,
    n_sinks=8, whiten=True, stable_band=0.10,
    train_cap=900, batch_size=64
):
    t0 = time.time()
    X, y = load_xy(data_path)
    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)

    # RF baseline
    imp_full = SimpleImputer(strategy="median")
    std_full = StandardScaler()
    Xtr_full = std_full.fit_transform(imp_full.fit_transform(X_tr))
    Xte_full = std_full.transform(imp_full.transform(X_te))
    rf = RandomForestRegressor(n_estimators=200, random_state=42, n_jobs=-1)
    rf.fit(Xtr_full, y_tr)
    y_rf_te = rf.predict(Xte_full)
    rmsd_rf, pcc_rf, r2_rf = metrics(y_te, y_rf_te)
    acc3_rf = direction_accuracy(y_te, y_rf_te, stable_band=stable_band)

    # Plain QRFR
    t1 = time.time()
    y_qrfr, info_q = qrfr_fit_predict(
        X_tr, X_te, y_tr,
        n_qubits=n_qubits, topk_for_pca=topk_for_pca,
        fmap_reps=fmap_reps, ansatz_reps=ansatz_reps, use_pairs=use_pairs,
        n_sinks=n_sinks, whiten=whiten, train_cap=train_cap, batch_size=batch_size
    )
    t2 = time.time()
    rmsd_q, pcc_q, r2_q = metrics(y_te, y_qrfr)
    acc3_q = direction_accuracy(y_te, y_qrfr, stable_band=stable_band)

    # Residual-QRFR
    y_res_tr = y_tr - rf.predict(Xtr_full)
    # Fit residual QRFR on SAME settings (coreset + batched)
    y_res_hat, info_r = qrfr_fit_predict(
        X_tr, X_te, y_res_tr,
        n_qubits=n_qubits, topk_for_pca=topk_for_pca,
        fmap_reps=fmap_reps, ansatz_reps=ansatz_reps, use_pairs=use_pairs,
        n_sinks=n_sinks, whiten=whiten, train_cap=train_cap, batch_size=batch_size
    )
    y_blend = y_rf_te + y_res_hat
    rmsd_r, pcc_r, r2_r = metrics(y_te, y_blend)
    acc3_r = direction_accuracy(y_te, y_blend, stable_band=stable_band)
    t3 = time.time()

    print("\n===== RESULTS (batched + coreset) =====")
    print(f"RF baseline:           RMSD={rmsd_rf:.4f}  PCC={pcc_rf:.4f}  R²={r2_rf:.4f}  ACC3={acc3_rf*100:.1f}%")
    print(f"QRFR (plain):          RMSD={rmsd_q:.4f}  PCC={pcc_q:.4f}  R²={r2_q:.4f}  ACC3={acc3_q*100:.1f}%")
    print(f"Residual-QRFR (blend): RMSD={rmsd_r:.4f}  PCC={pcc_r:.4f}  R²={r2_r:.4f}  ACC3={acc3_r*100:.1f}%")
    print(f"\nDims: QRFR out={info_q['obs_dim']}  | sinks={info_q['sinks']}  | pairs={info_q['pairs']}  "
          f"| fmap={info_q['fmap']}  | ans={info_q['ans']}")
    print(f"Top-k={info_q['topk']}  | coreset={info_q['coreset']}  | batch={info_q['batch']}")
    print(f"Angles from: {info_q['features_used']}")
    print(f"Timings: QRFR={t2-t1:.1f}s  Residual={t3-t2:.1f}s  Total={t3-t0:.1f}s")

    return {
        "rf":   dict(rmsd=rmsd_rf, pcc=pcc_rf, r2=r2_rf, acc3=acc3_rf),
        "qrfr": dict(rmsd=rmsd_q,  pcc=pcc_q,  r2=r2_q,  acc3=acc3_q),
        "res":  dict(rmsd=rmsd_r,  pcc=pcc_r,  r2=r2_r,  acc3=acc3_r),
    }

if __name__ == "__main__":
    _ = run_all(
        data_path="final_processed_als_data.csv",
        n_qubits=4, topk_for_pca=8,
        fmap_reps=2, ansatz_reps=1, use_pairs=True,
        n_sinks=8, whiten=True, stable_band=0.10,
        train_cap=900,    # set to None to use full train set
        batch_size=64     # bump to 96/128 if memory allows
    )


✓ Data loaded: X=(2022, 30), y=(2022,)
✓ Top-8 features: ['fvc_Subject_Liters_Trial_1_std', 'alsfrs_ALSFRS_Total_std', 'fvc_Subject_Liters_Trial_1_slope', 'vitals_Vital_Signs_Delta_std', 'vitals_Pulse_median', 'fvc_Subject_Liters_Trial_1_last', 'fvc_pct_of_Normal_Trial_1_std', 'alsfrs_ALSFRS_Total_slope']


Sink 1/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 2/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 3/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 4/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 5/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 6/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 7/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 8/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 1/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 2/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 3/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 4/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 5/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 6/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 7/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 8/8:   0%|          | 0/7 [00:00<?, ?it/s]

✓ Top-8 features: ['alsfrs_ALSFRS_Total_slope', 'alsfrs_Q9_Climbing_Stairs_median', 'vitals_Vital_Signs_Delta_std', 'alsfrs_Q7_Turning_in_Bed_slope', 'fvc_Subject_Liters_Trial_1_std', 'labs_AST(SGOT)_slope', 'labs_Creatinine_slope', 'vitals_Weight_slope']


Sink 1/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 2/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 3/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 4/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 5/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 6/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 7/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 8/8:   0%|          | 0/15 [00:00<?, ?it/s]

Sink 1/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 2/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 3/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 4/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 5/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 6/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 7/8:   0%|          | 0/7 [00:00<?, ?it/s]

Sink 8/8:   0%|          | 0/7 [00:00<?, ?it/s]


===== RESULTS (batched + coreset) =====
RF baseline:           RMSD=0.5586  PCC=0.2742  R²=0.0709  ACC3=86.2%
QRFR (plain):          RMSD=0.5849  PCC=0.0911  R²=-0.0186  ACC3=86.2%
Residual-QRFR (blend): RMSD=0.5588  PCC=0.2733  R²=0.0702  ACC3=86.2%

Dims: QRFR out=80  | sinks=8  | pairs=True  | fmap=2  | ans=1
Top-k=8  | coreset=900  | batch=64
Angles from: ['fvc_Subject_Liters_Trial_1_std', 'alsfrs_ALSFRS_Total_std', 'fvc_Subject_Liters_Trial_1_slope', 'vitals_Vital_Signs_Delta_std', 'vitals_Pulse_median', 'fvc_Subject_Liters_Trial_1_last', 'fvc_pct_of_Normal_Trial_1_std', 'alsfrs_ALSFRS_Total_slope']
Timings: QRFR=378.3s  Residual=385.9s  Total=781.1s
