In [None]:
"""Train/test within stills using a training time window determind by the 2 peak times in diagonal within-still decoding"""

import os
import mne 
import pandas as pd 
import numpy as np 
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from joblib import Parallel, delayed
import matplotlib.pyplot as plt

# Paths
bids_dir = '/System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/Data/Bids/'
data_path = f'{bids_dir}/derivatives/preprocessed/'
output_dir = '/System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/!Important Data/LDA-16way Static'
peak_csv_path = f'{output_dir}/Mean/Peak_Times.csv'
os.makedirs(f"{output_dir}/data", exist_ok=True)
os.makedirs(f"{output_dir}/plots", exist_ok=True)
os.makedirs(f"{output_dir}/mean", exist_ok=True)

# Functions
def train_test_split(epochs, test_chunk, target_id, cv_id='run_nr'):
    epochs_test = epochs[epochs.metadata[cv_id] == test_chunk]
    epochs_train = epochs[epochs.metadata[cv_id] != test_chunk]
    return epochs_train._data, epochs_test._data, epochs_train.metadata[target_id].to_numpy(), epochs_test.metadata[target_id].to_numpy()

def run_decoding_windowed(x_train, y_train, x_test, y_test, train_window):
    x_train_mean = np.mean(x_train[:, :, train_window], axis=2)
    pipe = Pipeline([
        ('scaler', StandardScaler()), 
        ('classifier', LinearDiscriminantAnalysis(solver='eigen', shrinkage=0.01))
    ])
    pipe.fit(x_train_mean, y_train)
    return [np.mean(pipe.predict(x_test[:, :, t]) == y_test) for t in range(x_test.shape[2])]

# Main loop over early/late
for stage in ['early', 'late']:
    print(f"\n===== Processing {stage.upper()} window decoding =====")
    peak_df = pd.read_csv(peak_csv_path)
    peak_col = 'peak1_sample' if stage == 'early' else 'peak2_sample'
    peak_sample = int(peak_df[peak_df['Subject'] == 'all'][peak_col].values[0])
    train_window = np.arange(peak_sample - 18, peak_sample + 19)

    subjects = [f"S{i:02}" for i in range(1, 21)]

    for Subject in subjects:
        print(f"  Processing {Subject}...")
        fn_still = f'sub-{Subject}_Still_preprocessed-epo.fif'
        epochs_still = mne.read_epochs(data_path + fn_still, preload=True)

        # Clean and relabel metadata
        epochs_still.metadata['degrees_string'] = [k.split('/')[-1] for k in epochs_still.metadata['trial_type']]
        epochs_still = epochs_still[epochs_still.metadata['degrees_string'] != 'catch']
        epochs_still.metadata['degrees'] = [int(i) for i in epochs_still.metadata['degrees_string']]

        accuracy = []
        for test_run in np.unique(epochs_still.metadata.run_nr):
            x_train, x_test, y_train, y_test = train_test_split(epochs_still, test_run, 'degrees_string')
            accs = run_decoding_windowed(x_train, y_train, x_test, y_test, train_window)
            accuracy.append(accs)

        mean_accuracy = np.mean(accuracy, axis=0)
        np.save(f"{output_dir}/data/{Subject}_DecodingAccuracyTimecourse{stage.capitalize()}Window.npy", mean_accuracy)

        # Plotting
        times = epochs_still.times
        plt.figure(figsize=(10, 5))
        plt.plot(times, mean_accuracy, label="Decoding Accuracy")
        plt.xlabel('Time (s)')
        plt.ylabel('Accuracy')
        plt.title(f'{Subject} Window-Trained Decoding ({stage.capitalize()})')
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"{output_dir}/plots/{Subject}_DecodingWindow{stage.capitalize()}.png")
        plt.close()

    # ===== Save group mean across subjects =====
    all_subject_means = []
    valid_subjects = []
    for Subject in subjects:
        acc_path = f"{output_dir}/data/{Subject}_DecodingAccuracyTimecourse{stage.capitalize()}Window.npy"
        if os.path.exists(acc_path):
            mean_accuracy = np.load(acc_path)
            all_subject_means.append(mean_accuracy)
            valid_subjects.append(Subject)
        else:
            print(f"  Missing data for {Subject}, skipping.")

    if len(all_subject_means) == 0:
        print(f"No valid data for {stage}. Skipping group plot.")
        continue

    group_mean = np.mean(np.vstack(all_subject_means), axis=0)
    np.save(f"{output_dir}/data/Mean_DecodingAccuracyTimecourse{stage.capitalize()}Window.npy", group_mean)

    # Use times from one subject
    sample_subject = valid_subjects[0]
    times = mne.read_epochs(data_path + f"sub-{sample_subject}_Still_preprocessed-epo.fif", preload=False).times

    plt.figure(figsize=(10, 5))
    plt.plot(times, group_mean, color='black', label="Group Average")
    plt.xlabel('Time (s)')
    plt.ylabel('Accuracy')
    plt.title(f"Group Average Window-Trained Decoding ({stage.capitalize()})")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/Mean/Mean_DecodingWindow{stage.capitalize()}.png")
    plt.close()



===== Processing EARLY window decoding =====
  Processing S01...
Reading /System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/Data/Bids/derivatives/preprocessed/sub-S01_Still_preprocessed-epo.fif ...
    Found the data of interest:
        t =    -200.00 ...     600.00 ms
        0 CTF compensation matrices available
Adding metadata with 8 columns
1104 matching events found
No baseline correction applied
0 projection items activated
  Processing S02...
Reading /System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/Data/Bids/derivatives/preprocessed/sub-S02_Still_preprocessed-epo.fif ...
    Found the data of interest:
        t =    -200.00 ...     600.00 ms
        0 CTF compensation matrices available
Adding metadata with 8 columns
1104 matching events found
No baseline correction applied
0 projection items activated
  Processing S03...
Reading /System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/Data/Bids/derivatives/preprocessed/sub-S0

In [None]:
all_subject_means = []
valid_subjects = []

for Subject in subjects:
    acc_path = f"{output_dir}/data/{Subject}_DecodingAccuracyTimecourseEarlyWindow.npy"
    if os.path.exists(acc_path):
        mean_accuracy = np.load(acc_path)
        all_subject_means.append(mean_accuracy)
        valid_subjects.append(Subject)
    else:
        print(f"Missing data for {Subject}, skipping in group average.")

# Convert to array and compute grand average
group_mean = np.mean(np.vstack(all_subject_means), axis=0)
np.save(f"{output_dir}/data/Mean_DecodingAccuracyTimecourseEarlyWindow.npy", group_mean)

# Load times from any valid subject
if len(valid_subjects) > 0:
    sample_subject = valid_subjects[0]
    times = mne.read_epochs(data_path + f"sub-{sample_subject}_Still_preprocessed-epo.fif", preload=False).times
else:
    raise RuntimeError("No valid subject data found to compute group average.")

# Plot
plt.figure(figsize=(10, 5))
plt.plot(times, group_mean, color='black', label="Group Average")
plt.xlabel('Time (s)')
plt.ylabel('Accuracy')
plt.title("Group Average Window-Trained Decoding")
plt.legend()
plt.tight_layout()
plt.savefig(f"{output_dir}/mean/Mean_DecodingWindowEarly.png")
plt.show()
plt.close()
