In [2]:
import polars as pl
import numpy as np
from scipy import stats
from tqdm import tqdm
import os
from numpy.lib.stride_tricks import sliding_window_view

In [4]:
import polars as pl
import numpy as np
from scipy import stats
from scipy.signal import welch
from tqdm import tqdm
import os
from numpy.lib.stride_tricks import sliding_window_view
import math
import gc
import joblib
# ... other imports ...

# --- Feature Extraction with Larger Overlapping Windows ---
def extract_features_large_windows(group: pl.DataFrame) -> pl.DataFrame:
    # Get constant features
    event_id = group['event_id'][0]
    prev_marker = group['prev_marker'][0]
    marker = group['marker'][0] # Target

    eeg_cols = [col for col in group.columns if col not in
                ['event_id', 'time', 'marker', 'prev_marker', 'orig_marker']]

    # --- Configuration for Large Windows ---
    fs = 500 # Sampling Frequency
    target_len = 2000 # Samples per epoch (4 seconds)
    window_size = 1000 # Samples (2 seconds) <-- ADJUSTABLE
    step_size = 250    # Samples (0.5 seconds for 75% overlap) <-- ADJUSTABLE
    # Calculate expected number of windows
    num_windows = math.floor((target_len - window_size) / step_size) + 1 if target_len >= window_size else 0
    # --- End Configuration ---

    features_dict = {
        'event_id': event_id,
        'prev_marker': prev_marker,
        'marker': marker
    }

    bands = { 'delta': (1, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30) }
    nperseg_welch = min(500, window_size) # Welch window, e.g., 1 sec or window_size if smaller
    noverlap_welch = nperseg_welch // 2

    # Process each EEG channel
    for col in eeg_cols:
        signal = group[col].cast(pl.Float64).to_numpy()
        signal = np.nan_to_num(signal, nan=0.0)
        if len(signal) < target_len:
            signal = np.pad(signal, (0, target_len - len(signal)), constant_values=0.0)
        elif len(signal) > target_len:
            signal = signal[:target_len]

        if len(signal) < window_size: # If epoch too short even for one window
             for window_idx in range(num_windows): # Add NaNs for all expected features/windows
                 # Time domain NaNs
                 features_dict[f"{col}_mean_w{window_idx}"] = np.nan
                 features_dict[f"{col}_std_w{window_idx}"] = np.nan
                 # ... add NaNs for min, max, skew, kurtosis, mobility, complexity ...
                 # Freq domain NaNs
                 for band in bands: features_dict[f"{col}_{band}_power_w{window_idx}"] = np.nan
             continue

        # Create sliding windows view
        windows = sliding_window_view(signal, window_shape=window_size)[::step_size]

        # Calculate features for each window
        for window_idx, window_signal in enumerate(windows):
            # Ensure we don't process more windows than expected (safety)
            if window_idx >= num_windows: break

            with np.errstate(divide='ignore', invalid='ignore'):
                # Basic Stats
                features_dict[f"{col}_mean_w{window_idx}"] = np.mean(window_signal)
                features_dict[f"{col}_std_w{window_idx}"] = np.std(window_signal, ddof=1)
                features_dict[f"{col}_min_w{window_idx}"] = np.min(window_signal)
                features_dict[f"{col}_max_w{window_idx}"] = np.max(window_signal)
                features_dict[f"{col}_skew_w{window_idx}"] = stats.skew(window_signal, bias=False)
                features_dict[f"{col}_kurtosis_w{window_idx}"] = stats.kurtosis(window_signal, bias=False)

                # Hjorth Parameters
                diff1 = np.diff(window_signal)
                diff2 = np.diff(diff1)
                var_signal = np.var(window_signal, ddof=1)
                var_diff1 = np.var(diff1, ddof=1)
                var_diff2 = np.var(diff2, ddof=1)
                mobility = np.sqrt(var_diff1 / var_signal) if var_signal > 1e-9 else 0.0
                complexity = np.sqrt(var_diff2 / var_diff1) / mobility if mobility > 1e-9 and var_diff1 > 1e-9 else 0.0
                features_dict[f"{col}_mobility_w{window_idx}"] = mobility
                features_dict[f"{col}_complexity_w{window_idx}"] = complexity

                # Band Power
                try:
                    freqs, psd = welch(window_signal, fs=fs, nperseg=nperseg_welch, noverlap=noverlap_welch, scaling='density', average='mean')
                    for band, (low_hz, high_hz) in bands.items():
                        idx_band = np.logical_and(freqs >= low_hz, freqs < high_hz)
                        band_power = np.mean(psd[idx_band]) if np.any(idx_band) else 0.0
                        features_dict[f"{col}_{band}_power_w{window_idx}"] = band_power
                except ValueError as e:
                    print(f"Warning: Welch failed for {col} w{window_idx}, event {event_id}: {e}")
                    for band in bands: features_dict[f"{col}_{band}_power_w{window_idx}"] = np.nan

    return pl.DataFrame([features_dict])


