In [None]:
"""Script to Cross-Decode Orientations in movies using LDA *by relative position in a movie over time*"""

import os, mne, pickle, numpy as np, pandas as pd, time
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn import svm
from joblib import Parallel, delayed
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from scipy.ndimage import gaussian_filter1d

#List of static conditions, adjust if you only want to look at the appearance of certain orientations (all are listed)
input_numbers = [1, 22, 45, 67, 90, 112, 135, 157, 180, 202, 225, 247, 270, 292, 315, 337]


def get_movie_orientation_sequence(movie_condition, input_numbers, n_steps=7):
    """
    Given a movie condition (ie. '0022_Left') and a list of static orientations,
    return the ordered list of what orientations the movie passes through.
    movie_condition: string of a movie
    input_numbers: orientations a movie can pass through
    n_steps: how many orientations you want to see in order
    """
    parts = movie_condition.split('_')
    start_ori = int(parts[0])
    direction = parts[1]  # 'Left' or 'Right'

    if start_ori not in input_numbers:
        raise ValueError(f"Starting orientation {start_ori} not in input_numbers")

    start_idx = input_numbers.index(start_ori)
    seq = []

    for i in range(n_steps):
        if direction == 'Right':
            idx = (start_idx + i) % len(input_numbers)
        elif direction == 'Left':
            idx = (start_idx - i) % len(input_numbers)
        else:
            raise ValueError(f"Unknown direction: {direction}")
        seq.append(f"{input_numbers[idx]:04}")  # pad to 4 digits

    return seq

def time_resolved_decoding_per_position(x_train, y_train, x_test, movie_condition, times, input_numbers, n_positions=7):
    """
    Perform decoding for each timepoint, checking match with each position (1st, 2nd, ..., nth)
    in the sequence of the movie.
    Returns: array (n_positions, n_times) of proportion correct.
    """
    orientation_sequence = get_movie_orientation_sequence(movie_condition, input_numbers, n_steps=n_positions)
    n_times = x_test.shape[2]

    pipe = Pipeline([
        ('scaler', StandardScaler()), 
        ('classifier', LinearDiscriminantAnalysis(solver='eigen', shrinkage=0.01))
    ])
    pipe.fit(x_train, y_train)

    accuracy_matrix = np.zeros((n_positions, n_times))  # shape: (position, time)

    for t in range(n_times):
        x_test_t = x_test[:, :, t]
        preds = pipe.predict(x_test_t)  # shape: (n_trials,)
        for pos in range(n_positions):
            target_ori = orientation_sequence[pos]
            acc = np.mean(preds == target_ori)
            accuracy_matrix[pos, t] = acc

    return accuracy_matrix, times

def select_conditions(epochs_stacked,column_name,items_to_select):
    """
    Select a specific set of conditions from an epochs object
    """
    if type(items_to_select)==list:
        selected_epochs = epochs_stacked[np.any([epochs_stacked.metadata[column_name] == i  for i in items_to_select],axis=0)]
    if type(items_to_select)==str:
        selected_epochs = epochs_stacked[epochs_stacked.metadata[column_name] == items_to_select]
    return selected_epochs
def train_test_split_cross(epo_stacked, epo_stackedMovie, test_condition, start_time, end_time):
    """
    Creates a train-test split cross in which one averages across a training window and then goes across
    all testing points
    """
    # Test set: all timepoints
    epochs_test = epo_stackedMovie[epo_stackedMovie.metadata['condition'] == test_condition]
    y_test = epochs_test.metadata['condition'].to_numpy()
    x_test = epochs_test._data  # shape: (n_trials, n_channels, n_times)

    # Train set: Average across a given time-window
    epochs_train = epo_stacked[epo_stacked.metadata['block_type'] == 'Still']
    y_train = epochs_train.metadata['condition'].to_numpy()
    mask = (epochs_train.times >= start_time) & (epochs_train.times <= end_time)
    x_train = np.mean(epochs_train._data[:, :, mask], axis=2)  # mean over time

    return x_train, x_test, y_train, y_test, epochs_test.times


