<a href="https://colab.research.google.com/github/tozanni/nma_wcst_rl/blob/main/human_rl_wcst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using RL to Model Wisconsin Card Sorting Task


**Original notebook credits:**

__Content creators:__ Morteza Ansarinia, Yamil Vidal

__Production editor:__ Spiros Chavlis


---
# Objective

- This project aims to use behavioral data to train an agent and then use the agent to investigate data produced by human subjects. Having a computational agent that mimics humans in such tests, we will be able to compare its mechanics with human data.

- In another conception, we could fit an agent that learns many cognitive tasks that require abstract-level constructs such as executive functions. This is a multi-task control problem.




---
# Setup

In [1]:
# @title Install dependencies
!pip install jedi --quiet
!pip install --upgrade pip setuptools wheel --quiet
!pip install dm-acme[jax] --quiet
!pip install dm-sonnet --quiet
!pip install trfl --quiet
!pip install numpy==1.23.3 --quiet --ignore-installed
!pip uninstall seaborn -y --quiet
!pip install seaborn --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m804.0/804.0 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.4/6.4 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.0/4.0 MB[0m [31m39.1 MB/s[0m eta [36

In [6]:
# Imports
import time
import numpy as np
import pandas as pd
import sonnet as snt
import seaborn as sns
import matplotlib.pyplot as plt

import dm_env

import acme
from acme import specs
from acme import wrappers
from acme import EnvironmentLoop
from acme.agents.tf import dqn
from acme.utils import loggers

In [None]:
# @title Figure settings
from IPython.display import clear_output, display, HTML
%matplotlib inline
sns.set()

  and should_run_async(code)


---
# Background

- Cognitive scientists use standard lab tests to tap into specific processes in the brain and behavior. Some examples of those tests are Stroop, N-back, Digit Span, TMT (Trail making tests), and WCST (Wisconsin Card Sorting Tests).

## Datasets

This notebook works on simulated data only.

## Wisconsin Card Sorting task (WCST)

TODO: Describe Task

TODO: Describe metrics

---
# Cognitive Tests Environment


## Implementation scheme


In [35]:
import WCST

  and should_run_async(code)


module

### Environment

The following cell implments an envinronment for the WCST:
- Rewards the agent once the action was correct (i.e., a normative model of the environment).
- **Future work**: Receives human data and returns what participants performed as the observation.

In [73]:
class WCST_Env(dm_env.Environment):
    ACTIONS = [0, 1, 2, 3]

    def __init__(self,seed=1):

        self.episode_steps = 36  #36 cards or steps per episode
        self._current_step = 0  #Current episode step counter
        self._reset_next_step = True
        self._action_history = []

        #Init WCST variables
        self.nb_dim = 3
        self.nb_features = 4
        self.nb_templates = self.nb_features
        r = 3  #rules number, we have 3 rules

        self.sample_card = np.array([0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1])

        self.nbTS = 0
        self.nb_win = 0
        self.t_criterion = 0
        self.t_err = 0
        self.criterions = []

        self.winstreak = 0
        self.m_percep = WCST.perception(self.nb_dim, self.nb_templates, self.nb_features)
        self.reasoning_list = []
        self.rule = 0

        #Last card info, it's dealt for the first time on reset method
        self.np_data = []
        self.v_data = []

    def new_card(self):
        v_data = [] #list type
        np_data = WCST.response_item_Reasoning(self.nb_dim, self.nb_features, self.m_percep, self.reasoning_list) #Modified WCST version

        #Transform into a vector
        for arr in np_data:
            for e in arr:
                v_data.append(e)

        #Save last card info
        self.np_data = np_data
        self.v_data = v_data

        return np_data, v_data

    def reset(self):
        self._reset_next_step = False
        self._current_step = 0
        self._action_history.clear()

        #Deal new card
        obs = self._observation()
        #self._current_step += 1
        return dm_env.restart(self._observation())

    def _episode_return(self):
      return 0.0

    def rule_switching(self, rule):
        """
        Serially changing the rules : color - form - number.
        """
        if rule!=2:
            rule = rule+1
        else:
            rule = 0
        return rule

    def external_feedback(self, action):
        """
        Returns a true reward according to the success or not of the card chosen.
        """
        response_card = self.np_data
        reference_cards = self.m_percep
        right_action_i = 0

        #print("Determine reward for rule:", self.rule, "and card: ")
        #print(response_card)

        for i in range(0, self.nb_templates):

            if np.array_equal(reference_cards[i][self.rule], response_card[self.rule]):
                right_action_i = i

        if right_action_i == action:
            #0 to decrease error activity
            return 0
        else:
            #1 to activate error cluster
            return 1


    def step(self, action: int):

        if self._reset_next_step:
            return self.reset()

        agent_action = WCST_Env.ACTIONS[action]

        #Compute reward
        step_reward = self.external_feedback(agent_action)
        print("Reward:", step_reward)

        ##Winstreak count
        if step_reward == 0:
            self.t_err = 0
            self.nb_win += 1
            self.winstreak += 1
            #ptrial.append(1)
            #ntrial.append(0)

        if step_reward == 1:
            self.t_criterion += 1
            self.t_err += 1

            #FIXME: In the original code winstreak is reset
            #after a positive reward. This doesn't look right.

            #self.winstreak = 0
            #ptrial.append(0)
            #ntrial.append(1)

        # Criterion test
        # After 3 wins, then change the rule
        if self.winstreak==3:
            self.rule = self.rule_switching(self.rule)
            self.criterions.append(self.t_criterion)
            #Reset some variables and increment nbTS
            self.t_criterion = 0
            self.winstreak = 0
            self.nbTS +=1
            print("winstreak=3, New rule is ", self.rule, "NbTS=", self.nbTS)
        else:
            #print("Winstreak=", self.winstreak)
            pass

        self._action_history.append(agent_action)
        self._current_step += 1

        # Check for termination.
        if self.nbTS >= 6 or self._current_step == self.episode_steps:
            self._reset_next_step = True
            print("A. Return last observation and terminate, NbTS=", self.nbTS)
            return dm_env.termination(reward=self._episode_return(), observation=self._observation())
        else:
            #Send reward to agent and a new observation
            #Uncomment in notebook
            print("B. Step: ", self._current_step, "Return observation")
            return dm_env.transition(reward=step_reward, observation=self._observation())

    def observation_spec(self):
        return dm_env.specs.BoundedArray(
            shape=self.sample_card.shape,
            dtype=self.sample_card.dtype,
            name='card',
            minimum=len(self.sample_card),
            maximum=len(self.sample_card)
        )

    def action_spec(self):
        return dm_env.specs.DiscreteArray(
            num_values=len(WCST_Env.ACTIONS),
            dtype=np.int32,
            name='action')
        pass

    def _observation(self):
        # agent observes only the current trial

        #INPUT new card, (Environment)
        print("Calling new_card...")
        np_data, card = self.new_card()

        #print("New card is np_data", np_data)
        print("New card is v_data", card)

        obs = card
        return obs

    @staticmethod
    def create_environment():
        """Utility function to create a N-back environment and its spec."""

        # Make sure the environment outputs single-precision floats.
        environment = wrappers.SinglePrecisionWrapper(WCST_Env())

        # Grab the spec of the environment.
        environment_spec = specs.make_environment_spec(environment)
        return environment, environment_spec


In [74]:
env, env_spec = WCST_Env.create_environment()
agent = RandomAgent(env_spec)

#print('actions:\n', env_spec.actions)
#print('observations:\n', env_spec.observations)
#print('rewards:\n', env_spec.rewards)

In [56]:
## Test run
env = WCST_Env()

# First observation
timestep = env.reset()

import random

## FIXME: This stalls after 36 cards, which is the limit per game or episode
## In the RL framework you should iterate as it is indicated in the acme loop.
for i in range(0,37):

    print("Step ", i)
    action = random.randint(0,2)
    env.step(action)

print("END")

  and should_run_async(code)


Calling new_card...
New card is v_data [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]
Calling new_card...
New card is v_data [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0]
Step  0
Reward: 1
B. Step:  1 Return observation
Calling new_card...
New card is v_data [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1]
Step  1
Reward: 1
B. Step:  2 Return observation
Calling new_card...
New card is v_data [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
Step  2
Reward: 0
B. Step:  3 Return observation
Calling new_card...
New card is v_data [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0]
Step  3
Reward: 1
B. Step:  4 Return observation
Calling new_card...
New card is v_data [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0]
Step  4
Reward: 1
B. Step:  5 Return observation
Calling new_card...
New card is v_data [1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0]
Step  5
Reward: 1
B. Step:  6 Return observation
Calling new_card...
New card is v_data [0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]
Step  6
Reward: 1
B. Step:  7 Return observation
Calling new_card...
New card is v_data [1, 0, 0, 

KeyboardInterrupt: ignored

### Define a random agent

For more information you can refer to NMA-DL W3D2 Basic Reinforcement learning.

In [26]:
class RandomAgent(acme.Actor):

  def __init__(self, environment_spec):
    """Gets the number of available actions from the environment spec."""
    self._num_actions = environment_spec.actions.num_values

  def select_action(self, observation):
    """Selects an action uniformly at random."""
    action = np.random.randint(self._num_actions)
    return action

  def observe_first(self, timestep):
    """Does not record as the RandomAgent has no use for data."""
    pass

  def observe(self, action, next_timestep):
    """Does not record as the RandomAgent has no use for data."""
    pass

  def update(self):
    """Does not update as the RandomAgent does not learn from data."""
    pass

### Initialize the environment and the agent

In [66]:
env, env_spec = WCST_Env.create_environment()
agent = RandomAgent(env_spec)

#print('actions:\n', env_spec.actions)
#print('observations:\n', env_spec.observations)
#print('rewards:\n', env_spec.rewards)

  and should_run_async(code)


AttributeError: ignored

### Run the loop

In [76]:
# fitting parameters
n_episodes = 100
n_total_steps = 0
log_loss = False
n_steps = n_episodes * 32
all_returns = []

# main loop
for episode in range(n_episodes):
  episode_steps = 0
  episode_return = 0
  episode_loss = 0

  start_time = time.time()

  timestep = env.reset()

  # Make the first observation.
  agent.observe_first(timestep)

  # Run an episode
  while not timestep.last():

    # DEBUG
    # print(timestep)

    # Generate an action from the agent's policy and step the environment.
    action = agent.select_action(timestep.observation)
    timestep = env.step(action)

    # Have the agent observe the timestep and let the agent update itself.
    agent.observe(action, next_timestep=timestep)
    agent.update()

    # Book-keeping.
    episode_steps += 1
    n_total_steps += 1
    episode_return += timestep.reward

    if log_loss:
      episode_loss += agent.last_loss

    if n_steps is not None and n_total_steps >= n_steps:
      break

  # Collect the results and combine with counts.
  steps_per_second = episode_steps / (time.time() - start_time)
  result = {
      'episode': episode,
      'episode_length': episode_steps,
      'episode_return': episode_return,
  }
  if log_loss:
    result['loss_avg'] = episode_loss/episode_steps

  all_returns.append(episode_return)

  display(env.plot_state())
  # Log the given results.
  print(result)

  if n_steps is not None and n_total_steps >= n_steps:
    break

clear_output()

# Histogram of all returns
#plt.figure()
#sns.histplot(all_returns, stat="density", kde=True, bins=12)
#plt.xlabel('Return [a.u.]')
#plt.ylabel('Density')
#plt.show()

Calling new_card...


KeyboardInterrupt: ignored

**Note:** You can simplify the environment loop using [DeepMind Acme](https://github.com/deepmind/acme).

In [None]:
# init a new N-back environment
env, env_spec = NBack.create_environment()

# DEBUG fake testing environment.
# Uncomment this to debug your agent without using the N-back environment.
# env = fakes.DiscreteEnvironment(
#     num_actions=2,
#     num_observations=1000,
#     obs_dtype=np.float32,
#     episode_length=32)
# env_spec = specs.make_environment_spec(env)

In [None]:
def dqn_make_network(action_spec: specs.DiscreteArray) -> snt.Module:
  return snt.Sequential([
      snt.Flatten(),
      snt.nets.MLP([50, 50, action_spec.num_values]),
  ])

# construct a DQN agent
agent = dqn.DQN(
    environment_spec=env_spec,
    network=dqn_make_network(env_spec.actions),
    epsilon=[0.5],
    logger=loggers.InMemoryLogger(),
    checkpoint=False,
)

Now, we run the environment loop with the DQN agent and print the training log.

In [None]:
# training loop
loop = EnvironmentLoop(env, agent, logger=loggers.InMemoryLogger())
#loop.run(n_episodes)

# print logs
#logs = pd.DataFrame(loop._logger._data)
#logs.tail()