## TD Models

TD-learning family models:
1. <a href='#oneStepTD'> TD(0) </a>
2. Actor-critic:
    * Some Theory: <a href='http://incompleteideas.net/book/first/ebook/node66.html'>Actor-Critic Methods</a>
    * Possible Implementation: <a href='https://www.nature.com/articles/s41598-017-18004-7'> A hippocampo-cerebellar centred network for the learning and execution of sequence-based navigation </a>

In [1]:
# Imports
from __future__ import print_function
import pickle
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches
from matplotlib.collections import LineCollection
from matplotlib import cm
from copy import deepcopy
import plotly.graph_objects as go
from scipy.optimize import curve_fit
from dataclasses import make_dataclass
import pandas as pd
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import sys

module_path = 'src' 
if module_path not in sys.path:
    sys.path.append(module_path)
    
# Markus's code
from MM_Plot_Utils import plot, hist
from MM_Maze_Utils import *
from MM_Traj_Utils import *

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
# Defining global variables

# Some lists of nicknames for mice
RewNames=['B1','B2','B3','B4','C1','C3','C6','C7','C8','C9']
UnrewNames=['B5','B6','B7','D3','D4','D5','D6','D7','D8','D9']
AllNames=RewNames+UnrewNames
UnrewNamesSub=['B5','B6','B7','D3','D4','D5','D7','D8','D9'] # excluding D6 which barely entered the maze

# Define cell numbers of end/leaf nodes
lv6_nodes = list(range(63,127))
lv5_nodes = list(range(31,63))
lv4_nodes = list(range(15,31))
lv3_nodes = list(range(7,15))
lv2_nodes = list(range(3,7))
lv1_nodes = list(range(1,3))
lv0_nodes = list(range(0,1))
lvl_dict = {0:lv0_nodes, 1:lv1_nodes, 2:lv2_nodes, 3:lv3_nodes, 4:lv4_nodes, 5:lv5_nodes, 6:lv6_nodes}
quad1 = [3,7,8,15,16,17,18,31,32,33,34,35,36,37,38,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78]
quad2 = [4,9,10,19,20,21,22,39,40,41,42,43,44,45,46,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94]
quad3 = [5,11,12,23,24,25,26]
quad3.extend(list(range(47,54)))
quad3.extend(list(range(95,110)))
quad4 = [6,13,14,27,28,29,30]
quad4.extend(list(range(55,62)))
quad4.extend(list(range(111,126)))

# Parameters for simulating new trajectories
InvalidState = -1
RewardNode = 116
HomeNode = 127
StartNode = 0
S = 128  # Number of states
A = 3
RewardNodeMag = 1
main_dir = 'stan/'
real_traj_dir = main_dir+'traj_data/real_traj/'
pred_traj_dir = main_dir+'traj_data/pred_traj/'
stan_results_dir = main_dir+'stan_results/'
nodemap = pickle.load(open('stan/nodemap.p', 'rb'))

In [3]:
def extract_rew_traj(traj_type):
    ''' 
    Segments of real rewarded mice trajectories are extracted and saved. This data will be used for model fitting.
    
    traj_type: 'first_Rvisit', Storing trajectories up until the first reward node visit (assume immediate reward)
                               and discarding all subsequent bouts
               'post_firstR_every_Rvisit', Storing trajectories after the first rewarded bout up to the end of the experiment
                                           each trajectory terminated at the first reward node visit for the bout
               'first_drink', Storing trajectories up until the first reward receipt (could be multiple reward node visits)

    Returns: Saves trajectories as a pickle file, 'rewMICE_first_Rvisit.p' in the specified directory, 'real_traj_dir'
             If traj_type is 'first_drink', also saves the number of unrewarded water port visits as 'nonRew_RVisits.p'
    Return type: TrajS, ndarray[(N, TrajNo, TrajSize), int]
                 nonRew_RVisits, ndarray[(N, TrajNo), int]
    '''
   
    N = 10
    TrajSize = 1000
    TrajNo = 350
    TrajS = np.ones((N,TrajNo,TrajSize)) * InvalidState

    if traj_type == 'first_Rvisit':
        # Extracting trajectories of each rewarded mouse up until the first visit to the reward node
        for mouseID, nickname in enumerate(RewNames):
            tf = LoadTraj(nickname+'-tf')
            reward_found = False
            for boutID in np.arange(len(tf.no)):
                # find the number of steps till the first reward
                for step, entry in enumerate(tf.no[boutID]):
                    node, frame = entry
                    if node==RewardNode:
                        TrajS[mouseID,boutID,step] = tf.no[boutID][step,0]
                        reward_found = True
                        break
                    else:
                        TrajS[mouseID,boutID,step] = tf.no[boutID][step,0]
                if reward_found:
                    break
        pickle.dump(TrajS,open(real_traj_dir+'rewMICE_first_Rvisit.p','wb'))
    elif traj_type == 'post_firstR_every_Rvisit':
        # Extracting trajectories of each rewarded mouse up until the first visit to the reward node
        for mouseID, nickname in enumerate(RewNames):
            tf = LoadTraj(nickname+'-tf')
            first_reward = False
            ID = -1
            for boutID in np.arange(len(tf.no)):
                if len(tf.no[boutID]) > 1:
                    for step, entry in enumerate(tf.no[boutID]):
                        node, frame = entry
                        if node==RewardNode:
                            if first_reward:
                                TrajS[mouseID,ID,step] = tf.no[boutID][step,0]
                            first_reward = True
                            break
                        elif first_reward:
                            TrajS[mouseID,ID,step] = tf.no[boutID][step,0]
                    if first_reward:
                        ID += 1
                
        pickle.dump(TrajS,open(real_traj_dir+'rewMICE_post_firstR_every_Rvisit.p','wb'))
    elif traj_type == 'first_drink':
        # Extracting trajectories of each rewarded mouse up until the first drink is obtained at the reward node
        nonRew_RVisits = np.zeros((N,TrajNo), dtype=int)
        
        for mouseID, nickname in enumerate(RewNames):
            tf = LoadTraj(nickname+'-tf')
            reward_found = False
            for boutID, reFrames in enumerate(tf.re):
                waterport_visit_frames = tf.no[boutID][np.where(tf.no[boutID][:,0]==116)[0],1]
                if len(reFrames) != 0:
                    # find the number of steps till the first reward
                    for step, entry in enumerate(tf.no[boutID]):
                        node, frame = entry
                        if node==116:
                            wID = np.where(waterport_visit_frames==frame)[0][0]
                            if len(waterport_visit_frames)==1 and waterport_visit_frames[wID] <= reFrames[0][0]:
                                reFirst = step
                                TrajS[mouseID,boutID,0:reFirst+1] = tf.no[boutID][0:reFirst+1,0] 
                                reward_found = True
                                break
                            elif waterport_visit_frames[-1]==frame and waterport_visit_frames[wID] <= reFrames[0][0]:
                                reFirst = step
                                TrajS[mouseID,boutID,0:reFirst+1] = tf.no[boutID][0:reFirst+1,0] 
                                reward_found = True
                                break
                            elif waterport_visit_frames[wID] <= reFrames[0][0] and reFrames[0][0] <= waterport_visit_frames[wID+1]:
                                reFirst = step
                                TrajS[mouseID,boutID,0:reFirst+1] = tf.no[boutID][0:reFirst+1,0] 
                                reward_found = True
                                break
                        else:
                            TrajS[mouseID,boutID,step] = tf.no[boutID][step,0]

                else:
                    TrajS[mouseID,boutID,0:len(tf.no[boutID][:,0])] = tf.no[boutID][:,0]

                if reward_found:
                    break
                    
            # Save number of unsuccessful reward node visits
            for boutID in np.arange(TrajNo):
                nonRew_RVisits[mouseID,boutID] = len(np.where(TrajS[mouseID,boutID,:]==RewardNode)[0])

                # Checking if the bout is rewarded
                if TrajS[mouseID,boutID+1,0] == InvalidState:
                    nonRew_RVisits[mouseID,boutID] -= 1
                    break
        pickle.dump(nonRew_RVisits,open(real_traj_dir+'nonRew_RVisits.p','wb'))
        pickle.dump(TrajS,open(real_traj_dir+'rewMICE_first_drink.p','wb'))       

