In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Parameters
n_schedules = 20
trials_per_reversal = 15
sigma_switch = 2
n_reversals = 9
stim_percentages = [0.25, 0.25, 0.5]  # stim1, stim2, stim3
reinforcement_rates = [0.2, 0.8]  # low, high
stim_names = ['stim1', 'stim2', 'stim3']
alpha = 0.2
no_decisions = 4  # per pair
oddball_cutoff = 6

np.random.seed(42)
os.makedirs("schedules", exist_ok=True)

def jittered_interval(base, sigma=200):
    return int(np.random.normal(base, sigma, 1)[0])

def generate_stim_sequence():
    total_stim3_trials = n_reversals * trials_per_reversal
    total_trials = int(total_stim3_trials / stim_percentages[2])
    stim_counts = [int(total_trials * p) for p in stim_percentages]
    stim_list = sum([[stim_names[i]] * stim_counts[i] for i in range(3)], [])
    np.random.shuffle(stim_list)
    for i in range(3, len(stim_list)):
        if stim_list[i] == stim_list[i-1] == stim_list[i-2] == stim_list[i-3]:
            alternatives = [s for s in stim_names if s != stim_list[i]]
            stim_list[i] = np.random.choice(alternatives)
    return stim_list

def insert_decision_trials(stim_sequence):
    pairs = [
        ('stim1', 'stim2', ''),
        ('stim1', 'stim3', 'low'),
        ('stim1', 'stim3', 'high'),
        ('stim2', 'stim3', 'low'),
        ('stim2', 'stim3', 'high')
    ]
    jitter_range = np.linspace(20, len(stim_sequence) - 20, no_decisions * len(pairs), dtype=int) + np.random.randint(-5, 6, no_decisions * len(pairs))
    inserted = 0
    for i, (s1, s2, state) in enumerate(pairs):
        for j in range(no_decisions):
            idx = max(0, min(len(stim_sequence), jitter_range[inserted]))
            stim_sequence.insert(idx, f"DECISION_{s1}_vs_{s2}_{state}")
            inserted += 1
    return stim_sequence

def generate_stim3_outcomes(n_trials):
    outcomes = []
    states = []
    state = np.random.choice(['low', 'high'])
    while len(outcomes) < n_trials:
        segment_len = trials_per_reversal + np.random.randint(-sigma_switch, sigma_switch + 1)
        prob = reinforcement_rates[1 if state == 'high' else 0]
        while True:
            segment = [int(np.random.rand() < prob) for _ in range(segment_len)]
            correct = sum(segment) if state == 'high' else segment.count(0)
            if 11 <= correct <= 13:
                break
        outcomes.extend(segment)
        states.extend([state] * len(segment))
        state = 'low' if state == 'high' else 'high'
    return outcomes[:n_trials], states[:n_trials]