# --- Function to Create Dataset (calls extract_features_1000s) ---
def create_ml_dataset_large_windows(df_path: str, output_path: str = "ML_dataset_large_window_features.parquet", batch_size: int = 100):
    # This function remains structurally the same as create_ml_dataset_epoch_features,
    # just make sure it calls extract_features_large_windows instead.
    # Remember to use a different temp directory name like 'temp_large_window_files'

    # Initialize
    temp_counter = 0
    temp_files = []
    processed_batch = []
    temp_dir = "temp_large_window_files" # Different temp dir
    os.makedirs(temp_dir, exist_ok=True)

    lf = pl.scan_parquet(df_path).with_columns(
        pl.col(['prev_marker', 'marker']).cast(pl.Utf8),
        pl.all().exclude(['event_id', 'time', 'marker', 'prev_marker', 'orig_marker']).cast(pl.Float64)
    )

    print("Fetching unique event IDs...")
    event_ids = lf.select('event_id').unique(maintain_order=True).collect()['event_id'].to_list()
    total_events = len(event_ids)
    print(f"Found {total_events} unique events.")

    with tqdm(total=total_events, desc="Processing events (large windows)") as pbar:
        for event_id in event_ids:
            try:
                event_group_lf = lf.filter(pl.col('event_id') == event_id)
                event_group_df = event_group_lf.sort('time').collect()

                if event_group_df.is_empty():
                     pbar.update(1); continue
                # Check length, although padding/truncation is handled inside
                # if event_group_df.height != 2000:
                #      print(f"Warning: Event {event_id} has {event_group_df.height} samples")

                # *** Call the large window feature extractor ***
                processed_df = extract_features_large_windows(event_group_df)
                processed_batch.append(processed_df)

                if len(processed_batch) >= batch_size:
                    temp_file = os.path.join(temp_dir, f"temp_{temp_counter}.parquet")
                    pl.concat(processed_batch).write_parquet(temp_file, compression="zstd")
                    temp_files.append(temp_file)
                    processed_batch = []
                    temp_counter += 1
                    gc.collect()

                pbar.update(1)

            except Exception as e:
                print(f"Error processing event_id {event_id}: {str(e)}")
                import traceback
                traceback.print_exc()
                pbar.update(1)
                continue

    # Write remaining batch
    if processed_batch:
        temp_file = os.path.join(temp_dir, f"temp_{temp_counter}.parquet")
        pl.concat(processed_batch).write_parquet(temp_file, compression="zstd")
        temp_files.append(temp_file)
        print(f"Wrote final batch of {len(processed_batch)} events.")

    # Combine temporary files
    print(f"Combining {len(temp_files)} temporary batch files...")
    if temp_files:
        # ... (Combination and cleanup logic is identical to previous function) ...
        try:
            lazy_frames = [pl.scan_parquet(f) for f in temp_files]
            pl.concat(lazy_frames, rechunk=False).sink_parquet(
                output_path, compression="zstd", statistics=True
            )
            print(f"Successfully created final dataset: {output_path}")
        except Exception as e:
             print(f"ERROR: Failed during final concatenation/writing: {e}")
             print(f"Temporary files are kept for inspection in '{temp_dir}'.")
             return None
        finally:
             if os.path.exists(output_path):
                 print("Cleaning up temporary files...")
                 for f in temp_files:
                     try: os.remove(f)
                     except OSError as e: print(f"Warning: Could not remove temp file {f}: {e}")
                 try: os.rmdir(temp_dir); print(f"Removed temporary directory: {temp_dir}")
                 except OSError: print(f"Temporary directory {temp_dir} not empty, not removed.")
                 print("Cleanup complete.")
             else:
                 print(f"Final file not created. Keeping temporary files in '{temp_dir}'.")
        return output_path
    else:
        print("No temporary files were generated. No output created.")
        return None


