## Pytorch supervised learning of perceptual decision making task

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/neurogym/neurogym/blob/master/examples/example_neurogym_pytorch.ipynb)

Pytorch-based example code for training a RNN on a perceptual decision-making task.

### Installation when used on Google Colab

In [None]:
# Install gym
! pip install gym
# Install neurogym
! git clone https://github.com/gyyang/neurogym.git
%cd neurogym/
! pip install -e .

### Dataset

In [2]:
import numpy as np
import torch
import torch.nn as nn

import neurogym as ngym

# Environment
task = 'PerceptualDecisionMaking-v0'
kwargs = {'dt': 100}
seq_len = 100

# Make supervised dataset
dataset = ngym.Dataset(task, env_kwargs=kwargs, batch_size=16,
                       seq_len=seq_len)
env = dataset.env
ob_size = env.observation_space.shape[0]
act_size = env.action_space.n

  f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"


### Network and Training

In [3]:
class Net(nn.Module):
    def __init__(self, num_h):
        super(Net, self).__init__()
        self.lstm = nn.LSTM(ob_size, num_h)
        self.linear = nn.Linear(num_h, act_size)

    def forward(self, x):
        out, hidden = self.lstm(x)
        x = self.linear(out)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = Net(num_h=64).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

running_loss = 0.0
for i in range(2000):
    inputs, labels = dataset()
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    labels = torch.from_numpy(labels.flatten()).type(torch.long).to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)

    loss = criterion(outputs.view(-1, act_size), labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 200 == 199:
        print('{:d} loss: {:0.5f}'.format(i + 1, running_loss / 200))
        running_loss = 0.0

print('Finished Training')

200 loss: 0.09157
400 loss: 0.01893
600 loss: 0.01281
800 loss: 0.01198
1000 loss: 0.01114
1200 loss: 0.01116
1400 loss: 0.01104
1600 loss: 0.01079
1800 loss: 0.01112
2000 loss: 0.01062
Finished Training


In [5]:
print(net)

Net(
  (lstm): LSTM(3, 64)
  (linear): Linear(in_features=64, out_features=3, bias=True)
)


### Analysis

In [4]:
# TODO: Make this into a function in neurogym
perf = 0
num_trial = 200
for i in range(num_trial):
    env.new_trial()
    ob, gt = env.ob, env.gt
    ob = ob[:, np.newaxis, :]  # Add batch axis
    inputs = torch.from_numpy(ob).type(torch.float).to(device)

    action_pred = net(inputs)
    action_pred = action_pred.detach().numpy()
    action_pred = np.argmax(action_pred, axis=-1)
    perf += gt[-1] == action_pred[-1, 0]

perf /= num_trial
print('Average performance in {:d} trials'.format(num_trial))
print(perf)

Average performance in 200 trials
0.895


In [9]:
# extract the lstm hidden state from net
net = net.to('cpu')
hidden = net.lstm(inputs)
hidden 

(tensor([[[ 5.0934e-03, -1.2046e-01, -1.5882e-01,  ..., -2.0963e-01,
           -1.7633e-01,  1.0578e-01]],
 
         [[ 2.4956e-02, -1.5824e-01,  1.9214e-04,  ..., -3.2610e-01,
           -1.2877e-01,  1.7506e-01]],
 
         [[ 2.5560e-02, -1.2214e-01,  1.7761e-04,  ..., -1.8526e-01,
           -4.3815e-02,  1.2272e-01]],
 
         ...,
 
         [[ 2.6339e-02, -6.6062e-02,  3.5459e-04,  ..., -1.2552e-01,
           -3.7296e-02,  9.0264e-02]],
 
         [[ 2.6406e-02, -6.6035e-02,  1.6718e-04,  ..., -1.2250e-01,
           -2.6784e-02,  8.2449e-02]],
 
         [[ 5.4010e-01, -2.1344e-02,  4.8640e-05,  ..., -9.8092e-03,
           -7.3958e-03,  2.3478e-02]]], grad_fn=<StackBackward0>),
 (tensor([[[ 5.4010e-01, -2.1344e-02,  4.8640e-05,  5.5331e-03,  5.8322e-02,
            -2.6924e-02,  3.5993e-02,  6.7515e-01, -6.8902e-02,  7.8153e-04,
             7.8588e-01, -1.3600e-03,  1.5669e-04, -8.3393e-02, -9.6789e-02,
             1.9087e-02, -7.0471e-03, -6.5499e-03,  5.5878e-01,  6.