In [None]:
from abc import ABC, abstractmethod

import numpy as np
from scipy.optimize import Bounds, LinearConstraint, minimize
from tqdm import tqdm



In [None]:
class DecisionModel(ABC):
    _NUM_ACTIONS = 2

    def __init__(
        self,
        params: np.ndarray,
        param_names: list[str],
        param_bounds: Bounds | None = None,
        param_constraints: LinearConstraint | None = None,
    ) -> None:
        if type(params) is not np.ndarray:
            raise ValueError("'params' must be a numpy array.")
        if len(params) != len(param_names):
            raise ValueError("Length of 'params' must match length of 'param_names'.")
        self._params = params
        self._param_names = param_names
        self._param_bounds = param_bounds
        self._param_constraints = param_constraints

    @property
    def params(self) -> np.ndarray:
        return self._params

    @params.setter
    def params(self, new_params: np.ndarray) -> None:
        self._params = new_params

    @property
    def param_bounds(self) -> Bounds | None:
        return self._param_bounds

    @property
    def param_constraints(self) -> LinearConstraint | None:
        return self._param_constraints

    @abstractmethod
    def action_probabilities(self, stimuli: float | np.ndarray) -> np.ndarray:
        raise NotImplementedError

    def sample(self, stimuli: float | np.ndarray) -> np.ndarray:
        probabilities = self.action_probabilities(stimuli)
        if probabilities.ndim == 1:
            probabilities = probabilities.reshape(1, -1)
        return np.array(
            [np.random.choice(self._NUM_ACTIONS, p=prob) for prob in probabilities]
        )

    def likelihood(
        self, stimuli: float | np.ndarray, actions: float | np.ndarray
    ) -> np.ndarray:
        probabilities = self.action_probabilities(stimuli)
        if type(actions) is not np.ndarray:
            return probabilities[0, actions]
        likelihoods = np.zeros_like(actions, dtype=float)
        for i in range(probabilities.shape[-1]):
            likelihoods[actions == i] = probabilities[actions == i, i]
        return likelihoods

    def __repr__(self) -> str:
        param_str = ", ".join(
            f"{name}={value}" for name, value in zip(self._param_names, self.params)
        )
        return f"{self.__class__.__name__}({param_str})"


class CategoricalDecisionModel(DecisionModel):
    def __init__(self, probabilities: np.ndarray | list | None = None) -> None:
        if probabilities is None:
            probabilities = np.random.random(self._NUM_ACTIONS)
            probabilities /= probabilities.sum()
        elif type(probabilities) is list:
            probabilities = np.array(probabilities)
        param_names = [f"p{i}" for i in range(len(probabilities))]
        param_bounds = Bounds(0, 1, keep_feasible=True)
        param_constraints = LinearConstraint(np.ones(self._NUM_ACTIONS), 1, 1)
        super().__init__(probabilities, param_names, param_bounds, param_constraints)

    def action_probabilities(self, stimuli: float | np.ndarray) -> np.ndarray:
        if type(stimuli) is not np.ndarray:
            return self._params
        return np.tile(self._params, reps=stimuli.shape + (1,))


class LogisticDecisionModel(DecisionModel):
    def __init__(
        self, bias: float | None = None, stim_weight: float | None = None
    ) -> None:
        if bias is None:
            bias = np.random.random()
        if stim_weight is None:
            stim_weight = np.random.random()
        params = np.array([bias, stim_weight])
        param_names = ["bias", "w_stim"]
        super().__init__(params, param_names)

    def action_probabilities(
        self, stimuli: float | np.ndarray | None = None
    ) -> np.ndarray:
        if type(stimuli) is not np.ndarray:
            stimuli = np.array([stimuli])
        bias, weight = self._params
        p = 1 / (1 + np.exp(-bias - weight * stimuli))
        return np.stack([1 - p, p], axis=-1)



In [None]:
initial_probs = [0.5, 0.5]
transition_probs = np.array(
    [
        [0.95, 0.05],
        [0.02, 0.98],
    ]
)
decision_models = [
    LogisticDecisionModel(bias=0.0, stim_weight=5.0),
    LogisticDecisionModel(bias=5, stim_weight=0.5),
]
time_steps = 250
num_sequences = 100


def generate_data(
    initial_probs, transition_probs, decision_models, time_steps, num_sequences
):
    stimuli = []
    data = []
    for _ in tqdm(range(num_sequences)):
        stim, observations = generate_sequence(
            initial_probs, transition_probs, decision_models, time_steps
        )
        stimuli.append(stim)
        data.append(observations)
    return np.array(stimuli), np.array(data)


