# RihkyeTask + LSTM_MD model

In [17]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from task import RihkyeTask

## Dataset

In [18]:
num_cueingcontext = 2
num_cue = 2
num_rule = 2
rule = [0, 1, 0, 1]
blocklen = [500, 500, 200]
block_cueingcontext = [0, 1, 0]
tsteps = 200
cuesteps = 100
batch_size = 10

dataset = RihkyeTask(num_cueingcontext=num_cueingcontext, num_cue=num_cue, num_rule=num_rule, rule=rule, blocklen=blocklen, \
block_cueingcontext=block_cueingcontext, tsteps=tsteps, cuesteps=cuesteps, batch_size=batch_size)

## Model

In [19]:
class LSTM_MD(nn.Module):
    """LSTM with a MD layer
    Parameters:
    input_size: int, LSTM input size
    hidden_size: int, LSTM hidden size
    output_size: int, output layer size
    num_layers: int, number of LSTM layers
    """

    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        rnn_output, _ = self.lstm(x)
        out = self.fc(rnn_output)
        return out

## Training

In [20]:
import time

# Model settings
input_size = 4 # 4 cues
hidden_size = 200
output_size = 2 # 2 rules
num_layers = 1

model = LSTM_MD(input_size=input_size, hidden_size=hidden_size, output_size=output_size, num_layers=num_layers)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


total_step = sum(blocklen)//batch_size
print_step = 10
running_loss = 0.0
model_name = 'model-' + str(int(time.time()))
savemodel = False


for i in range(total_step):
    inputs, labels = dataset()

    inputs = torch.from_numpy(inputs).type(torch.float)
    labels = torch.from_numpy(labels).type(torch.float)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)

    loss = criterion(outputs, labels)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # normalization
    optimizer.step()


    # print statistics
    running_loss += loss.item()
    if i % print_step == (print_step - 1):
        print('Total step: {:d}'.format(total_step))
        print('Training sample index: {:d}-{:d}'.format(i+1-print_step, i+1))

        # running loss
        print('loss: {:0.5f}'.format(running_loss / print_step), '\n')
        running_loss = 0.0

        if savemodel:
            # save model every print_step
            fname = os.path.join('models', model_name + '.pt')
            torch.save(model.state_dict(), fname)

            # save info of the model
            fpath = os.path.join('models', model_name + '.txt')
            with open(fpath, 'w') as f:
                f.write('input_size = ' + str(input_size) + '\n')
                f.write('hidden_size = ' + str(hidden_size) + '\n')
                f.write('output_size = ' + str(output_size) + '\n')
                f.write('num_layers = ' + str(num_layers) + '\n')


print('Finished Training')

Total step: 120
Training sample index: 0-10
loss: 0.41148 

Total step: 120
Training sample index: 10-20
loss: 0.27202 

Total step: 120
Training sample index: 20-30
loss: 0.15606 

Total step: 120
Training sample index: 30-40
loss: 0.05314 

Total step: 120
Training sample index: 40-50
loss: 0.03295 

Total step: 120
Training sample index: 50-60
loss: 0.26729 

Total step: 120
Training sample index: 60-70
loss: 0.16601 

Total step: 120
Training sample index: 70-80
loss: 0.05359 

Total step: 120
Training sample index: 80-90
loss: 0.03323 

Total step: 120
Training sample index: 90-100
loss: 0.02649 

Total step: 120
Training sample index: 100-110
loss: 0.02525 

Total step: 120
Training sample index: 110-120
loss: 0.02133 

Finished Training