In [4]:
extract_rew_traj('post_firstR_every_Rvisit')

In [None]:
def get_SAnodemap(S,A):
    '''
    Creates a mapping based on the maze layout where current states are linked to the next 3 future states
    
    Returns: SAnodemap, a 2D array of current state to future state mappings
             Also saves SAnodemap in the specified 'main_dir' as 'nodemap.p'
    Return type: ndarray[(S, A), int]
    '''
    
    # Return nodemap for state-action values
    SAnodemap = np.ones((S,A), dtype=int) * InvalidState
    for node in np.arange(S-1):
        # Shallow level node available from current node
        if node%2 == 0:
            SAnodemap[node,0] = (node - 2) / 2
        elif node%2 == 1:
            SAnodemap[node,0] = (node - 1) / 2
        if SAnodemap[node,0] == InvalidState:
            SAnodemap[node,0] = HomeNode

        if node not in lv6_nodes:
            # Deeper level nodes available from current node
            SAnodemap[node,1] = node*2 + 1
            SAnodemap[node,2] = node*2 + 2

    # Nodes available from entry point
    SAnodemap[HomeNode,0] = InvalidState
    SAnodemap[HomeNode,1] = 0
    SAnodemap[HomeNode,2] = InvalidState
    
    pickle.dump(SAnodemap,open(main_dir+'nodemap.p','wb'))
    
    return SAnodemap

In [None]:
def plot_trajectory(state_hist_all, episode, save_dir=None, mouse=None, figtitle=None):
    '''
    Plots specified simulated trajectories on the maze layout.
    
    state_hist_all: dictionary of trajectories simulated by a model. Eg. state_hist_all{0:[0,1,3..], 1:[]..}
    episode: 'all', to plot all trajectories in state_hist_all
             int, to plot a specific bout/episode
    
    Returns: One maze figure with plotted trajectories and a color bar indicating nodes from entry to exit
    Return type: --
    '''

    def nodes2cell(state_hist_all):
        '''
        simulated trajectories, state_hist_all: {mouseID: [[TrajID x TrajSize]]}
        '''
        state_hist_cell = []
        state_hist_xy = {}
        ma=NewMaze(6)
        for epID, episode in enumerate(state_hist_all.values()):
            cells = []
            cells.extend([7])
            for id,node in enumerate(episode):
                if id != 0 and node != HomeNode:
                    if node > episode[id-1]: 
                        # if going to a deeper node
                        cells.extend(ma.ru[node])
                    elif node < episode[id-1]: 
                        # if going to a shallower node
                        reverse_path = list(reversed(ma.ru[episode[id-1]]))
                        reverse_path = reverse_path + [ma.ru[node][-1]]
                        cells.extend(reverse_path[1:])
            if node==HomeNode:
                home_path = list(reversed(ma.ru[0]))
                cells.extend(home_path[1:])  # cells from node 0 to maze exit
            state_hist_cell.append(cells)
            state_hist_xy[epID] = np.zeros((len(cells),2))
            state_hist_xy[epID][:,0] = ma.xc[cells] + np.random.choice([-1,1],len(ma.xc[cells]),p=[0.5,0.5])*np.random.rand(len(ma.xc[cells]))/2
            state_hist_xy[epID][:,1] = ma.yc[cells] + np.random.choice([-1,1],len(ma.yc[cells]),p=[0.5,0.5])*np.random.rand(len(ma.yc[cells]))/2
        return state_hist_cell, state_hist_xy
    
    state_hist_cell, state_hist_xy = nodes2cell(state_hist_all)
    
    ma=NewMaze(6)
    # Draw the maze outline    
    fig,ax=plt.subplots(figsize=(9,9))
    plot(ma.wa[:,0],ma.wa[:,1],fmts=['k-'],equal=True,linewidth=2,yflip=True,
              xhide=True,yhide=True,axes=ax)
    re=[[-0.5,0.5,1,1],[-0.5,4.5,1,1],[-0.5,8.5,1,1],[-0.5,12.5,1,1],
       [2.5,13.5,1,1],[6.5,13.5,1,1],[10.5,13.5,1,1],
       [13.5,12.5,1,1],[13.5,8.5,1,1],[13.5,4.5,1,1],[13.5,0.5,1,1],
       [10.5,-0.5,1,1],[6.5,-0.5,1,1],[2.5,-0.5,1,1],
       [6.5,1.5,1,1],[6.5,11.5,1,1],[10.5,5.5,1,1],[10.5,7.5,1,1],
       [5.5,4.5,1,1],[5.5,8.5,1,1],[7.5,4.5,1,1],[7.5,8.5,1,1],[2.5,5.5,1,1],[2.5,7.5,1,1],
       [-0.5,2.5,3,1],[-0.5,10.5,3,1],[11.5,10.5,3,1],[11.5,2.5,3,1],[5.5,0.5,3,1],[5.5,12.5,3,1],
       [7.5,6.5,7,1]]
    for r in re:
        rect=patches.Rectangle((r[0],r[1]),r[2],r[3],linewidth=1,edgecolor='lightgray',facecolor='lightgray')
        ax.add_patch(rect)

    #plt.axis('off'); # turn off the axes

    # Converting cell positions to x,y positions in the maze
    # ma.ce contains x,y positions for each cell
    if episode == 'all':
        for id, episode in enumerate(state_hist_xy):
            x = state_hist_xy[episode][:,0]
            y = state_hist_xy[episode][:,1]
            t = np.linspace(0,1,x.shape[0]) # your "time" variable

            # set up a list of (x,y) points
            points = np.array([x,y]).transpose().reshape(-1,1,2)

            # set up a list of segments
            segs = np.concatenate([points[:-1],points[1:]],axis=1)

            # make the collection of segments
            lc = LineCollection(segs, cmap=plt.get_cmap('viridis'),linewidths=2) # jet, viridis hot
            lc.set_array(t) # color the segments by our parameter

            # plot the collection
            lines=ax.add_collection(lc); # add the collection to the plot
    else:
        x = state_hist_xy[episode][:,0]
        y = state_hist_xy[episode][:,1]
        t = np.linspace(0,1,x.shape[0]) # your "time" variable

        # set up a list of (x,y) points
        points = np.array([x,y]).transpose().reshape(-1,1,2)

        # set up a list of segments
        segs = np.concatenate([points[:-1],points[1:]],axis=1)

        # make the collection of segments
        lc = LineCollection(segs, cmap=plt.get_cmap('viridis'),linewidths=2) # jet, viridis hot
        lc.set_array(t) # color the segments by our parameter

        # plot the collection
        lines=ax.add_collection(lc); # add the collection to the plot

    cax=fig.add_axes([1.05, 0.05, 0.05, 0.9])
    cbar=fig.colorbar(lines,cax=cax)
    cbar.set_ticks([0,1])
    cbar.set_ticklabels(['Entry','Exit'])
    cbar.ax.tick_params(labelsize=18)
    fig.suptitle(figtitle)
    fig = plt.gcf()
    if save_dir:
        fig.savefig(save_dir+mouse+'.png')
    plt.show()

