# **Generate Reversal ABC Multi-Timestep Task Sequence**

This notebook generates reversal ABC task sequences with multi-timestep trial structure:

**Trial Structure:**
- **Stimulus window**: Multiple timesteps showing the stimulus
- **Reward availability window**: Multiple timesteps where reward can be obtained
- **Outcome state**: Single timestep showing outcome (determined by action)
- **ITI**: Random timesteps (between min_iti and max_iti) showing no stimulus

**Task Rules:**
- Stimulus A: Always rewarded in pre-reversal, never in post-reversal
- Stimulus B: Never rewarded in pre-reversal, 50% rewarded in post-reversal
- Stimulus C: Random reward (50% probability, doesn't reverse)

Sequences are saved as pickle files for use with Gymnasium environments.


In [None]:
import pickle
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder

# Set random seed for reproducibility
np.random.seed(42)


---
Define Stimuli and Rewards


In [None]:
# Define stimulus identities and reward values
stimuli = {"A": 0, "B": 1, "C": 2}
rewards = {"no_reward": 0, "reward": 1}

def generate_random_reward(reward1, reward2, prob):
    """Generate random reward with given probability."""
    if np.random.rand(1) < prob:
        return reward1
    else:
        return reward2


---
Set Task Parameters


In [None]:
# Task parameters
num_pre_reversal_trials = 4000  # Number of trials before reversal
num_post_reversal_trials = 4000  # Number of trials after reversal

# Trial structure parameters
stim_window = 5      # Number of timesteps for stimulus presentation
reward_window = 3    # Number of timesteps for reward availability
min_iti = 10         # Minimum ITI timesteps
max_iti = 20         # Maximum ITI timesteps

# Random seed
seed = 42
np.random.seed(seed)

print(f"Task parameters:")
print(f"  Pre-reversal trials: {num_pre_reversal_trials}")
print(f"  Post-reversal trials: {num_post_reversal_trials}")
print(f"  Stimulus window: {stim_window} timesteps")
print(f"  Reward window: {reward_window} timesteps")
print(f"  ITI range: {min_iti}-{max_iti} timesteps")


---
Generate Trial-Level Data


In [None]:
# Generate trial-level data (stimulus and reward availability)
trial_data = {
    "stimuli": [],
    "rewards": [],
    "masks": {"reversal": []}
}

# Pre-reversal phase: A rewarded, B not, C random 50%
for _ in range(num_pre_reversal_trials):
    stim = np.random.choice([stimuli["A"], stimuli["B"], stimuli["C"]], p=[1/3, 1/3, 1/3])
    
    if stim == stimuli["A"]:
        reward = rewards["reward"]
    elif stim == stimuli["B"]:
        reward = rewards["no_reward"]
    else:  # C
        reward = generate_random_reward(rewards["reward"], rewards["no_reward"], prob=0.5)
    
    trial_data["stimuli"].append(stim)
    trial_data["rewards"].append(reward)
    trial_data["masks"]["reversal"].append(0)

# Post-reversal phase: A not rewarded, B rewarded, C still random 50%
for _ in range(num_post_reversal_trials):
    stim = np.random.choice([stimuli["A"], stimuli["B"], stimuli["C"]], p=[1/3, 1/3, 1/3])
    
    if stim == stimuli["A"]:
        reward = rewards["no_reward"]
    elif stim == stimuli["B"]:
        reward = rewards["reward"]
    else:  # C
        reward = generate_random_reward(rewards["reward"], rewards["no_reward"], prob=0.5)
    
    trial_data["stimuli"].append(stim)
    trial_data["rewards"].append(reward)
    trial_data["masks"]["reversal"].append(1)

print(f"Generated trial-level data:")
print(f"  Total trials: {len(trial_data['stimuli'])}")
print(f"  Pre-reversal: {num_pre_reversal_trials} trials")
print(f"  Post-reversal: {num_post_reversal_trials} trials")


---
Expand to Timestep-Level Sequence


In [None]:
# Now expand to timestep-level sequence
# States: A=0, B=1, C=2, reward_unknown=3, unrewarded=4, rewarded=5, ITI=6
state_map = {"A": 0, "B": 1, "C": 2, "reward_unknown": 3, "unrewarded": 4, "rewarded": 5, "ITI": 6}
state_sequence = []
reward_sequence = []
trial_structure = []  # Track which timesteps belong to which trial and phase

trial_idx = 0
timestep = 0

for stim, reward_avail, reversal_phase in zip(
    trial_data["stimuli"],
    trial_data["rewards"],
    trial_data["masks"]["reversal"]
):
    trial_start_timestep = timestep
    
    # Stimulus window: show stimulus
    stim_timesteps = []
    for _ in range(stim_window):
        state_sequence.append(stim)  # Stimulus state
        reward_sequence.append(0.0)  # No reward during stimulus window
        stim_timesteps.append(timestep)
        timestep += 1
    
    # Reward availability window: start in "reward_unknown" outcome state
    # The environment will transition to rewarded/unrewarded based on action
    reward_timesteps = []
    for _ in range(reward_window):
        state_sequence.append(state_map["reward_unknown"])  # Outcome state, reward unknown
        reward_sequence.append(1.0 if reward_avail == rewards["reward"] else 0.0)
        reward_timesteps.append(timestep)
        timestep += 1
    
    # No separate outcome_timestep - outcome is determined during reward window
    
    # ITI: random length
    iti_length = np.random.randint(min_iti, max_iti + 1)
    iti_timesteps = []
    for _ in range(iti_length):
        state_sequence.append(state_map["ITI"])
        reward_sequence.append(0.0)
        iti_timesteps.append(timestep)
        timestep += 1
    
    trial_structure.append({
        "trial_idx": trial_idx,
        "stimulus": stim,
        "reward_available": reward_avail == rewards["reward"],
        "reversal_phase": reversal_phase,
        "trial_start": trial_start_timestep,
        "stim_window": stim_timesteps,
        "reward_window": reward_timesteps,  # These are the "reward_unknown" outcome states
        "iti_window": iti_timesteps,
        "trial_end": timestep - 1
    })
    
    trial_idx += 1

print(f"Expanded to timestep-level sequence:")
print(f"  Total timesteps: {len(state_sequence)}")
print(f"  Average trial length: {len(state_sequence) / len(trial_structure):.1f} timesteps")


---
Convert to One-Hot Encoding (7D)


In [None]:
# Convert to one-hot encoding
# States: A=0, B=1, C=2, reward_unknown=3, unrewarded=4, rewarded=5, ITI=6
# We'll use 7D encoding: [A, B, C, reward_unknown, unrewarded, rewarded, ITI]
state_sequence_ohe_7d = np.zeros((len(state_sequence), 7), dtype=np.float32)
for i, state_idx in enumerate(state_sequence):
    if 0 <= state_idx < 7:
        state_sequence_ohe_7d[i, state_idx] = 1.0
    # ITI (6) is encoded as [0, 0, 0, 0, 0, 0, 1]

print(f"Converted to one-hot encoding:")
print(f"  State sequence shape: {state_sequence_ohe_7d.shape}")
print(f"  Reward sequence shape: {len(reward_sequence)}")
print(f"  Number of trials: {len(trial_structure)}")


---
Calculate Phase Boundaries


In [None]:
# Calculate phase boundaries
pre_reversal_end_timestep = trial_structure[num_pre_reversal_trials - 1]["trial_end"] + 1
phase_boundaries = {
    "reversal_points": [pre_reversal_end_timestep],
    "pre_reversal": {
        "start": 0,
        "end": pre_reversal_end_timestep
    },
    "post_reversal": {
        "start": pre_reversal_end_timestep,
        "end": timestep
    }
}

print(f"Phase boundaries:")
print(f"  Pre-reversal: timesteps {phase_boundaries['pre_reversal']['start']} to {phase_boundaries['pre_reversal']['end']}")
print(f"  Post-reversal: timesteps {phase_boundaries['post_reversal']['start']} to {phase_boundaries['post_reversal']['end']}")
print(f"  Reversal point: {phase_boundaries['reversal_points']}")


---
Prepare Data Dictionary


In [None]:
# Prepare data dictionary
data = {
    "state_sequence_ohe": state_sequence_ohe_7d,
    "reward_sequence": np.array(reward_sequence, dtype=np.float32),
    "sequence": {
        "stimuli": trial_data["stimuli"],
        "rewards": trial_data["rewards"],
        "masks": trial_data["masks"]
    },
    "phase_boundaries": phase_boundaries,
    "trial_structure": trial_structure,
    "state_map": {"A": 0, "B": 1, "C": 2, "reward_unknown": 3, "unrewarded": 4, "rewarded": 5, "ITI": 6},
    "trial_params": {
        "stim_window": stim_window,
        "reward_window": reward_window,
        "min_iti": min_iti,
        "max_iti": max_iti
    }
}

print("Data dictionary prepared:")
print(f"  Keys: {list(data.keys())}")
print(f"  State sequence shape: {data['state_sequence_ohe'].shape}")
print(f"  Reward sequence length: {len(data['reward_sequence'])}")
print(f"  Trial structure length: {len(data['trial_structure'])}")


---
Save Data


In [None]:
# Save to pickle file
output_path = Path("/Users/pmccarthy/Documents/cogNN/task_data/reversal_abc_multitimestep.pkl")
output_path.parent.mkdir(parents=True, exist_ok=True)

with open(output_path, "wb") as f:
    pickle.dump(data, f)

print(f"\nGenerated reversal_abc_multitimestep task:")
print(f"  Total timesteps: {len(data['state_sequence_ohe'])}")
print(f"  Pre-reversal: {phase_boundaries['pre_reversal']['end']} timesteps")
print(f"  Post-reversal: {phase_boundaries['post_reversal']['end'] - phase_boundaries['post_reversal']['start']} timesteps")
print(f"  Saved to: {output_path}")