def generate_outcomes(stim_sequence):
    stim3_indices = [i for i, stim in enumerate(stim_sequence) if stim == 'stim3']
    n_stim3 = len(stim3_indices)
    stim3_outcomes, stim3_states = generate_stim3_outcomes(n_stim3)

    outcomes = []
    ISIs, ITIs, states, reversals, trial_types = [], [], [], [], []
    expectedness, trial_since_reversal = [], []

    stim3_count = 0
    stim3_reversal_counter = 1
    stim3_last_state = stim3_states[0] if stim3_states else None
    reversal_index = 0

    for trial, stim in enumerate(stim_sequence):
        if stim.startswith('DECISION'):
            outcomes.append('')
            ISIs.append('')
            ITIs.append('')
            states.append('')
            reversals.append('')
            trial_types.append('decision')
            expectedness.append('')
            trial_since_reversal.append('')
        else:
            trial_types.append('rating')
            if stim == 'stim3':
                outcome = stim3_outcomes[stim3_count]
                state = stim3_states[stim3_count]
                if stim3_count > 0 and state != stim3_last_state:
                    stim3_reversal_counter = 1
                else:
                    stim3_reversal_counter += 1
                stim3_last_state = state

                if (state == 'low' and outcome == 0) or (state == 'high' and outcome == 1):
                    tag = 'regular'
                elif stim3_reversal_counter < oddball_cutoff:
                    tag = 'oddball-early'
                else:
                    tag = 'oddball'

                reversal = reversal_index + 1 if stim3_count > 0 and stim3_states[stim3_count] != stim3_states[stim3_count - 1] else reversal_index
                if stim3_count > 0 and stim3_states[stim3_count] != stim3_states[stim3_count - 1]:
                    reversal_index += 1
                current_tsr = stim3_reversal_counter
                stim3_count += 1

            else:
                state = 'low' if stim == 'stim1' else 'high'
                prob = reinforcement_rates[0 if stim == 'stim1' else 1]
                outcome = int(np.random.rand() < prob)
                tag = 'regular' if (state == 'low' and outcome == 0) or (state == 'high' and outcome == 1) else 'oddball'
                reversal = ''
                current_tsr = ''

            outcomes.append(outcome)
            ISIs.append(jittered_interval(2500))
            ITIs.append(jittered_interval(2500))
            states.append(state)
            reversals.append(reversal)
            expectedness.append(tag)
            trial_since_reversal.append(current_tsr)

    return pd.DataFrame({
        'stimulus': stim_sequence,
        'outcome': outcomes,
        'ISI': ISIs,
        'ITI': ITIs,
        'stim3_state': states,
        'reversal': reversals,
        'trial_type': trial_types,
        'expectedness': expectedness,
        'trial_since_reversal': trial_since_reversal,
        'sigma_switch': sigma_switch,
        'reinforcement_low': reinforcement_rates[0],
        'reinforcement_high': reinforcement_rates[1],
        'stim1_pct': stim_percentages[0],
        'stim2_pct': stim_percentages[1],
        'stim3_pct': stim_percentages[2],
        'alpha': alpha,
        'no_decisions_per_type': no_decisions,
        'oddball_cutoff': oddball_cutoff
    })

def rescorla_wagner(outcomes, stim_sequence, target_stim):
    V = 0.5
    values = []
    for outcome, stim in zip(outcomes, stim_sequence):
        if stim == target_stim:
            V += alpha * (int(outcome) - V)
            values.append(V)
        elif stim.startswith('DECISION'):
            values.append(np.nan)
        else:
            values.append(np.nan)
    return values

def plot_schedule(df, sid):
    fig, axs = plt.subplots(3, 1, figsize=(12, 9), sharex=False)
    for i, stim in enumerate(stim_names):
        stim_df = df[df['stimulus'] == stim].reset_index()
        ax = axs[i]
        ax.set_title(f'{stim} - Schedule {sid}')
        ax.set_ylim(-0.1, 1.1)
        ax.plot(stim_df.index, stim_df['outcome'], 'ro', label='Outcome')

        rw_values = rescorla_wagner(df['outcome'].tolist(), df['stimulus'].tolist(), stim)
        rw_filtered = [v for v, s in zip(rw_values, df['stimulus']) if s == stim]
        ax.plot(stim_df.index, rw_filtered, label='RW Value')

        if stim == 'stim3':
            stim3_trial_indices = df[df['stimulus'] == 'stim3'].index.tolist()
            for idx in df[df['trial_type'] == 'decision'].index:
                if stim3_trial_indices:
                    closest_stim3 = min(stim3_trial_indices, key=lambda x: abs(x - idx))
                    stim3_pos = df[df['stimulus'] == 'stim3'].index.get_loc(closest_stim3)
                    ax.axvspan(stim3_pos - 0.5, stim3_pos + 0.5, color='crimson', alpha=0.2)

        ax.legend()
    plt.xlabel("Trial Index (within stimulus)")
    plt.tight_layout()
    plt.savefig(f'schedules/schedule_{sid}.png')
    plt.close()

for sid in range(n_schedules):
    stim_seq = generate_stim_sequence()
    stim_seq = insert_decision_trials(stim_seq)
    df = generate_outcomes(stim_seq)
    df.to_csv(f'schedules/schedule_{sid}.csv', index=False)
    plot_schedule(df, sid)