In [None]:
def visualize_pred(avg_V, state_hist_all, save_dir=None):
    '''
    avg_V: vector of state values averaged across multiple runs of the model
           ndarray[(1,S), float]
    state_hist_all: dictionary of trajectories simulated by a model. Eg. state_hist_all{0:[0,1,3..], 1:[]..}
    
    Returns: A heatmap of state values and predicted trajectories plotted on the maze layout
    '''
    # Plotting state values
    fig, ax = plt.subplots(figsize=(30,800))
    axhandle = ax.imshow(np.transpose(np.reshape(avg_V,(S,1))),cmap='YlGnBu')
    ax.invert_yaxis()
    ax.set_ylabel('V (s)')
    ax.set_xticks(np.arange(0,127,5))
    ax.set_xticklabels([str(val) for val in np.arange(0,127,5)])
    ax.set_xlabel('Nodes')
    ax.set_title('Average state values for TD(0) with alpha: %.2f, beta: %.2f and gamma: %.2f' %(alpha,beta,gamma))
    fig.colorbar(axhandle,fraction=0.005)

    print('Max state value ', np.max(avg_V))
    print('Min state value', np.min(avg_V))

    # Plot predicted trajectories
    plot_trajectory(state_hist_all, 'all', save_dir=None)

In [None]:
def TD0_first_Rvisit(sub_fits,fit_group,fit_group_data):
    '''
    Predicts trajectories with first reward visit TD(0) using the parameters fitted for each rewarded mouse.
    Predicted trajectories can't be longer than its corresponding bout in real mouse trajectories.
    Set MatchEndNode = True to set an additional constraint on generating predicted trajectories with the same end node as the real counterpart

    Note: use this to generate predicted trajectories from a range of parameters for parameter recovery 
          or to generate predicted trajectories from fitted parameters
    
    sub_fits: dictionary of fitted parameters and log likelihood for each rewarded mouse. 
                   sub_fits{0:[alpha_fit, beta_fit, gamma_fit, LL], 1:[]...., 9:[]}
    fit_group: 'Rew' or 'Unrew', specify mouse group to load real trajectory data for
    fit_group_data: str, file path for real trajectory data
                   
    Returns: state_hist_AllMice, dictionary of trajectories simulated by a model using fitted parameters for all Rew mice
             state_hist_AllMice{0:[0,1,3..], 1:[]..}
             
             int valid_bouts, counter to record the number of bouts that were simulated corresponding to real trajectory
                              data used for fitting
                              
             int success, either 0 or 1 to flag when the model fails to generate simulated trajectories adhering
                          to certain bounds: fitted parameters, number of episodes, trajectory length
    '''
    # Set environment parameters
    state_hist_AllMice = {}
    valid_bouts = []
    episode_cap = 500
    value_cap = 1e5
    success = 1
    MatchEndNode = False
    
    if fit_group == 'Rew':
        TrajS = pickle.load(open(fit_group_data,'rb')).astype(int)
        
    for mouseID in np.arange(10):
        # Set model parameters
        alpha, beta, gamma, lamda = sub_fits[mouseID]
        TrajNo = len(np.where(TrajS[mouseID, :, 0] != InvalidState)[0])
        
        for count in np.arange(avg_count):
            # Initialize model parameters
            V = np.zeros(S)  # state-action values
            V[HomeNode] = 0  # setting action-values of maze entry to 0
            V[RewardNode] = 0  # setting action-values of reward port to 0
            state_hist_mouse = {}
            R_visits = 0

            for n in np.arange(N):
                valid_episode = False
                episode_attempt = 0
                
                # Extract from real mouse trajectory the terminal node in current bout and trajectory length
                end = np.where(TrajS[mouseID,valid_boutID[n]]==InvalidState)[0][0]
                valid_traj = TrajS[mouseID,valid_boutID[n],0:end]
                
                # Back-up a copy of state-values to use in case the next episode has to be discarded
                V_backup = np.copy(V)
                
                # Begin episode
                while not valid_episode and episode_attempt < episode_cap:
                    # Initialize starting state,s0 to node 0
                    s = StartNode
                    state_hist = []
                    
                    while s!=HomeNode and s!=RewardNode:
                        # Record current state
                        state_hist.extend([s])

                        # Use softmax policy to select action, a at current state, s
                        betaV = []
                        for node in nodemap[s, :]:
                            if node == InvalidState:
                                betaV.extend([0])
                            else:
                                betaV.extend([np.exp(beta * V[node])])
                        prob = betaV / np.sum(betaV)
                        try:
                            a = np.random.choice([0, 1, 2], 1, p=prob)[0]
                        except:
                            print('Error with probabilities. betaV: ', betaV, ' nodes: ', nodemap[s, :], ' state-values: ', V[nodemap[s, :]])

                        # Take action, observe reward and next state
                        sprime = int(nodemap[s,a])
                        if sprime == RewardNode:
                            R = RewardNodeMag  # Receive a reward of 1 when transitioning to the reward port
                        else:
                            R = 0

                        # Update action-value of previous state value, V[s]
                        V[s] += alpha * (R + gamma*V[sprime] - V[s])
                        if np.isnan(V[s]):
                            print('Warning invalid state-value: ', s, sprime, V[s], V[sprime], alpha, beta, gamma, R)
                        elif np.isinf(V[s]):
                            print('Warning infinite state-value: ', V)
                        elif V[s]>value_cap:
                            #print('Warning state value exceeded upper bound. Might approach infinity')
                            V[s] = value_cap
                            
                        # Shift state values for the next time step
                        s = sprime
                        
                        # Check whether to abort the current episode
                        if len(state_hist) > len(valid_traj):
                            #print('Trajectory too long. Aborting episode')
                            break
                    state_hist.extend([s])

                    if abort_episode:
                        # Don't save predicted trajectory and attempt episode again
                        pass
                    else:
                        if not MatchEndNode:
                            valid_episode = True
                        elif MatchEndNode:
                            # Checking if predicted trajectory meets another minimum requirement
                            # Trajectory must end at the same terminal node as the real trajectory bout
                            realTerminalNode = valid_traj[-1]
                            if s == realTerminalNode:
                                state_hist_mouse[mouseID] = state_hist
                                valid_episode = True
                            else:
                                V = np.copy(V_backup)
                                episode_attempt += 1
                                #print('Invalid episode: Requirements are to end at ', realTerminalNode, ' with length ', len(valid_traj))
                                #print('Predicted Trajectory statistics: ends at ', s, ' with length ', len(state_hist))

                if episode_attempt >= episode_cap:
                    print('Failed to generate episodes for mouse ', mouseID, ' with parameter set: ', alpha, beta, gamma)
                    success = 0
                    break
            state_hist_AllMice[mouseID] = state_hist_mouse
            
    return state_hist_AllMice, success