def generate_sequence(initial_probs, transition_probs, decision_models, time_steps):
    stimuli = np.random.random(time_steps) * 2 - 1
    num_states = len(initial_probs)
    states = [np.random.choice(num_states, p=initial_probs)]
    for _ in range(time_steps - 1):
        states.append(np.random.choice(num_states, p=transition_probs[states[-1]]))
    observations = []
    for stimulus, state in zip(stimuli, states):
        observations.append(decision_models[state].sample(stimulus))
    observations = np.concatenate(observations)
    return stimuli, observations



In [None]:
stimuli, data = generate_data(
    initial_probs, transition_probs, decision_models, time_steps, num_sequences
)


In [None]:
def forward(
    stimuli: np.ndarray,
    sequence: np.ndarray,
    initial_probs: np.ndarray,
    transition_probs: np.ndarray,
    decision_models: list[DecisionModel],
) -> np.ndarray:
    num_states = len(initial_probs)
    alpha = np.zeros((num_states, len(sequence)))
    alpha[:, 0] = initial_probs * [
        dm.likelihood(stimuli[0], sequence[0]) for dm in decision_models
    ]
    for t in range(1, len(sequence)):
        for s in range(num_states):
            alpha[s, t] = np.sum(
                alpha[:, t - 1] * transition_probs[:, s]
            ) * decision_models[s].likelihood(stimuli[t], sequence[t])
        alpha[:, t] /= np.sum(alpha[:, t])
    return alpha


def backward(stimuli, sequence, initial_probs, transition_probs, decision_models):
    num_states = len(initial_probs)
    beta = np.zeros((num_states, len(sequence)))
    beta[:, -1] = 1
    for t in range(len(sequence) - 2, -1, -1):
        for s in range(num_states):
            beta[s, t] = np.sum(
                beta[:, t + 1]
                * transition_probs[s, :]
                * [
                    dm.likelihood(stimuli[t + 1], sequence[t + 1])
                    for dm in decision_models
                ]
            )
    return beta


def e_step(stimuli, data, initial_probs, transition_probs, decision_models):
    gamma = []
    xi = []
    for stim, seq in zip(stimuli, data):
        gamma_s, xi_s = e_step_helper(
            stim, seq, initial_probs, transition_probs, decision_models
        )
        gamma.append(gamma_s)
        xi.append(xi_s)
    gamma = np.array(gamma)
    xi = np.array(xi)
    return gamma, xi


def e_step_helper(stimuli, sequence, initial_probs, transition_probs, decision_models):
    num_states = len(initial_probs)
    alpha = forward(stimuli, sequence, initial_probs, transition_probs, decision_models)
    beta = backward(stimuli, sequence, initial_probs, transition_probs, decision_models)
    gamma = alpha * beta
    gamma /= np.sum(gamma, axis=0)
    xi = np.zeros((num_states, num_states, len(sequence) - 1))
    for t in range(len(sequence) - 1):
        for i in range(num_states):
            for j in range(num_states):
                xi[i, j, t] = (
                    alpha[i, t]
                    * transition_probs[i, j]
                    * decision_models[j].likelihood(stimuli[t + 1], sequence[t + 1])
                    * beta[j, t + 1]
                )
    xi /= np.sum(xi, axis=(0, 1))
    return gamma, xi


def m_step_latents(data, gamma, xi):
    initial_probs = np.mean(gamma[:, :, 0], axis=0)
    transition_probs = (
        np.sum(xi, axis=(0, 3)) / np.sum(gamma[:, :, :-1], axis=(0, 2))[:, np.newaxis]
    )
    return initial_probs, transition_probs


def m_step_observations(stimuli, data, gamma, decision_models):
    def decision_model_nll(params, gamma, decision_model):
        decision_model.params = params
        return -np.sum(gamma * np.log(decision_model.likelihood(stimuli, data)))

    for i, dm in enumerate(decision_models):
        dm.params = minimize(
            decision_model_nll,
            dm.params,
            args=(gamma[:, i, :], dm),
            method="trust-constr",
            bounds=dm.param_bounds,
            constraints=dm.param_constraints,
        ).x


def baum_welch(stimuli, data, num_states, model_type, num_iters=100):
    initial_probs = np.random.random(num_states)
    initial_probs /= initial_probs.sum()
    transition_probs = np.random.random((num_states, num_states))
    transition_probs /= transition_probs.sum(axis=1)[:, np.newaxis]
    decision_models = [model_type() for _ in range(num_states)]
    for _ in tqdm(range(num_iters)):
        gamma, xi = e_step(
            stimuli, data, initial_probs, transition_probs, decision_models
        )
        initial_probs, transition_probs = m_step_latents(data, gamma, xi)
        m_step_observations(stimuli, data, gamma, decision_models)
        print(initial_probs)
        print(transition_probs)
        print([dm.params for dm in decision_models])
    return initial_probs, transition_probs, decision_models



In [None]:
baum_welch(stimuli, data, 2, LogisticDecisionModel)