# --- Usage Example ---
INPUT_FILE = "/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet"
OUTPUT_FILE_LARGE_WINDOW = "ML_dataset_1000_features.parquet" # New name
BATCH_SIZE = 200

final_large_window_file_path = create_ml_dataset_large_windows(INPUT_FILE, OUTPUT_FILE_LARGE_WINDOW, BATCH_SIZE)

if final_large_window_file_path:
    print(f"\nLarge window feature dataset creation complete. Output file: {final_large_window_file_path}")
    # Verify schema or head
    print("\nSchema of the large window feature dataset:")
    final_lf = pl.scan_parquet(final_large_window_file_path)
    print(final_lf.schema)
    # Estimate feature count more accurately now
    num_features_estimate = len(final_lf.columns) - 3 # Subtract event_id, marker, prev_marker
    print(f"\nEstimated number of feature columns: {num_features_estimate}")

else:
    print("\nLarge window feature dataset creation failed.")

Fetching unique event IDs...
Found 2772 unique events.


Processing events (large windows): 100%|██████████| 2772/2772 [12:56<00:00,  3.57it/s]


Wrote final batch of 172 events.
Combining 14 temporary batch files...
Successfully created final dataset: ML_dataset_1000_features.parquet
Cleaning up temporary files...
Removed temporary directory: temp_large_window_files
Cleanup complete.

Large window feature dataset creation complete. Output file: ML_dataset_1000_features.parquet