In [None]:
def TD0_first_drink(sub_fits,fit_group):
    '''
    Generating simulated data from the model for STAN fitting
    '''
    # Set environment parameters
    S = 127
    A = 3
    RT = 1
    N = 10
    nodemap = get_SAnodemap(S,A)  # rows index the current state, columns index 3 available neighboring states
    state_hist_AllMice = {}
    valid_bouts = []
    avg_count = 1
    episode_cap = 500
    value_cap = 1e5
    success = 1
    
    if fit_group == 'Rew':
        TrajS = pickle.load(open(real_traj_dir+'rewMICE.p','rb')).astype(int)
        
    for mouseID in np.arange(N):
        # Set model parameters
        alpha = sub_fits[mouseID][0]  # learning rate
        beta = sub_fits[mouseID][1]   # softmax exploration - exploitation
        gamma = sub_fits[mouseID][2]
        R = 0
        
        # number of episodes to train over which are real bouts beginning at node 0 
        # and exploring deeper into the maze, which is > than a trajectory length of 2 (node 0 -> node 127)
        valid_boutID = np.where(TrajS[mouseID,:,2]!=InvalidState)[0]
        N = len(valid_boutID)
        valid_bouts.extend([N])
        
        for count in np.arange(avg_count):
            # Initialize model parameters
            V = np.random.rand(S+1)  # state-action values
            V[HomeNode] = 0  # setting action-values of maze entry to 0
            V[RewardNode] = 0  # setting action-values of reward port to 0
            state_hist_mouse = {}
            R_visits = 0

            for n in np.arange(N):
                valid_episode = False
                episode_attempt = 0
                total_R_visits = len(np.where(TrajS[mouseID,n,:]==RewardNode)[0])
                
                # Extract from real mouse trajectory the terminal node in current bout and trajectory length
                end = np.where(TrajS[mouseID,valid_boutID[n]]==InvalidState)[0][0]
                valid_traj = TrajS[mouseID,valid_boutID[n],0:end]
                
                # Back-up a copy of state-values to use in case the next episode has to be discarded
                V_backup = np.copy(V)
                
                # Begin episode
                while not valid_episode and episode_attempt < episode_cap:
                    # Initialize starting state,s0 to node 0
                    s = 0
                    state_hist = []
                    
                    while s!=HomeNode and (s!=RewardNode or R==0):
                        # Record current state
                        state_hist.extend([s])

                        # Use softmax policy to select action, a at current state, s
                        if s in lv6_nodes:
                            aprob = [1,0,0]
                        else:
                            betaV = [np.exp(beta*V[int(val)]) for val in nodemap[s,:]]
                            aprob = []
                            for atype in np.arange(3):
                                if np.isinf(betaV[atype]):
                                    aprob.extend([1])
                                elif np.isnan(betaV[atype]):
                                    aprob.extend([0])
                                else:
                                    aprob.extend([betaV[atype]/np.nansum(betaV)])
                        
                        # Check for invalid probabilities
                        for i in aprob:
                            if np.isnan(i):
                                print('Invalid action probabilities ', aprob, betaV, s)
                                print(alpha, beta, gamma, mouseID, n)
                        if np.sum(aprob) < 0.999:
                            print('Invalid action probabilities, failed summing to 1: ', aprob, betaV, s)
                        a = np.random.choice([0,1,2],1,p=aprob)[0]

                        # Take action, observe reward and next state
                        sprime = int(nodemap[s,a])
                        if sprime == RewardNode:
                            R_visits += 1
                            if R_visits >= total_R_visits:
                                R = 1  # Receive a reward of 1 when transitioning to the reward port
                            else:
                                R = 0
                        else:
                            R = 0

                        # Update action-value of previous state value, V[s]
                        #V[s] += alpha * (R + gamma*V[sprime] - V[s])
                        V[s] += R + gamma*V[sprime] - alpha*V[s]
                        if np.isnan(V[s]):
                            print('Warning invalid state-value: ', s, sprime, V[s], V[sprime], alpha, beta, gamma, R)
                        elif np.isinf(V[s]):
                            print('Warning infinite state-value: ', V)
                        elif V[s]>value_cap:
                            #print('Warning state value exceeded upper bound. Might approach infinity')
                            V[s] = value_cap
                            
                        # Shift state values for the next time step
                        s = sprime
                        
                        # Check whether to abort the current episode
                        if len(state_hist) > len(valid_traj):
                            #print('Trajectory too long. Aborting episode')
                            break
                    state_hist.extend([s])

                    # Find actual end node for mouse trajectory in the current bout/episode
                    if s == valid_traj[-1]:
                        #if len(state_hist) < 200:
                        if (len(state_hist) > 2) and (len(state_hist) <= len(valid_traj)):
                            state_hist_mouse[n] = state_hist
                            valid_episode = True
                    else:
                        R = 0
                        V = np.copy(V_backup)
                        #print('Rejecting episode of length: ', len(state_hist), ' for mouse ', mouseID, ' bout ', valid_boutID[n], ' traj length ', len(valid_traj))
                        episode_attempt += 1
                        
                if episode_attempt >= episode_cap:
                    print('Failed to generate episodes for mouse ', mouseID, ' with parameter set: ', alpha, beta, gamma)
                    success = 0
                    break
            state_hist_AllMice[mouseID] = state_hist_mouse
            
            #print('Mouse', mouseID, ' max state-value', np.max(V))
    return state_hist_AllMice, valid_bouts, success

