# SARSA Behavioural Dataset Demo

This notebook demonstrates how to fit the SARSA agent implemented in `sarsa.sarsa`
against the behavioural dataset bundled with the project.

## Project bootstrap
Resolve the repository root so relative imports such as `sarsa` keep working,
and record the path to the bundled behavioural dataset (`examples/M1.csv`).


In [13]:
from pathlib import Path
import sys

PROJECT_ROOT = Path.cwd()
while PROJECT_ROOT != PROJECT_ROOT.parent and not (PROJECT_ROOT / "pyproject.toml").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

for candidate in (PROJECT_ROOT / "src", PROJECT_ROOT / "examples"):
    path_str = str(candidate)
    if path_str not in sys.path:
        sys.path.insert(0, path_str)

DATA_PATH = PROJECT_ROOT / "examples" / "M1.csv"


## SARSA configuration
Define the action/state configuration, reward constants, and any custom parameter bounds
so we can initialise the solver consistently across runs.


In [14]:
import numpy as np
import pandas as pd

from sarsa import sarsa
from experiment import (
    Location,
    StateAxis,
    downsample_behavior_data,
    process_data,
    row_to_state,
    rt_lc_unp_state_spec,
)

rng = np.random.default_rng(0)


In [15]:
MIN_PENALTY = 1.0  # minimum shock penalty
REWARD_VALUE = 1.0  # value for liquid reward
CUSTOM_PARAM_BOUNDS = [
    (MIN_PENALTY, None),  # shock
    (0.0, None),  # avoidance
]

STATE_SPEC = rt_lc_unp_state_spec
ACTION_SIZE = 3  # size of action set


## Fit and inspect the policy
Run `run_session` to build quintuples, optimise the SARSA parameters, and
capture diagnostics (`loss`, `params`) for the fitted policy.


In [16]:
def calc_reward(params, state):
    """Calculate the net reward of a state."""
    reward_value = REWARD_VALUE
    shock_value = params[3]
    escape_value = params[4]
    val = 0.0

    if state[StateAxis.Loc] == Location.R and state[StateAxis.Light] > 0:
        val += reward_value

    if state[StateAxis.Tone] == 3:
        if state[StateAxis.Loc] == Location.P:
            # Successful avoidance!!!
            val += escape_value
        else:
            # Shock!!!
            val -= shock_value

    return val


def init_params(rng, bounds):
    """Initialise the SARSA parameter vector within provided bounds."""
    bmin = np.array([b[0] for b in bounds])
    p0 = bmin + 0.5 * rng.random(size=len(sarsa.ParamIndex) + 2)
    return p0


def make_quintuples(behavior_data):
    """Construct SARSA training quintuples from processed behavioural data."""
    quintuples = []
    horizon = len(behavior_data)
    for t in range(horizon - 2):
        t1 = behavior_data.iloc[t]
        t2 = behavior_data.iloc[t + 1]
        t3 = behavior_data.iloc[t + 2]
        s1 = row_to_state(t1)
        s2 = row_to_state(t2)
        s3 = row_to_state(t3)
        a1 = s2[StateAxis.Loc]
        a2 = s3[StateAxis.Loc]
        r2 = np.nan
        quintuples.append(sarsa.Quintuple(s1=s1, a1=a1, r2=r2, s2=s2, a2=a2))
    return quintuples


def run_session(path: Path, rng: np.random.Generator):
    """Fit SARSA to the behavioural dataset found at ``path``."""
    behavior_data = pd.read_csv(path, encoding="unicode_escape", header=0)
    behavior_data = behavior_data.rename(columns={behavior_data.columns[0]: "Time (s)"})
    behavior_data.columns = map(str.upper, behavior_data.columns)
    behavior_data = downsample_behavior_data(behavior_data, "1s")
    behavior_data = process_data(behavior_data, "LC")

    quintuples = make_quintuples(behavior_data)
    q0 = np.zeros((*STATE_SPEC, ACTION_SIZE))
    param_bounds = sarsa.PARAM_BOUNDS + CUSTOM_PARAM_BOUNDS
    p0 = init_params(rng, param_bounds)

    params, loss, q_trajectory, action_prob = sarsa.fit(
        quintuples,
        q0=q0,
        p0=p0,
        static_params=None,
        reward_func=calc_reward,
        custom_param_bounds=CUSTOM_PARAM_BOUNDS,
    )
    return params, loss, q_trajectory, action_prob, quintuples


In [17]:
params, loss, q_trajectory, action_prob, quintuples = run_session(DATA_PATH, rng)
loss


np.float64(1.0986122886681093)

In [18]:
quintuples

[Quintuple(s1=array([1, 0, 0]), a1=np.int64(0), r2=0.0, s2=array([0, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([0, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(2)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(2), r2=0.0, s2=array([2, 0, 0]), a2=np.int64(0)),
 Quintuple(s1=array([2, 0, 0]), a1=np.int64(0), r2=0.0, s2=array([0, 0, 0]), a2=np.int64(2)),
 Quintuple(s1=array([0, 0, 0]), a1=np.int64(2), r2=0.0, s2=array([2, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([2, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(1)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(1), r2=0.0, s2=array([1, 0, 0]), a2=np.int64(2)),
 Quintuple(s1=array([1, 0, 0]), a1=np.int64(2), r2=0.0, s2=a

In [19]:
params


array([0.31848085, 0.13489337, 0.02048677, 1.00826382, 0.40663512])

Save `params`, `loss`, `q_trajectory` and `action_prob` for analysis.