In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import os

import tensorflow as tf

from plot_stats import *
from utils import *
from helper_rnn import *

plt.rcParams['figure.figsize'] = [8,6]
sns.set_style("darkgrid")

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

tf.random.set_seed(26)
np.random.seed(26)

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2


ModuleNotFoundError: No module named 'helper_rnn'

In [2]:
def train_a_model(path, train_arr, test_arr, num_of_trials, n_cells, lr, epochs, batch_size, cp_callback):
    
    reward, action, state = preprocessing(path,
                                          num_of_trials,
                                          train_arr)

    
    X_t , y_t = one_hot_encoding(reward, action, state)
    
    reward, action, state = preprocessing(path,
                                          num_of_trials,
                                          test_arr)

    
    X_v , y_v = one_hot_encoding(reward, action, state)
    
    
    model, sim_model = create_model(n_actions=2, state_size=2, n_cells=n_cells)
    
    model, initial_rnn_state = compile_model(model,
                                             lr,X_t.shape[0],
                                             n_cells)

    # training the model
    hist = model.fit(x=[X_t, initial_rnn_state],
                     y=y_t,
                     epochs=epochs,
                     batch_size=batch_size,
                     verbose=0,
                     validation_data=([X_v,initial_rnn_state],y_v))

    return model, sim_model, hist

In [19]:
train_acc = []
test_acc = []
test_loss = []


cell = 5
batch = 100
epochs = [100,200,400,500]

cv = np.array([
               [0,1],
                [1,0]
               ]) 
for epoch in tqdm(epochs):
    print(f'*** epoch {epoch} ***')
    tt_acc = 0
    tt_acc_t = 0
    tt_loss = 0
    for a in range(50):
        print(f'agent {a}')
        for t in cv:
            train_arr = [t[0]]
            test_arr = [t[1]]

            path = f'data/hybrid/sim_data/hybrid_agent_{a}'

            model, sim_model, hist = train_a_model(path,
                                                  train_arr,
                                                  test_arr,
                                                  num_of_trials=200,
                                                  n_cells=cell,
                                                  lr=0.001,
                                                  epochs=epoch,
                                                  batch_size=batch,
                                                  cp_callback=True
            )

            tt_acc += hist.history['accuracy'][-1]
            tt_acc_t += hist.history['val_accuracy'][-1]
            tt_loss += hist.history['val_binary_crossentropy'][-1]
    
    train_acc.append(tt_acc/100)  
    test_acc.append(tt_acc_t/100)
    test_loss.append(tt_loss/100)
    print(test_loss)

  0%|                                                                                            | 0/4 [00:00<?, ?it/s]

*** epoch 100 ***
agent 0
agent 1
agent 2
agent 3
agent 4
agent 5
agent 6
agent 7
agent 8
agent 9
agent 10
agent 11
agent 12
agent 13
agent 14
agent 15
agent 16
agent 17
agent 18
agent 19
agent 20
agent 21
agent 22
agent 23
agent 24
agent 25
agent 26
agent 27
agent 28
agent 29
agent 30
agent 31
agent 32
agent 33
agent 34
agent 35
agent 36
agent 37
agent 38
agent 39
agent 40
agent 41
agent 42
agent 43
agent 44
agent 45
agent 46
agent 47
agent 48
agent 49


 25%|████████████████████▊                                                              | 1/4 [05:05<15:16, 305.35s/it]

[0.6328739780187607]
*** epoch 200 ***
agent 0
agent 1
agent 2
agent 3
agent 4
agent 5
agent 6
agent 7
agent 8
agent 9
agent 10
agent 11
agent 12
agent 13
agent 14
agent 15
agent 16
agent 17
agent 18
agent 19
agent 20
agent 21
agent 22
agent 23
agent 24
agent 25
agent 26
agent 27
agent 28
agent 29
agent 30
agent 31
agent 32
agent 33
agent 34
agent 35
agent 36
agent 37
agent 38
agent 39
agent 40
agent 41
agent 42
agent 43
agent 44
agent 45
agent 46
agent 47
agent 48
agent 49


 50%|█████████████████████████████████████████▌                                         | 2/4 [13:24<13:58, 419.14s/it]

[0.6328739780187607, 0.5694618272781372]
*** epoch 400 ***
agent 0
agent 1
agent 2
agent 3
agent 4
agent 5
agent 6
agent 7
agent 8
agent 9
agent 10
agent 11
agent 12
agent 13
agent 14
agent 15
agent 16
agent 17
agent 18
agent 19
agent 20
agent 21
agent 22
agent 23
agent 24
agent 25
agent 26
agent 27
agent 28
agent 29
agent 30
agent 31
agent 32
agent 33
agent 34
agent 35
agent 36
agent 37
agent 38
agent 39
agent 40
agent 41
agent 42
agent 43
agent 44
agent 45
agent 46
agent 47
agent 48
agent 49


 75%|██████████████████████████████████████████████████████████████▎                    | 3/4 [28:10<10:32, 632.47s/it]

[0.6328739780187607, 0.5694618272781372, 0.5219461995363236]
*** epoch 500 ***
agent 0
agent 1
agent 2
agent 3
agent 4
agent 5
agent 6
agent 7
agent 8
agent 9
agent 10
agent 11
agent 12
agent 13
agent 14
agent 15
agent 16
agent 17
agent 18
agent 19
agent 20
agent 21
agent 22
agent 23
agent 24
agent 25
agent 26
agent 27
agent 28
agent 29
agent 30
agent 31
agent 32
agent 33
agent 34
agent 35
agent 36
agent 37
agent 38
agent 39
agent 40
agent 41
agent 42
agent 43
agent 44
agent 45
agent 46
agent 47
agent 48
agent 49


100%|███████████████████████████████████████████████████████████████████████████████████| 4/4 [45:51<00:00, 687.83s/it]

[0.6328739780187607, 0.5694618272781372, 0.5219461995363236, 0.541953761279583]





In [9]:
# cells 30. lr 0.001
# e 200 500 800 1000 2000

[0.6792920002341271,
 0.9734493246674538,
 1.198500207066536,
 1.4444124647974967,
 2.061407516300678]

In [11]:
# e 200 and 500. cells 10. lr 0.001

[0.5288263574242592, 0.6753430104255677]

In [13]:
# e 200 and 500. cells 10. lr 0.01

[1.0166496016085147, 1.6456142246723175]

In [15]:
# e 200 and 500. cells 5. lr 0.001

[0.5631232652068138, 0.5391559466719628]

In [17]:
# e 400. cells [1,2,4,5,8]. lr 0.001

[0.6196852266788483,
 0.57273413002491,
 0.5340087789297104,
 0.5263048848509788,
 0.5673557734489441]

In [None]:
# e 400 cells=5 0.52194