In [None]:
def TDlambda_Rvisit(sub_fits, fit_group, fit_group_data):
    '''
    Predict trajectories using TD-lambda model. In this version home and reward node are terminal states.
    Predicted trajectories can't be longer than its corresponding bout in real mouse trajectories.
    Set MatchEndNode = True to set an additional constraint on generating predicted trajectories with the same end node as the real counterpart
    '''
    N = 10
    state_hist_AllMice = {}
    episode_cap = 500
    value_cap = 1e5
    success = 1
    MatchEndNode = False

    if fit_group == 'Rew':
        TrajS = pickle.load(open(fit_group_data, 'rb')).astype(int)

    for mouseID in np.arange(N):
        # Set model parameters
        alpha, beta, gamma, lamda = sub_fits[mouseID]
        TrajNo = len(np.where(TrajS[mouseID, :, 0] != InvalidState)[0])

        # Initialize model parameters
        if init == 'ZERO':
            V = np.zeros(S)  # state-action values
        V[HomeNode] = 0  # setting action-values of maze entry to 0
        V[RewardNode] = 0  # setting action-values of reward port to 0
        e = np.zeros(S)  # eligibility trace vector for all states
        state_hist_mouse = {}

        for bout in np.arange(TrajNo):
            valid_episode = False
            abort_episode = False
            episode_attempt = 0

            # Extract from real mouse trajectory the terminal node in current bout and trajectory length
            end = np.where(TrajS[mouseID, bout] == InvalidState)[0][0]
            valid_traj = TrajS[mouseID, bout, 0:end]

            # Back-up a copy of state-values to use in case the next episode has to be discarded
            V_backup = np.copy(V)
            e_backup = np.copy(e)

            # Begin episode
            while not valid_episode and episode_attempt < episode_cap:
                # Initialize starting state,s0 to node 0
                s = StartNode
                state_hist = []

                while s != HomeNode and s != RewardNode:
                    # Record current state
                    state_hist.extend([s])

                    # Use softmax policy to select action, a at current state, s
                    betaV = []
                    for node in nodemap[s, :]:
                        if node == InvalidState:
                            betaV.extend([0])
                        else:
                            betaV.extend([np.exp(beta * V[node])])
                    prob = betaV / np.sum(betaV)
                    try:
                        a = np.random.choice([0, 1, 2], 1, p=prob)[0]
                    except:
                        print('Error with probabilities. betaV: ', betaV, ' nodes: ', nodemap[s, :], ' state-values: ', V[nodemap[s, :]])

                    # Take action, observe reward and next state
                    sprime = nodemap[s, a]
                    if sprime == RewardNode:
                        R = RewardNodeMag  # Receive a reward of 1 when transitioning to the reward port
                    else:
                        R = 0

                    # Calculate error signal for current state
                    td_error = R + gamma * V[sprime] - V[s]
                    e[s] += 1

                    # Propagate value to all other states
                    for node in np.arange(S):
                        V[node] += alpha * td_error * e[node]
                        e[node] = gamma * lamda * e[node]

                    if np.isnan(V[s]):
                        print('Warning invalid state-value: ', s, sprime, V[s], V[sprime], sub_fits)
                    elif np.isinf(V[s]):
                        print('Warning infinite state-value: ', V)
                    elif V[s] > value_cap:
                        # print('Warning state value exceeded upper bound. Might approach infinity')
                        V[s] = value_cap

                    # Update future state to current state
                    s = sprime

                    # Check whether to abort the current episode
                    if len(state_hist) > len(valid_traj):
                        # print('Trajectory too long. Aborting episode')
                        abort_episode = True
                        V = np.copy(V_backup)
                        e = np.copy(e_backup)
                        episode_attempt += 1
                        break
                    else:
                        abort_episode = False
                state_hist.extend([s])

                if abort_episode:
                    # Don't save predicted trajectory and attempt episode again
                    pass
                else:
                    if not MatchEndNode:
                        valid_episode = True
                    elif MatchEndNode:
                        # Checking if predicted trajectory meets another minimum requirement
                        # Trajectory must end at the same terminal node as the real trajectory bout
                        realTerminalNode = valid_traj[-1]
                        if s == realTerminalNode:
                            state_hist_mouse[mouseID] = state_hist
                            valid_episode = True
                        else:
                            V = np.copy(V_backup)
                            e = np.copy(e_backup)
                            episode_attempt += 1
                            #print('Invalid episode: Requirements are to end at ', realTerminalNode, ' with length ', len(valid_traj))
                            #print('Predicted Trajectory statistics: ends at ', s, ' with length ', len(state_hist))

            if episode_attempt >= episode_cap:
                print('Failed to generate episodes for mouse ', mouseID, ' with parameter set: ', alpha, beta, gamma)
                success = 0
                break
        state_hist_AllMice[mouseID] = state_hist_mouse

    return state_hist_AllMice, success

### Visualize trajectory lengths over time

In [None]:
# Plotting trajectory lengths vs bouts for all mice
fig, axs = plt.subplots(2,5, constrained_layout=True, figsize=(20,15))
for mouseID, nickname in enumerate(RewNames):
    i = mouseID//5
    j = mouseID%5
    tf = LoadTraj(nickname+'-tf')
    traj_lengths = []
    for boutID, bout in enumerate(tf.no):
        traj_lengths.extend([len(tf.no[boutID])])
    axs[i,j].bar(np.arange(1,len(tf.no)+1), traj_lengths)
    axs[i,j].set_title(nickname)
    axs[1,j].set_xlabel('Bouts')
    axs[i,0].set_ylabel('No. of steps')
