In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import datetime
import re
from os import listdir
import sys
sys.path.insert(0, '../YOUR DIRECTORY OF MODEL FILES')
from plotting_functions import *
from arb_functions import *
from helper_functions import *
from behaviour_analyses import *
from environment_functions import *

save_dirs = ['YOUR DIRECTORY OF SAVED DATA']

date = 'DATE OF SIMULATION'
run = 'RUN NUMBER'

recent = -1
data, para, list_of_files, save_path = get_data(save_dirs, run, date, recent)

A_RNN, g2g, x_all, g_all, p_all, p_gen_all, acc_s_t_to, acc_s_t_from, positions, shinys, adj, timeseries = data
params, widths, batch_id, g_size, p_size, s_size, s_size_comp, n_freq, width, states = para

mult = 4 if params['world_type'] == 'tonegawa' else 4  # upsample
smoothing = 1
cmap = 'jet'
maxmin=True

import seaborn
seaborn.set_style(style='white')
seaborn.set_style({'axes.spines.bottom': False,'axes.spines.left': False,'axes.spines.right': \
                   False,'axes.spines.top': False})

shinys_, masks, g_lim, p_lim = sort_data(g_all, p_all, shinys, widths, mult, smoothing, params, batch_id, \
                                         g_max_0=False, p_max_0=True)

In [None]:
for env in range(params['n_envs_save']):
    print('env', env)
    if shinys_[env] is not None:
        for shiny in shinys_[env]:
            try:
                print('acc to ' + str(shiny), acc_s_t_to[env][shiny])
                print('acc from ' + str(shiny), acc_s_t_from[env][shiny])
            except IndexError:
                pass

In [None]:
for env in range(params['n_envs_save']):
    num_correct = np.sum(acc_s_t_to[env] * positions[env])
    proportion = num_correct / sum(positions[env])
    approx_num = proportion * np.sum(positions[env] > 0.1)
    print(env, '   Approx proportion : ', np.round(proportion, decimals=3), \
          '   Approx num : ', np.round(approx_num, decimals=2))

In [None]:
env0 = 1
env1 = 2

In [None]:
# GRAPHS
import networkx as nx
try:
    plt.figure(figsize=(10,5))
    for i, env in enumerate([env0, env1]):
        adj_ = remove_zero_adj(adj[env])
        
        plt.subplot(1,2,i + 1)
        g = nx.from_numpy_matrix(adj_)
        pos_nodes = nx.spring_layout(g, iterations=500)
        nx.draw(g, pos=pos_nodes, node_size=50, with_labels=True)
    
    plt.show()
except TypeError:
    plt.close('all')
    pass

In [None]:
# AGENT COVERAGE

In [None]:
plt.figure(figsize=(10,5))
for i, env in enumerate([env0, env1]):
    plt.subplot(1,2,i+1)
    cell_reshaped = reshape_cells(positions[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped)
    plt.colorbar()

plt.show()

print(min(positions[env0]), min(positions[env1]))

In [None]:
# where the agent didnt go much shown in black
plt.figure(figsize=(10,5))

for i, env in enumerate([env0, env1]):
    a = positions[env] < 4
    plt.subplot(1,2,i+1)
    cell_reshaped = reshape_cells(a.astype(int), widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped)
    plt.colorbar()    
    
plt.show()

In [None]:
# ACCURACY MAPS - SENSORY

In [None]:
plt.figure(figsize=(10,10))
for i, env in enumerate([env0, env1]):
    plt.subplot(2,2,i+1)
    cell_reshaped = reshape_cells(acc_s_t_to[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped,vmax=1,vmin=0)
    plt.title('accuracy to')

    plt.colorbar()
    plt.subplot(2,2,i+3)
    cell_reshaped = reshape_cells(acc_s_t_from[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped,vmax=1,vmin=0)
    plt.title('accuracy from')

    plt.colorbar()

plt.show()

In [None]:
# ENTORHINAL CELLS 
square_plot(g_all[env0], widths[batch_id[env0]], name='g0', maxmin=maxmin, shiny=shinys_[env0], \
            hexy=params['world_type'], lims=g_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env0])

In [None]:
square_plot(g_all[env1], widths[batch_id[env1]], name='g1', maxmin=maxmin, shiny=shinys_[env1], \
            hexy=params['world_type'], lims=g_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env1])

In [None]:
# ENTORHINAL AUTOCORRELATIONS

In [None]:
square_autocorr_plot(g_all[env0], widths[batch_id[env0]], name='g0_auto', \
                     hexy=params['world_type'], mult=mult, smoothing=smoothing, cmap=cmap, circle=True)

In [None]:
square_autocorr_plot(g_all[env1], widths[batch_id[env1]], name='g1_auto', \
                     hexy=params['world_type'], mult=mult, smoothing=smoothing, cmap=cmap, circle=True)

In [None]:
# HIPPOCAMPAL CELLS

In [None]:
square_plot(p_all[env0],widths[batch_id[env0]], name='p0', shiny=shinys_[env0],\
            hexy=params['world_type'], lims=p_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env0])

