# **Generate Multi-Reversal ABC Multi-Timestep Task Sequence**

This notebook generates multi-reversal ABC task sequences with multi-timestep trial structure:

**Trial Structure:**
- **Stimulus window**: Multiple timesteps showing the stimulus
- **Reward availability window**: Multiple timesteps where reward can be obtained
- **Outcome state**: Determined by action during reward window
- **ITI**: Random timesteps (between min_iti and max_iti) showing no stimulus

**Task Rules:**
- Stimulus A: Rewarded in even phases (0, 2, 4, ...), never in odd phases (1, 3, 5, ...)
- Stimulus B: Never rewarded in even phases, rewarded in odd phases
- Stimulus C: Random reward (50% probability, doesn't reverse)

Sequences are saved as pickle files for use with Gymnasium environments.

In [None]:
import pickle
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
np.random.seed(42)

---
Define Stimuli and Rewards

In [None]:
# Define stimulus identities and reward values
stimuli = {"A": 0, "B": 1, "C": 2}
rewards = {"no_reward": 0, "reward": 1}

def generate_random_reward(reward1, reward2, prob):
    """Generate random reward with given probability."""
    if np.random.rand(1) < prob:
        return reward1
    else:
        return reward2

---
Set Task Parameters

In [None]:
# Multi-reversal task parameters
# Specify number of trials for each reversal phase
# A and B reverse: A rewarded -> B rewarded -> A rewarded -> B rewarded -> ...
# C always has random 50% reward
phase_trials = [2000, 2000, 2000, 2000]  # 4 phases: A rewarded, B rewarded, A rewarded, B rewarded

# Trial structure parameters
stim_window = 5      # Number of timesteps for stimulus presentation
reward_window = 3    # Number of timesteps for reward availability
min_iti = 10         # Minimum ITI timesteps
max_iti = 20         # Maximum ITI timesteps

# Random seed
seed = 42
np.random.seed(seed)

print(f"Task parameters:")
print(f"  Number of phases: {len(phase_trials)}")
for i, num_trials in enumerate(phase_trials):
    phase_type = "A rewarded" if i % 2 == 0 else "B rewarded"
    print(f"  Phase {i} ({phase_type}): {num_trials} trials")
print(f"  Stimulus window: {stim_window} timesteps")
print(f"  Reward window: {reward_window} timesteps")
print(f"  ITI range: {min_iti}-{max_iti} timesteps")

---
Generate Trial-Level Data

In [None]:
# Generate trial-level data (stimulus and reward availability)
trial_data = {
    "stimuli": [],
    "rewards": [],
    "masks": {"reversal": []}
}

# Generate trials for each phase
# Phase 0: A rewarded, B not, C random 50%
# Phase 1: B rewarded, A not, C random 50%
# Phase 2: A rewarded, B not, C random 50%
# Phase 3: B rewarded, A not, C random 50%
# ... and so on

for phase_idx, num_trials in enumerate(phase_trials):
    # Determine which stimulus is rewarded in this phase
    # Even phases (0, 2, 4, ...): A rewarded
    # Odd phases (1, 3, 5, ...): B rewarded
    a_rewarded = (phase_idx % 2 == 0)
    
    for _ in range(num_trials):
        stim = np.random.choice([stimuli["A"], stimuli["B"], stimuli["C"]], p=[1/3, 1/3, 1/3])
        
        if stim == stimuli["A"]:
            reward = rewards["reward"] if a_rewarded else rewards["no_reward"]
        elif stim == stimuli["B"]:
            reward = rewards["reward"] if not a_rewarded else rewards["no_reward"]
        else:  # C
            reward = generate_random_reward(rewards["reward"], rewards["no_reward"], prob=0.5)
        
        trial_data["stimuli"].append(stim)
        trial_data["rewards"].append(reward)
        trial_data["masks"]["reversal"].append(phase_idx)

print(f"Generated trial-level data:")
print(f"  Total trials: {len(trial_data['stimuli'])}")
for i, num_trials in enumerate(phase_trials):
    phase_type = "A rewarded" if i % 2 == 0 else "B rewarded"
    print(f"  Phase {i} ({phase_type}): {num_trials} trials")

---
Expand to Timestep-Level Sequence

In [None]:
# Now expand to timestep-level sequence
# States: A=0, B=1, C=2, reward_unknown=3, unrewarded=4, rewarded=5, ITI=6
state_map = {"A": 0, "B": 1, "C": 2, "reward_unknown": 3, "unrewarded": 4, "rewarded": 5, "ITI": 6}
state_sequence = []
reward_sequence = []
trial_structure = []  # Track which timesteps belong to which trial and phase

trial_idx = 0
timestep = 0

for stim, reward_avail, reversal_phase in zip(
    trial_data["stimuli"],
    trial_data["rewards"],
    trial_data["masks"]["reversal"]
):
    trial_start_timestep = timestep
    
    # Stimulus window: show stimulus
    stim_timesteps = []
    for _ in range(stim_window):
        state_sequence.append(stim)  # Stimulus state
        reward_sequence.append(0.0)  # No reward during stimulus window
        stim_timesteps.append(timestep)
        timestep += 1
    
    # Reward availability window: start in "reward_unknown" outcome state
    # The environment will transition to rewarded/unrewarded based on action
    reward_timesteps = []
    for _ in range(reward_window):
        state_sequence.append(state_map["reward_unknown"])  # Outcome state, reward unknown
        reward_sequence.append(1.0 if reward_avail == rewards["reward"] else 0.0)
        reward_timesteps.append(timestep)
        timestep += 1
    
    # ITI: random length
    iti_length = np.random.randint(min_iti, max_iti + 1)
    iti_timesteps = []
    for _ in range(iti_length):
        state_sequence.append(state_map["ITI"])
        reward_sequence.append(0.0)
        iti_timesteps.append(timestep)
        timestep += 1
    
    trial_structure.append({
        "trial_idx": trial_idx,
        "stimulus": stim,
        "reward_available": reward_avail == rewards["reward"],
        "reversal_phase": reversal_phase,
        "trial_start": trial_start_timestep,
        "stim_window": stim_timesteps,
        "reward_window": reward_timesteps,  # These are the "reward_unknown" outcome states
        "iti_window": iti_timesteps,
        "trial_end": timestep - 1
    })
    
    trial_idx += 1

print(f"Expanded to timestep-level sequence:")
print(f"  Total timesteps: {len(state_sequence)}")
print(f"  Average trial length: {len(state_sequence) / len(trial_structure):.1f} timesteps")

---
Convert to One-Hot Encoding (7D)

In [None]:
# Convert to one-hot encoding
# States: A=0, B=1, C=2, reward_unknown=3, unrewarded=4, rewarded=5, ITI=6
# We'll use 7D encoding: [A, B, C, reward_unknown, unrewarded, rewarded, ITI]
state_sequence_ohe_7d = np.zeros((len(state_sequence), 7), dtype=np.float32)
for i, state_idx in enumerate(state_sequence):
    if 0 <= state_idx < 7:
        state_sequence_ohe_7d[i, state_idx] = 1.0

print(f"Converted to one-hot encoding:")
print(f"  State sequence shape: {state_sequence_ohe_7d.shape}")
print(f"  Reward sequence length: {len(reward_sequence)}")
print(f"  Number of trials: {len(trial_structure)}")

---
Calculate Phase Boundaries

In [None]:
# Calculate phase boundaries
phase_boundaries = {"reversal_points": []}
cumulative_trials = 0

for phase_idx, num_trials in enumerate(phase_trials):
    if phase_idx == 0:
        phase_start_timestep = 0
    else:
        # Start of this phase is end of previous phase
        phase_start_timestep = trial_structure[cumulative_trials]["trial_start"]
        phase_boundaries["reversal_points"].append(phase_start_timestep)
    
    # End of this phase
    phase_end_trial_idx = cumulative_trials + num_trials - 1
    phase_end_timestep = trial_structure[phase_end_trial_idx]["trial_end"] + 1
    
    phase_boundaries[f"phase_{phase_idx}"] = {
        "start": phase_start_timestep,
        "end": phase_end_timestep
    }
    
    cumulative_trials += num_trials

print(f"Phase boundaries:")
for i in range(len(phase_trials)):
    phase_type = "A rewarded" if i % 2 == 0 else "B rewarded"
    print(f"  Phase {i} ({phase_type}): timesteps {phase_boundaries[f'phase_{i}']['start']} to {phase_boundaries[f'phase_{i}']['end']}")
print(f"  Reversal points: {phase_boundaries['reversal_points']}")

---
Prepare Data Dictionary

In [None]:
# Prepare data dictionary
data = {
    "state_sequence_ohe": state_sequence_ohe_7d,
    "reward_sequence": np.array(reward_sequence, dtype=np.float32),
    "sequence": {
        "stimuli": trial_data["stimuli"],
        "rewards": trial_data["rewards"],
        "masks": trial_data["masks"]
    },
    "phase_boundaries": phase_boundaries,
    "trial_structure": trial_structure,
    "state_map": {"A": 0, "B": 1, "C": 2, "reward_unknown": 3, "unrewarded": 4, "rewarded": 5, "ITI": 6},
    "trial_params": {
        "stim_window": stim_window,
        "reward_window": reward_window,
        "min_iti": min_iti,
        "max_iti": max_iti
    }
}

print("Data dictionary prepared:")
print(f"  Keys: {list(data.keys())}")
print(f"  State sequence shape: {data['state_sequence_ohe'].shape}")
print(f"  Reward sequence length: {len(data['reward_sequence'])}")
print(f"  Trial structure length: {len(data['trial_structure'])}")
print(f"  Number of phases: {len([k for k in data['phase_boundaries'].keys() if k.startswith('phase_')])}")

---
Save Data

In [None]:
# Save to pickle file
output_path = Path("/Users/pmccarthy/Documents/cogNN/task_data/reversal_abc_multitimestep_multi.pkl")
output_path.parent.mkdir(parents=True, exist_ok=True)

with open(output_path, "wb") as f:
    pickle.dump(data, f)

print(f"\nGenerated reversal_abc_multitimestep_multi task:")
print(f"  Total timesteps: {len(data['state_sequence_ohe'])}")
print(f"  Number of phases: {len([k for k in data['phase_boundaries'].keys() if k.startswith('phase_')])}")
print(f"  Reversal points: {data['phase_boundaries']['reversal_points']}")
print(f"  Saved to: {output_path}")