In [None]:
all_subjects_results = []
all_static_accuracies = []
subjects = [f"S{i:02}" for i in range(1, 21)]
data_path = '/System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/Data/Bids/derivatives/preprocessed/'
csv_path = '/System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/!Important Data/LDA-16way Static/Peak_Times.csv'
output_dir = '/System/Volumes/Data/misc/data12/sjapee/Sebastian-OrientationImagery/!Important Data/ProportionsOverTime'
peak_df = pd.read_csv(csv_path, index_col='Subject')
#Loop to produce plots
for subject in subjects:
    print(f"Processing {subject}...")
    # Load epochs
    fn_still = f'sub-{subject}_Still_preprocessed-epo.fif'
    fn_dynamic = f'sub-{subject}_Dynamic_preprocessed-epo.fif'
    epochs_still = mne.read_epochs(data_path + fn_still, preload=True)
    epochs_dynamic = mne.read_epochs(data_path + fn_dynamic, preload=True)

    # Select only relevant conditions (exclude catch trials)
    epochs_selected = select_conditions(epochs_still, column_name='condition', items_to_select=[
        '0001', '0022', '0045', '0067', '0090', '0112', '0135', '0157',
        '0180', '0202', '0225', '0247', '0270', '0292', '0315', '0337'])
    epochs_selectedMovie = select_conditions(epochs_dynamic, column_name='condition', items_to_select=[
        '0022_Left', '0022_Right', '0067_Left', '0067_Right', '0112_Left', '0112_Right',
        '0157_Left', '0157_Right', '0202_Left', '0202_Right', '0247_Left', '0247_Right',
        '0292_Left', '0292_Right', '0337_Left', '0337_Right'])

    #Find the peak time window for within-still decoding for a given subject to use for our training window for cross-decoding
    #Choose peak1 for early peaks and peak2 for later peaks
    peak_sample = int(peak_df.loc['all', 'peak2_sample'])
    start_time = epochs_selected.times[peak_sample - 18]
    end_time = epochs_selected.times[peak_sample + 18]


    n_positions = 9
    times = epochs_selectedMovie.times
    subject_acc_matrix = np.zeros((n_positions, len(times)))  # average across movies

    movie_conditions = epochs_selectedMovie.metadata['condition'].unique()
    for cond in movie_conditions:
        x_train, x_test, y_train, y_test, times = train_test_split_cross(
            epochs_selected, epochs_selectedMovie, cond, start_time, end_time)

        acc_matrix, times = time_resolved_decoding_per_position(
            x_train, y_train, x_test, cond, times, input_numbers, n_positions=n_positions)

        subject_acc_matrix += acc_matrix

    subject_acc_matrix /= len(movie_conditions)  # mean across movies
    np.save(f"{output_dir}/{subject}_SequencedAccuracyLateall.npy", subject_acc_matrix)
    all_subjects_results.append(subject_acc_matrix)  # shape: (n_subjects, n_positions, n_times)
#Train test split for cross decoding
all_subjects_results = np.array(all_subjects_results)  # shape: (n_subjects, n_positions, n_times)
group_avg = np.mean(all_subjects_results, axis=0)  # shape: (n_positions, n_times)
np.save(f"{output_dir}/GroupAverage_SequencedAccuracyLateall.npy", group_avg)

In [None]:
###Plotting code, go to plotting script to load in and plot if you don't want to plot here"""
plt.figure(figsize=(15, 6))
colors = plt.cm.viridis(np.linspace(0, 1, n_positions))

for pos in range(n_positions):
    smoothed = gaussian_filter1d(group_avg[pos], sigma=2)
    plt.plot(times, smoothed, label=f'{pos+1}Â° in sequence', color=colors[pos])
    onset_time = pos * 0.375
    if times[0] <= onset_time <= times[-1]:
        plt.axvline(onset_time, linestyle='--', color=colors[pos], alpha=0.8, linewidth=1.2)

plt.axhline(1/16, linestyle='--', color='gray', label='Chance')
plt.axvline(0, linestyle='--', color='black', label='Movie Onset')

plt.xlabel('Time (s)')
plt.ylabel('Proportion Correct')
plt.title('Cross-Decoding Accuracy by Sequence Position')

# ğŸ§± Move legend outside
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0))


plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave space on the right
plt.plot()