In [None]:
square_plot(p_all[env1],widths[batch_id[env1]], name='p1', shiny=shinys_[env1], \
            hexy=params['world_type'], lims=p_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env1])

In [None]:
"""
BEHAVIOURAL ANALYSES
"""

seaborn.set_style(style='white')

params['acc_simu'] = 1  # how accurate simulated node/edge agent is
recent = -1  # how far back into history of saved data
filt_size = 61  # smoothing window size (must be odd)
n = 10
fracs = [x /n for x in range(n+2)]  # for assessing accuracy within certain proportions of nodes visited

# for steps since visted analysis - assess accuracy within those steps
if params['world_type'] in ['family_tree', 'line_ti', 'tonegawa']:
    a_s = [0, 10, 20]
else:
    a_s = [0, 4, 10, 20, 40, 60, 100, 200, 300, 400, 600]

# Load data
positions_link, coos, env_info, distance_info = link_inferences(save_path, list_of_files, widths, batch_id, params,\
                                                                index=recent)
n_states, wids, n_available_states, n_available_edges = env_info

# Perform behavioural analayses. Partition results into environments of same size
allowed_widths = sorted(np.unique([widths[b_id] for b_id in batch_id]))
results = []
for allowed_wid in allowed_widths:
    p_cors, nodes_visited_all, edges_visited_all, time_vis_anal = \
        analyse_link_inference(allowed_wid, fracs, a_s, positions_link, coos, env_info, params)
    p_cors = [ind for ind in p_cors if len(ind)>0]
    results.append([p_cors, nodes_visited_all, edges_visited_all, time_vis_anal])

In [None]:
# inferernce of untraversed link as a function proportion on nodes seen
plot_link_inference(results, allowed_widths, fracs, s_size)

In [None]:
# simulation of node and edge agent
plot_sim_corrects(results, allowed_widths, filt_size, wids, n_available_states, n_available_edges, params)

In [None]:
# tem accuracy for each environment size
plot_acc_vs_sum_nodes_edges(results, allowed_widths, coos, filt_size, wids, n_available_states, n_available_edges)

In [None]:
# prediction accuracy after # steps since last time visited that node
plot_acc_vs_steps_since_visited(results, allowed_widths, a_s, s_size)

In [None]:
# HARLOW LIKE LEARNING CURVES

results_mega = []
coos_mega = []
allowed_widths = [9]
allowed_recents = [[1,2], [7, 8, 9] , [11, 12, 13, 14], [18, 19,20, 21], [-1]]  # for family_tree

#[[1,2], [7, 8, 9] , [11, 12, 13, 14], [18, 19,20, 21], [-3,-2,-1]]  # for family_tree
#[[1,2], [5, 6, 7, 8] , [9, 10, 11, 12, 13, 14], [50, 55, 60, 70], [-3,-2,-1]]  # for tran_inf

for hyper_recent in allowed_recents:
    results_ = []
    for recent in hyper_recent:

        positions_link, coos, env_info, distance_info = link_inferences(save_path, list_of_files, widths, batch_id, \
                                                                        params, index=recent)

        n_states, wids, n_available_states, n_available_edges = env_info

        #results_ = []
        for allowed_wid in allowed_widths:
            nodes_visited_all, edges_visited_all, n_visited_means = [], [], []
            for i in range(params['batch_size']):
                if wids[i] == allowed_wid:
                    params['states'] = n_states[i]
                    params['n_available_states'] = n_available_states[i]
                    params['edges'] = n_available_edges[i]

                    node_visited, edge_visited, nodes, edges, num_visited, edge_stay_still = \
                        check_visited(positions_link[i].astype(int), params)
                    # all new links, but visited nodes!
                    n_visited_means.append(n_visited_correct(coos[i], num_visited, edge_visited))

            results_.append([n_visited_means])
    results_mega.append(results_)
    coos_mega.append(coos)


In [None]:
f = plt.figure()

n_visits_lim = 20

smallest_min = 10000
for i, recent in enumerate(allowed_recents):    
    for res, wid in zip(results_mega[i], allowed_widths):
        pc = res[0]
        pc, smallest = sort_out_summin(pc)
        smallest_min = np.minimum(smallest, smallest_min)
    
        mean_pc = np.nanmean(pc, 0)
        std_pc = sp.stats.sem(pc, 0, nan_policy='omit')
        plt.plot(mean_pc)
        plt.fill_between(np.arange(smallest), mean_pc + std_pc, mean_pc - std_pc,
                         alpha=0.2, label=str(recent))

smallest_min = np.minimum(smallest_min, n_visits_lim)
plt.plot((0, smallest_min-1), (1/s_size, 1/s_size), 'k--', linewidth=linewidth)
plt.xlim(0, smallest_min-1)
plt.ylim(0, 1.1)
plt.tick_params(axis='both', which='major', labelsize=labelsize)
plt.tick_params(axis='both', which='minor', labelsize=labelsize)
plt.xlabel('# times node visited', fontsize=fontsize)
plt.ylabel('Correct inference of link', fontsize=fontsize)
plt.legend()
plt.show()