plt.savefig('figures/all_trajectories.png')
    
fig2, axs2 = plt.subplots(2,5, constrained_layout=True, figsize=(20,15))
for mouseID, nickname in enumerate(RewNames):
    i = mouseID//5
    j = mouseID%5
    tf = LoadTraj(nickname+'-tf')
    traj_lengths2 = []
    rewardID = [] 
    visitsteps = []
    drinkstep = []
    first_drink=False
    for boutID, bout in enumerate(tf.no[0:30]):
        traj_lengths2.extend([len(tf.no[boutID])])
        if 116 in bout:
            visits = np.where(tf.no[boutID]==116)[0]
            rewardID.extend([boutID]*len(visits))
            visitsteps.extend(visits)
        if len(tf.re[boutID])>0 and not first_drink:
            drinkID = boutID
            drinkframe = tf.re[boutID][0][0]
            for attempt in visits:
                if drinkframe > tf.no[boutID][attempt][1]:
                    drinkstep = attempt
                    break
            first_drink=True
    
    axs2[i,j].bar(np.arange(len(tf.no[0:30])), traj_lengths2, color='#FFFAC8', edgecolor='#FFD500')
    axmin, axmax = axs2[i,j].get_ylim()
    ypos = [traj_lengths2[val-1] + (axmax-axmin)*0.1 for val in rewardID]
    axs2[i,j].plot(rewardID,visitsteps,'b*',label='reward port visit')
    axs2[i,j].plot(drinkID,drinkstep,'k*',label='first drink')
    axs2[i,j].set_title(nickname)
    axs2[1,j].set_xlabel('Bouts')
    axs2[i,0].set_ylabel('No. of steps')
    axs2[0,0].legend()
plt.savefig('figures/first30_trajectories.png')

### Visualize Quadrant Behavior

In [None]:
# Sorting which quarters of the maze a mouse visits during each trajectory
tf = LoadTraj('C1'+'-tf')
quad_visit = []
for bout in np.arange(3,20):
    for val in tf.no[bout]:
        if val[0] in quad1:
            quad_visit.extend([1])
        elif val[0] in quad2:
            quad_visit.extend([2])
        elif val[0] in quad3:
            quad_visit.extend([3])
        elif val[0] in quad4:
            quad_visit.extend([4])
    plt.figure(figsize=(20,15))
    plt.plot(np.arange(1,len(quad_visit)+1),quad_visit,'*')
    plt.title('Bout %i' %(bout+1))

### TD(0) <a id='#oneStepTD'></a>
Online TD-control algorithm which estimates state values, V(s)
 - States: 128 maze nodes (including home node)
 - Terminal states: maze entry, 127 and reward port, 116
 - Rewards: 0 on all states except for 1 on the water port
 
 Pseudocode