Schema of the large window feature dataset:
Schema([('event_id', Int64), ('prev_marker', String), ('marker', String), ('Fp1_mean_w0', Float64), ('Fp1_std_w0', Float64), ('Fp1_min_w0', Float64), ('Fp1_max_w0', Float64), ('Fp1_skew_w0', Float64), ('Fp1_kurtosis_w0', Float64), ('Fp1_mobility_w0', Float64), ('Fp1_complexity_w0', Float64), ('Fp1_delta_power_w0', Float64), ('Fp1_theta_power_w0', Float64), ('Fp1_alpha_power_w0', Float64), ('Fp1_beta_power_w0', Float64), ('Fp1_mean_w1', Float64), ('Fp1_std_w1', Float64), ('Fp1_min_w1', Float64), ('Fp1_max_w1', Float64), ('Fp1_skew_w1', Float64), ('Fp1_kurtosis_w1', Float64), ('Fp1_mobility_w1', Float64), ('Fp1_c

  print(final_lf.schema)
  num_features_estimate = len(final_lf.columns) - 3 # Subtract event_id, marker, prev_marker


In [5]:
import polars as pl
import numpy as np
from tqdm import tqdm
import os
import math
import gc
import joblib
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_selection import VarianceThreshold, SelectKBest, mutual_info_classif
import traceback # Import traceback for detailed error printing

# --- Configuration ---
# *** IMPORTANT: Update INPUT_FILE to the correct large window features file ***
INPUT_FILE = "ML_dataset_1000_features.parquet" # Use the file generated by the previous step
OUTPUT_SCALER_MODEL = "fs_scaler_1000.joblib" # Use distinct names
OUTPUT_FINAL_FEATURE_LIST = "final_selected_features_1000.joblib" # Use distinct names

# Feature Selection Parameters
VARIANCE_THRESHOLD = 0.01
MI_SELECT_K = 500
MI_SAMPLE_SIZE = 20000 # Keep or adjust based on memory
BATCH_SIZE = 1000 # For scaler fitting

# --- Identify Feature and Target Columns ---
print("Reading schema to identify columns...")
try:
    # Make sure INPUT_FILE exists
    if not os.path.exists(INPUT_FILE):
        raise FileNotFoundError(f"Input file not found: {INPUT_FILE}")
    schema = pl.read_parquet(INPUT_FILE, n_rows=0).schema
except Exception as e:
    print(f"Error reading schema from {INPUT_FILE}: {e}")
    print("Please ensure the file exists and is a valid Parquet file.")
    exit(1) # Use non-zero exit code for errors

IDENTIFIER_COLS = ['event_id', 'prev_marker']
TARGET_COLUMN = 'marker'
# Calculate initial feature columns
FEATURE_COLUMNS = [col for col in schema if col not in IDENTIFIER_COLS + [TARGET_COLUMN]]
n_original = len(FEATURE_COLUMNS) # Store original count here
print(f"Found {n_original} initial feature columns.")
print(f"Target column: {TARGET_COLUMN}")

if not FEATURE_COLUMNS:
    raise ValueError("No feature columns identified. Check IDENTIFIER_COLS and TARGET_COLUMN.")
if TARGET_COLUMN not in schema:
     raise ValueError(f"Target column '{TARGET_COLUMN}' not found in the dataset.")

# --- Step 1: Low Variance Threshold Preparation ---

# 1a. Fit StandardScaler Incrementally on ALL original features
print("\n--- Step 1: Low Variance Threshold ---")
scaler = StandardScaler()
print(f"Fitting StandardScaler incrementally (batch size: {BATCH_SIZE})...")
selected_mask_variance = None # Initialize mask variable

try:
    total_rows = pl.scan_parquet(INPUT_FILE).select(pl.len()).collect().item()
    n_batches = math.ceil(total_rows / BATCH_SIZE)
    row_iterator = pl.read_parquet(INPUT_FILE, columns=FEATURE_COLUMNS).iter_slices(n_rows=BATCH_SIZE)

    for i, data_chunk_pl in enumerate(tqdm(row_iterator, total=n_batches, desc="Fitting Scaler")):
        if data_chunk_pl.height > 0:
            features_np = data_chunk_pl.to_numpy()
            if features_np.size > 0:
                # Check for NaNs/Infs which cause issues in partial_fit
                if np.any(~np.isfinite(features_np)):
                     print(f"\nWarning: Non-finite values found in chunk {i}. Replacing with 0 before scaling.")
                     features_np = np.nan_to_num(features_np, nan=0.0, posinf=0.0, neginf=0.0)
                scaler.partial_fit(features_np)
            del features_np
        del data_chunk_pl
        gc.collect()

    print("StandardScaler fitting complete.")
    joblib.dump(scaler, OUTPUT_SCALER_MODEL)
    print(f"Saved fitted scaler to {OUTPUT_SCALER_MODEL}")

    # 1b. Determine Variance Threshold Mask (based on the fitted scaler)
    if not hasattr(scaler, 'var_') or scaler.var_ is None:
        raise ValueError("Scaler variance (var_) attribute not found or is None. Fitting likely failed.")

    variances = scaler.var_
    selector_variance = VarianceThreshold(threshold=VARIANCE_THRESHOLD)
    selected_mask_variance = variances > VARIANCE_THRESHOLD # This mask corresponds to ORIGINAL features
    features_after_variance = [
        feature for feature, selected in zip(FEATURE_COLUMNS, selected_mask_variance) if selected
    ]
    n_after_variance = len(features_after_variance)

    print(f"Applied Variance Threshold > {VARIANCE_THRESHOLD}")
    print(f"Features remaining after variance check: {n_after_variance} (removed {n_original - n_after_variance})")

    if not features_after_variance:
        print("Error: No features remaining after variance thresholding. Check threshold or data.")
        exit(1)

except Exception as e:
    print(f"Error during StandardScaler fitting or Variance Threshold prep: {e}")
    traceback.print_exc()
    exit(1)


# --- Step 2: Univariate Selection (Mutual Information) ---
print("\n--- Step 2: Mutual Information Selection ---")
features_after_mi = [] # Initialize final list
n_after_mi = 0
k_features = 0 # Initialize k_features

# 2a. Sample Data (including ALL original features needed for scaling)
print(f"Sampling data ({MI_SAMPLE_SIZE} rows) for MI calculation...")
# We need all original features for scaling, plus the target
cols_to_sample = FEATURE_COLUMNS + [TARGET_COLUMN]

try:
    if total_rows <= MI_SAMPLE_SIZE:
         print("Total rows <= sample size, using all data for MI.")
         sample_df = pl.read_parquet(INPUT_FILE, columns=cols_to_sample)
    else:
        sample_df = pl.scan_parquet(INPUT_FILE, columns=cols_to_sample)\
                       .sample(n=MI_SAMPLE_SIZE, shuffle=True, seed=42)\
                       .collect()

    print(f"Sampled {sample_df.height} rows.")
    if sample_df.is_empty(): raise ValueError("Sampled DataFrame is empty.")

    # Separate features and target
    X_sample_full = sample_df.select(FEATURE_COLUMNS).to_numpy() # Has n_original columns
    y_sample_raw = sample_df.select(TARGET_COLUMN).to_numpy().ravel()
    del sample_df
    gc.collect()

    # Encode target variable
    le = LabelEncoder()
    y_sample = le.fit_transform(y_sample_raw)
    print(f"Encoded target variable. Found classes: {le.classes_}")

except Exception as e:
    print(f"Error during data sampling: {e}")
    traceback.print_exc()
    exit(1)

# 2b. Scale the FULL Sample Data
print("Scaling the sample data (using all original features)...")
try:
    # Check for NaNs/Infs before transform
    if np.any(~np.isfinite(X_sample_full)):
         print("Warning: Non-finite values found in sample features. Replacing with 0 before transform.")
         X_sample_full = np.nan_to_num(X_sample_full, nan=0.0, posinf=0.0, neginf=0.0)

    X_sample_scaled_full = scaler.transform(X_sample_full) # Transform uses the fitted scaler
    del X_sample_full # Free memory
    gc.collect()

    # 2c. Apply Variance Threshold filtering to the SCALED sample data
    print("Applying variance threshold filter to scaled sample data...")
    if selected_mask_variance is None:
         raise ValueError("Variance threshold mask was not calculated.")
    # Use the mask calculated in step 1b to select columns from the scaled data
    X_sample_scaled_variance_filtered = X_sample_scaled_full[:, selected_mask_variance]
    del X_sample_scaled_full # Free memory
    gc.collect()
    print(f"Shape of data for MI: {X_sample_scaled_variance_filtered.shape}") # Should have n_after_variance columns

except Exception as e:
     print(f"Error scaling sample data or applying variance filter: {e}")
     traceback.print_exc()
     exit(1)


# 2d. Calculate Mutual Information Scores and Select K Best
print(f"Calculating Mutual Information scores and selecting top {MI_SELECT_K} features...")
try:
    # Check if input data for MI exists and has expected shape
    if 'X_sample_scaled_variance_filtered' not in locals() or X_sample_scaled_variance_filtered.shape[1] != n_after_variance:
         raise ValueError("Data for MI calculation is missing or has incorrect shape.")

    # Ensure k is not larger than the number of features available AFTER variance thresholding
    num_features_for_mi = X_sample_scaled_variance_filtered.shape[1]
    k_features = min(MI_SELECT_K, num_features_for_mi) # Define k_features here
    if k_features < MI_SELECT_K:
        print(f"Warning: Requested K={MI_SELECT_K}, but only {k_features} available after variance threshold. Selecting all {k_features}.")
    if k_features == 0:
        raise ValueError("k_features is 0, cannot select features.")

    # Use random_state for reproducibility if MI uses it
    selector_mi = SelectKBest(lambda X, y: mutual_info_classif(X, y, discrete_features=False, random_state=42), k=k_features)
    selector_mi.fit(X_sample_scaled_variance_filtered, y_sample)

    # Get the mask relative to the variance-filtered features
    selected_mask_mi = selector_mi.get_support()

    # Map this mask back to the original feature names that passed the variance threshold
    features_after_mi = [
        feature for feature, selected in zip(features_after_variance, selected_mask_mi) if selected
    ]
    n_after_mi = len(features_after_mi)

    print("Mutual Information selection complete.")
    print(f"Final features selected: {n_after_mi} (removed {n_after_variance - n_after_mi} based on MI)")

    if n_after_mi == 0 :
        print("Warning: Mutual information selected 0 features.")


except Exception as e:
    print(f"Error during Mutual Information calculation/selection: {e}")
    traceback.print_exc()
    # No exit here, allow summary to print if possible, but feature list might be empty


# --- Final Results ---
print("\n--- Feature Selection Summary ---")
print(f"Initial features: {n_original}")
print(f"After Variance Threshold (> {VARIANCE_THRESHOLD}): {n_after_variance}")
# Check if MI step completed successfully enough to have n_after_mi and k_features
if 'n_after_mi' in locals() and 'k_features' in locals():
    print(f"After Mutual Information (Top {k_features}): {n_after_mi}")
else:
    print("Mutual Information step did not complete successfully.")

# Save the final list of selected features
if features_after_mi: # Only save if the list is not empty
    try:
        joblib.dump(features_after_mi, OUTPUT_FINAL_FEATURE_LIST)
        print(f"\nSaved the final list of {n_after_mi} selected features to: {OUTPUT_FINAL_FEATURE_LIST}")
    except Exception as e:
        print(f"Error saving the final feature list: {e}")
else:
    print("\nNo features selected by Mutual Information, final list not saved.")

print("\nFeature selection process finished.")

Reading schema to identify columns...
Found 3780 initial feature columns.
Target column: marker

--- Step 1: Low Variance Threshold ---
Fitting StandardScaler incrementally (batch size: 1000)...


Fitting Scaler: 100%|██████████| 3/3 [00:00<00:00, 12.48it/s]


StandardScaler fitting complete.
Saved fitted scaler to fs_scaler_1000.joblib
Applied Variance Threshold > 0.01
Features remaining after variance check: 3443 (removed 337)

--- Step 2: Mutual Information Selection ---
Sampling data (20000 rows) for MI calculation...
Total rows <= sample size, using all data for MI.
Sampled 2772 rows.
Encoded target variable. Found classes: ['Left' 'Right']
Scaling the sample data (using all original features)...
Applying variance threshold filter to scaled sample data...
Shape of data for MI: (2772, 3443)
Calculating Mutual Information scores and selecting top 500 features...
Mutual Information selection complete.
Final features selected: 500 (removed 2943 based on MI)

--- Feature Selection Summary ---
Initial features: 3780
After Variance Threshold (> 0.01): 3443
After Mutual Information (Top 500): 500

Saved the final list of 500 selected features to: final_selected_features_1000.joblib

Feature selection process finished.


In [8]:
import polars as pl
import numpy as np
import joblib
import os
import warnings
import gc # Import gc for memory management
import traceback # Import traceback for error details
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.exceptions import ConvergenceWarning

# Import popular classifiers
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier

# Suppress convergence warnings for models like Logistic Regression or MLP
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')

# --- Configuration ---
INPUT_FEATURES_FILE = "ML_dataset_1000_features.parquet" # The dataset with ~4k features
SELECTED_FEATURES_LIST_FILE = "final_selected_features_1000.joblib" # List of ~500 feature names
TARGET_COLUMN = "marker"
IDENTIFIER_COLS = ['event_id', 'prev_marker'] # Columns to ignore for features/target

TEST_SIZE = 0.25
RANDOM_STATE = 42

OUTPUT_MODEL_DIR = "trained_models_in_memory"
os.makedirs(OUTPUT_MODEL_DIR, exist_ok=True)

# --- 1. Load Selected Features ---
print(f"Loading selected features from: {SELECTED_FEATURES_LIST_FILE}")
try:
    selected_features = joblib.load(SELECTED_FEATURES_LIST_FILE)
    if not isinstance(selected_features, list) or len(selected_features) == 0:
        raise ValueError("Loaded features are not a valid non-empty list.")
    print(f"Loaded {len(selected_features)} selected features.")
except FileNotFoundError:
    print(f"Error: Selected features file not found at {SELECTED_FEATURES_LIST_FILE}")
    exit(1)
except Exception as e:
    print(f"Error loading selected features: {e}")
    exit(1)

columns_to_load = selected_features + [TARGET_COLUMN]

# --- 2. Load Dataset into Memory ---
print(f"\nLoading data from: {INPUT_FEATURES_FILE} (Cols: {len(columns_to_load)})")
try:
    if not os.path.exists(INPUT_FEATURES_FILE):
        raise FileNotFoundError(f"Input file not found: {INPUT_FEATURES_FILE}")
    df_pl = pl.read_parquet(INPUT_FEATURES_FILE, columns=columns_to_load)
    print(f"Loaded DataFrame shape: {df_pl.shape}")

    X = df_pl.select(selected_features).to_numpy()
    y_raw = df_pl.select(TARGET_COLUMN).to_numpy().ravel()
    del df_pl
    gc.collect()
    print("Converted data to NumPy arrays.")

except FileNotFoundError:
     print(f"Error: Input file not found at {INPUT_FEATURES_FILE}")
     exit(1)
except pl.exceptions.ColumnNotFoundError as e:
     print(f"Error: One or more selected columns not found in {INPUT_FEATURES_FILE}. Details: {e}")
     print("Ensure the feature list file corresponds to the columns in the Parquet file.")
     exit(1)
except Exception as e:
    print(f"Error loading or processing data: {e}")
    traceback.print_exc()
    exit(1)

# --- 3. Handle Missing Values ---
if np.any(~np.isfinite(X)):
    print("Warning: Non-finite values (NaN/Inf) found in features. Replacing with 0.")
    X = np.nan_to_num(X, nan=0.0, posinf=np.finfo(X.dtype).max, neginf=np.finfo(X.dtype).min)

# --- 4. Encode Target Variable ---
print("\nEncoding target variable...")
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y_raw)
print(f"Target classes: {label_encoder.classes_}")
# Find index of the 'positive' class (e.g., 'Right', assuming it's the second class)
positive_class_label = 'Right' # Or 'Left', choose one consistently
positive_class_index = np.where(label_encoder.classes_ == positive_class_label)[0][0]
print(f"Positive class '{positive_class_label}' encoded as: {positive_class_index}")
print(f"Encoded target shape: {y.shape}")

# --- 5. Split Data into Train/Test ---
print(f"\nSplitting data into Training ({1-TEST_SIZE:.0%}) and Testing ({TEST_SIZE:.0%})...")
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y
)
print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")
del X, y_raw, y
gc.collect()

# --- 6. Scale Features ---
print("\nScaling features (fitting scaler on training data only)...")
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
print("Scaling complete.")
scaler_path = os.path.join(OUTPUT_MODEL_DIR, "scaler_in_memory.joblib")
joblib.dump(scaler, scaler_path)
print(f"Saved fitted scaler to: {scaler_path}")

# --- 7. Initialize Models ---
print("\nInitializing models...")
# Ensure models that need probability=True have it set
models = {
    "LogisticRegression": LogisticRegression(random_state=RANDOM_STATE, max_iter=1000, solver='liblinear'),
    "SGDClassifier_Log": SGDClassifier(loss='log_loss', random_state=RANDOM_STATE, max_iter=1000, tol=1e-3), # Supports predict_proba
    # SGDClassifier_Hinge does NOT support predict_proba directly
    "SVC_Linear": SVC(kernel='linear', probability=True, random_state=RANDOM_STATE), # Needs probability=True
    "SVC_RBF": SVC(kernel='rbf', probability=True, random_state=RANDOM_STATE),       # Needs probability=True
    "RandomForest": RandomForestClassifier(n_estimators=150, random_state=RANDOM_STATE, n_jobs=-1, max_depth=20, min_samples_leaf=5),
    "GradientBoosting": GradientBoostingClassifier(n_estimators=100, random_state=RANDOM_STATE, learning_rate=0.1, max_depth=3), # Supports predict_proba
    "KNeighbors": KNeighborsClassifier(n_neighbors=5, n_jobs=-1), # Supports predict_proba
    "GaussianNB": GaussianNB(), # Supports predict_proba
    "DecisionTree": DecisionTreeClassifier(random_state=RANDOM_STATE, max_depth=15, min_samples_leaf=10), # Supports predict_proba
    "MLPClassifier": MLPClassifier(hidden_layer_sizes=(64, 32), random_state=RANDOM_STATE, max_iter=500, early_stopping=True) # Supports predict_proba
}

# --- Helper Function for Threshold Tuning ---
def find_best_threshold(y_true, y_pred_proba, pos_label_index=1, steps=100):
    """Finds the probability threshold that maximizes accuracy."""
    best_threshold = 0.5 # Default
    best_accuracy = 0.0
    min_prob = np.min(y_pred_proba)
    max_prob = np.max(y_pred_proba)
    
    # Calculate initial accuracy with default threshold
    y_pred_default = (y_pred_proba >= 0.5).astype(int)
    best_accuracy = accuracy_score(y_true, y_pred_default) # Start with default accuracy

    # Iterate through thresholds
    thresholds = np.linspace(min_prob, max_prob, steps + 1) # Check edge cases too
    for threshold in thresholds:
        y_pred_tuned = (y_pred_proba >= threshold).astype(int)
        current_accuracy = accuracy_score(y_true, y_pred_tuned)
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            best_threshold = threshold
            
    return best_threshold, best_accuracy

# --- 8. Train, Evaluate, and Tune Threshold ---
print("\n--- Training, Evaluating Models, and Tuning Threshold ---")

results = {}

for name, model in models.items():
    print(f"\n--- Processing Model: {name} ---")
    try:
        # Train the model
        print("Training...")
        model.fit(X_train_scaled, y_train)
        print("Training complete.")

        # --- Original Evaluation (Default Threshold) ---
        print("Evaluating with default threshold (0.5)...")
        y_pred_original = model.predict(X_test_scaled)
        accuracy_original = accuracy_score(y_test, y_pred_original)
        report_original = classification_report(y_test, y_pred_original, target_names=label_encoder.classes_, zero_division=0)
        cm_original = confusion_matrix(y_test, y_pred_original)

        # --- Threshold Tuning ---
        best_threshold = 0.5 # Default
        accuracy_tuned = accuracy_original # Start with original

        if hasattr(model, "predict_proba"):
            print("Tuning threshold...")
            try:
                # Get probabilities for the positive class
                y_pred_proba = model.predict_proba(X_test_scaled)[:, positive_class_index]
                best_threshold, accuracy_tuned = find_best_threshold(y_test, y_pred_proba, positive_class_index)
                print(f"Best threshold found: {best_threshold:.4f}")
            except Exception as te:
                print(f"Could not tune threshold for {name}: {te}")
                best_threshold = 'N/A'
                accuracy_tuned = accuracy_original # Fallback to original if tuning fails
        else:
            print(f"Model {name} does not support predict_proba. Skipping threshold tuning.")
            best_threshold = 'N/A'
            accuracy_tuned = accuracy_original

        # Store results
        results[name] = {
            'accuracy_original': accuracy_original,
            'accuracy_tuned': accuracy_tuned,
            'best_threshold': best_threshold,
            'report_original': report_original, # Report based on default threshold
            'cm_original': cm_original
        }

        # Print metrics
        print(f"\n--- Evaluation Results: {name} ---")
        print(f"Accuracy (Default Threshold 0.5): {accuracy_original:.4f}")
        if best_threshold != 'N/A':
            print(f"Best Threshold Found            : {best_threshold:.4f}")
            print(f"Accuracy (Best Threshold)       : {accuracy_tuned:.4f}")
        print("\nClassification Report (Default Threshold):")
        print(report_original)
        print("Confusion Matrix (Default Threshold):")
        print(cm_original)
        print("-" * 40)

        # Save the trained model
        model_path = os.path.join(OUTPUT_MODEL_DIR, f"{name}.joblib")
        joblib.dump(model, model_path)
        # print(f"Saved trained model to: {model_path}") # Optional: uncomment if needed

    except MemoryError:
         print(f"MemoryError occurred while processing {name}.")
         results[name] = {'accuracy_original': 'MemoryError', 'accuracy_tuned': 'MemoryError', 'best_threshold': 'MemoryError', 'report_original': 'MemoryError', 'cm_original': 'MemoryError'}
    except Exception as e:
        print(f"An error occurred while processing {name}: {e}")
        results[name] = {'accuracy_original': 'Error', 'accuracy_tuned': 'Error', 'best_threshold': 'Error', 'report_original': str(e), 'cm_original': 'Error'}
        traceback.print_exc()

# --- 9. Final Summary ---
print("\n--- Final Accuracy Summary ---")
print(f"{'Model':<25} | {'Acc (Default)':<15} | {'Best Threshold':<15} | {'Acc (Tuned)':<15}")
print("-" * 75)
for name, metrics in results.items():
    acc_orig_str = f"{metrics['accuracy_original']:.4f}" if isinstance(metrics['accuracy_original'], float) else str(metrics['accuracy_original'])
    acc_tuned_str = f"{metrics['accuracy_tuned']:.4f}" if isinstance(metrics['accuracy_tuned'], float) else str(metrics['accuracy_tuned'])
    thresh_str = f"{metrics['best_threshold']:.4f}" if isinstance(metrics['best_threshold'], float) else str(metrics['best_threshold'])

    print(f"{name:<25} | {acc_orig_str:<15} | {thresh_str:<15} | {acc_tuned_str:<15}")
print("-" * 75)


Loading selected features from: final_selected_features_1000.joblib
Loaded 500 selected features.

Loading data from: ML_dataset_1000_features.parquet (Cols: 501)
Loaded DataFrame shape: (2772, 501)
Converted data to NumPy arrays.

Encoding target variable...
Target classes: ['Left' 'Right']
Positive class 'Right' encoded as: 1
Encoded target shape: (2772,)

Splitting data into Training (75%) and Testing (25%)...
X_train shape: (2079, 500), y_train shape: (2079,)
X_test shape: (693, 500), y_test shape: (693,)

Scaling features (fitting scaler on training data only)...
Scaling complete.
Saved fitted scaler to: trained_models_in_memory/scaler_in_memory.joblib

Initializing models...

--- Training, Evaluating Models, and Tuning Threshold ---

--- Processing Model: LogisticRegression ---
Training...
Training complete.
Evaluating with default threshold (0.5)...
Tuning threshold...
Best threshold found: 0.4700

--- Evaluation Results: LogisticRegression ---
Accuracy (Default Threshold 0.5): 