- Softmax action selection policy: $\pi(a | s_i) = \frac{e^{\beta*V(s_{ij})}}{e^{\beta*V(s_{ij})} + e^{\beta*V(s_{ij})} + e^{\beta*V(s_{ij})}}$, where j indexes the 3 neighboring nodes to $s_i$
- State-value update: $V(s) \leftarrow V(s) + \alpha*(R + \gamma*V(s') - V(s))$

 Re-parametrized version where $\beta = \beta * \alpha$ and $V(s) = \frac{V(s)}{\alpha}$
- original action-value update: $V(s) \leftarrow (1-\alpha)V(s) + \alpha(R + \gamma * V(s'))$
- new action-value update: $V(s) \leftarrow (1-\alpha)V(s) + R + \alpha * \gamma * V(s')$

In [None]:
# Define TD(0) model
'''
Toy TD(0) model to simulate trajectories with

Returns: state_hist_all, dictionary of trajectories most recently simulated by the model. 
         If avg_count > 1, state_hist_all represents the latest simulation results
         Type: state_hist_all{0:[0,1,3..], 1:[]..}
         
         avg_V, average state-values resulting from simulation.
         Type: ndarray[(1,S), float]
'''

# Set environment parameters
nodemap = get_SAnodemap(S,A)  # rows index the current state, columns index 3 available neighboring states

# Set model parameters
alpha = 0.01  # learning rate
gamma = 1
beta = 10  # softmax exploration - exploitation
N = 5  # number of episodes to train over
speed = 10  # mice speed in nodes per sec
timeout = 10  # units of seconds that the reward port times out

# Simulation settings
avg_count = 50  # Number of times to repeat the simulation. Only state-values will be averaged, avg_V
avg_V = np.zeros(S)

for count in np.arange(avg_count):
    # Initialize model parameters
    V = np.zeros(S) #np.random.rand(S)  # state-action values
    V[HomeNode] = 0  # setting action-values of maze entry to 0
    V[RewardNode] = 0  # setting action-values of reward port to 0
    state_hist_all = {}
    total_reward = 0
    t = timeout

    for n in np.arange(N):
        # Initialize starting state,s0 to node 0
        s = 0
        state_hist = []

        # Begin episode
        while s!=HomeNode: #and s!=RewardNode:
            # Record current state
            state_hist.extend([s])

            # Use softmax policy to select action, a at current state, s
            betaV = [np.exp(beta*V[int(val)]) if val >= 0 else 0 for val in nodemap[s,:]]
            aprob = []
            if s not in lv6_nodes:
                for atype in np.arange(3):
                    prob = betaV[atype]/np.nansum(betaV)
                    if not np.isnan(prob):
                        aprob.extend([prob])
                    else:
                        aprob.extend([1])
                a = np.random.choice([0,1,2],1,p=aprob)[0]
            elif s in lv6_nodes:
                a = 0  # when s is an end node, chose action for lower level node

            # Observe reward and next state based on selected action
            sprime = int(nodemap[s,a])
            if sprime == RewardNode and t >= timeout:
                R = RewardNodeMag  # Receive a reward of 1 when transitioning to the reward port
                total_reward += 1
                t = 0  # Reset timer
            else:
                R = 0
                t += speed  # Increment time from last reward

            # Update action-value of previous state value, V[s]
            V[s] += alpha * (R + gamma*V[sprime] - V[s])

            # Shift state values for the next time step
            s = sprime
        state_hist.extend([s])
        state_hist_all[n] = state_hist

    avg_V += deepcopy(V)
avg_V /= avg_count

print('alpha: ', alpha, ' beta: ', beta, ' gamma: ', gamma)
visualize_pred(avg_V, state_hist_all)

### TD($\lambda$)

 - States: 128 maze nodes (including home node)
 - Terminal states: maze entry, 127 and reward port, 116
 - Rewards: 0 on all states except for 1 on the water port
 
 Pseudocode
- Softmax action selection policy: $\pi(a | s_i) = \frac{e^{\beta*V(s_{ij})}}{e^{\beta*V(s_{ij})} + e^{\beta*V(s_{ij})} + e^{\beta*V(s_{ij})}}$, where j indexes the 3 neighboring nodes to $s_i$
- Temporal difference error: $\delta = R + \gamma*V(s') - V(s)$
- Eligibility trace: $e(s) \leftarrow e(s) + 1$
- State-value update: $V(s) \leftarrow V(s) + \alpha*(\delta)*e(s)$        [all state-values are updated at each step]
- Eligibility trace decay: $e(s) \leftarrow \gamma*\lambda*e(s)$           [eligibility traces of all states are decayed at each step]

In [None]:
# Define TD(lambda) model
'''
Toy TD(lambda) model to predict trajectories

Returns: state_hist_all, dictionary of trajectories most recently simulated by the model. 
         If avg_count > 1, state_hist_all represents the latest simulation results
         Type: state_hist_all{0:[0,1,3..], 1:[]..}
         
         avg_V, average state-values resulting from simulation.
         Type: ndarray[(1,S), float]
'''

# Set environment parameters
nodemap = get_SAnodemap(S,A)  # rows index the current state, columns index 3 available neighboring states

# Set model parameters
alpha = 0.1   # learning rate, 0 < alpha < 1
beta = 10      # softmax exploration - exploitation
gamma = 1      # degree of discounting future state values, 0 < gamma < 1 
lamda = 0.5    # extent of past states to update, 0 < lambda < 1
N = 10          # number of episodes to train over
#speed = 10  # mice speed in nodes per sec
#timeout = 10  # units of seconds that the reward port times out

# Simulation settings
avg_count = 1  # Number of times to repeat the simulation. Only state-values will be averaged, avg_V
avg_V = np.zeros(S)

for count in np.arange(avg_count):
    # Initialize model parameters
    V = np.zeros(S)    # state values
    V[HomeNode] = 0    # setting state-values of maze entry to 0
    V[RewardNode] = 0  # setting state-values of reward port to 0
    e = np.zeros(S)    # eligibility trace vector for all states
    state_hist_all = {}
    #t = timeout

    for n in np.arange(N):
        # Initialize starting state,s0 to node 0
        s = 0
        state_hist = []

        # Begin episode
        while s!=HomeNode and s!=RewardNode:
            # Record current state
            state_hist.extend([s])

            # Use softmax policy to select action, a at current state, s
            betaV = []
            for node in nodemap[s,:]:
                if node==InvalidState:
                    betaV.extend([0])
                else:
                    betaV.extend([np.exp(beta*V[node])])
            prob = betaV/np.sum(betaV)
            try:
                a = np.random.choice([0,1,2],1,p=prob)[0]
            except ValueError:
                print('Error with aprob')
                print('Current state: ', s, 'Potential future states: ', nodemap[s,:], ' prob: ', prob)

            # Observe reward and next state based on selected action
            sprime = nodemap[s,a]
            if sprime == RewardNode: #and t >= timeout:
                R = RewardNodeMag  # Receive a reward of 1 when transitioning to the reward port
                #t = 0  # Reset timer
            else:
                R = 0
                #t += speed  # Increment time from last reward

            # Calculate error signal for current state
            td_error = R + gamma*V[sprime] - V[s]
            e[s] += 1
            
            # Propagate value to all other states
            for node in np.arange(S):
                V[node] += alpha * td_error * e[node]
                e[node] = gamma * lamda * e[node]

            # Update future state to current state
            s = sprime
            
        state_hist.extend([s])
        state_hist_all[n] = state_hist

    avg_V += deepcopy(V)
avg_V /= avg_count

print('alpha: ', alpha, ' beta: ', beta, ' gamma: ', gamma, 'lambda: ', lamda)
visualize_pred(avg_V, state_hist_all)

## Simulating Fake Data to test Parameter Recovery for a Model

In [None]:
# Generating simulated data for a range of parameter values and saving these trajectories in the directory 'pred_traj_dir'

# Set variables
file_name = 'set5_a0.05_b2_g0.7.p'
sim_data = False
save_traj = False
N = 10
TrajSize = 3000
TrajNo = 20
TD0_type = 'first_visit'  # Set to 'first_visit' or 'multiple_visit' to run a specific TD(0) model

if sim_data:
    # Simulating data for parameter recovery
    da = 0.1
    db = 0.1
    dg = 0.1
    alpha_range = np.arange(0,1+da,da)
    beta_range = np.arange(0,1+db,db)
    gamma_range = np.arange(0,1+dg,dg)
    true_param = {}
    set_counter = 0
    for gamma in gamma_range:
        for beta in beta_range:
            for alpha in alpha_range:
                print('Now simulating: ', set_counter, alpha, beta, gamma)
                true_param[set_counter] = [alpha, beta, gamma]

                sub_fits = {}
                for mouseID in np.arange(10):
                    sub_fits[mouseID] = [alpha,beta,gamma]

                if TD0_type == 'first_visit':
                    state_hist_AllMice, valid_bouts, success = TD0_first_visit(sub_fits,'Rew')
                elif TD0_type == 'multiple_visit':
                    state_hist_AllMice, valid_bouts, success = TD0_multiple_visit(sub_fits,'Rew')

                simTrajS = np.ones((N,TrajNo,TrajSize), dtype=int) * InvalidState
                for mouseID in np.arange(N):
                    for boutID in np.arange(len(state_hist_AllMice[mouseID])):
                        simTrajS[mouseID,boutID,0:len(state_hist_AllMice[mouseID][boutID])] = state_hist_AllMice[mouseID][boutID]
                if success == 0:
                    print('Not saving set ', set_counter)
                elif success == 1:
                    if TD0_type == 'first_visit':
                        pickle.dump(simTrajS,open(pred_traj_dir+'full_search_first_visit/set'+str(set_counter)+'.p','wb'))
                    elif TD0_type == 'multiple_visit':
                        pickle.dump(simTrajS,open(pred_traj_dir+'full_search/set'+str(set_counter)+'.p','wb'))

                # Increment counter
                set_counter += 1

    # Save true parameter sets
    if first_visit_TD:
        pickle.dump(true_param,open(pred_traj_dir+'full_search_first_visit/true_param.p','wb'))
    else:
        pickle.dump(true_param,open(pred_traj_dir+'full_search/true_param.p','wb'))
        
    # Convert simulated trajectory from a dictionary to an array and save
    if save_traj:
        simTrajS = np.ones((N,TrajNo,TrajSize), dtype=int) * InvalidState
        for mouseID in np.arange(N):
            for boutID in np.arange(len(state_hist_AllMice[mouseID])):
                simTrajS[mouseID,boutID,0:len(state_hist_AllMice[mouseID][boutID])] = state_hist_AllMice[mouseID][boutID]
        pickle.dump(simTrajS,open(pred_traj_dir+file_name,'wb'))

## Testing Simulated Data with the Model (running model forward on simulated choices)

In [None]:
# Running the TD(0) model forward on any set of simulated fake data to see what state-values evolve.
'''
Input: needs simulated trajectories of the form, ndarray[(N, TrajNo, TrajSize), int]

Usage: checking how the model might respond to certain sets of trajectories
       can also test out on real trajectories
'''

# Set input
simdata_path = 'test_search_cl_nonrp_abgtest_smalldata/set0.p'

# Set environment parameters
stan_simdata_dir = main_dir+'stan/pre_reward_traj/'
RT = 1
nodemap = get_SAnodemap(S,A)  # rows index the current state, columns index 3 available neighboring states
V = np.ones(S)
V[HomeNode] = 0  # setting action-values of maze entry to 0
V[RewardNode] = 0  # setting action-values of reward port to 0

# Load a simulated trajectory
TrajS = pickle.load(open(stan_simdata_dir+simdata_path,'rb'))

# Create action sequence from trajectories
N = np.shape(TrajS)[0]          # setting the number of rewarded mice
B = np.shape(TrajS)[1]          # setting the maximum number of bouts until the first reward was sampled
BL = np.shape(TrajS)[2]
TrajA = np.zeros(np.shape(TrajS), dtype=int)
for n in np.arange(N):
    for b in np.arange(B):
        for bl in np.arange(BL - 1):
            if TrajS[n, b, bl + 1] != InvalidState and TrajS[n, b, bl] != HomeNode:
                TrajA[n, b, bl] = np.where(nodemap[TrajS[n, b, bl], :] == TrajS[n, b, bl + 1])[0][0] + 1

# Set model parameters corresponding to simulated trajectory
alpha = 0.2
beta = 2
gamma = 0.2
mouseID = 0
boutsize = np.arange(len(TrajS[mouseID]))

for boutID in boutsize:
    for stepID, node in enumerate(TrajS[mouseID,boutID,:]):
        if node != InvalidState and node != HomeNode and node != RewardNode:
            s = node  # Retrieve current state

            a = TrajA[0,boutID,stepID]  # Retrieve action taken at state s

            # Observe reward and next state based on selected action
            sprime = int(nodemap[s,a-1])
            if sprime == RewardNode:
                R = 1  # Receive a reward of 1 when transitioning to the reward port
            else:
                R = 0

            # Update value of current state
            V[s] += alpha * (R + gamma*V[sprime] - V[s])

fig, ax = plt.subplots(figsize=(30,800))
axhandle = ax.imshow(np.transpose(np.reshape(V,(S,1))),cmap='RdPu')
ax.invert_yaxis()
ax.set_ylabel('V (s)')
ax.set_xticks(np.arange(0,127,5))
ax.set_xticklabels([str(val) for val in np.arange(0,127,5)])
ax.set_xlabel('Nodes')
ax.set_title('Current state values with alpha: %.2f, beta: %.2f and gamma: %.2f' %(alpha,beta,gamma))
fig.colorbar(axhandle,fraction=0.005)

## Evaluating Model Fits

In [None]:
# Simulate data with fitted parameters
'''
Needs: set of best fit parameters in the form, sub_fits{0:[alpha_fit, beta_fit, gamma_fit, LL], 1:[]...., 9:[]}

Usage: After getting fitted parameters from model-fitting (like in STAN), use them to run model simulations
       and plot resulting trajectories with the fit results
'''

sub_fits = pickle.load(open(stan_results_dir+'TD0_cl_nonrp_real/sub_fits.p','rb'))
fit_group = 'Rew'
fit_group_data = stan_data_dir+'real_traj/rewMICE_first_Rvisit.p'
state_hist_AllMice,valid_bouts,_ = TD0_first_Rvisit(sub_fits,fit_group,fit_group_data)

# Plotting to compare simulated and actual trajectory lengths
sim_lengths_all = {}
real_lengths_all = {}
rand_LL = {}
TrajS = pickle.load(open(fit_group_data,'rb'))
plt.figure()
for mouseID in np.arange(10):
    valid_boutID = np.where(TrajS[mouseID,:,2]!=InvalidState)[0]
    real_lengths = []
    sim_lengths = []
    for boutID, bout in enumerate(valid_boutID):
        end = np.where(TrajS[mouseID,bout]==InvalidState)[0][0]
        valid_traj = TrajS[mouseID,bout,0:end]
        random_choices = [val for val in valid_traj if val not in lv6_nodes]
        rand_LL[mouseID] = np.log(0.33) * len(random_choices) 
        real_lengths.extend([len(valid_traj)])
        sim_lengths.extend([len(state_hist_AllMice[mouseID][boutID])])
    real_lengths_all[mouseID] = real_lengths
    sim_lengths_all[mouseID] = sim_lengths
        
    plt.plot(np.arange(len(real_lengths_all[mouseID])), real_lengths_all[mouseID], 'r*', label='real')
    plt.plot(np.arange(len(real_lengths_all[mouseID])), sim_lengths_all[mouseID], 'g*', label='sim')
    if mouseID == 0:
        plt.legend()
plt.xlabel('Number of bouts till first reward')
plt.ylabel('Number of decisions/steps in bout')
plt.title('Number of steps in simulated trajectories')

In [None]:
# Save plots generated from model fit results
save_dir = 'C:/Users/kdilh/Documents/GitHub/MouseMaze/figures/TD0_firstreward/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
for mouseID, nickname in enumerate(RewNames):
    state_hist_cell, state_hist_xy = nodes2cell(state_hist_AllMice[mouseID])
    figtitle = 'Simulated trajectory for '+nickname+' with '+str(valid_bouts[mouseID]) \
    +' valid bouts \n alpha: '+str(np.round(sub_fits[mouseID][0],2))+' beta: '+str(np.round(sub_fits[mouseID][1],2)) \
    +' gamma: '+str(np.round(sub_fits[mouseID][2],2)) + ', subject LL - random LL: ' + str(np.round(sub_fits[mouseID][3] - rand_LL[mouseID],2))
    plot_trajectory('all', save_dir, nickname, figtitle)  # enter a single episode ID or enter 'all'