In [None]:
import sys
sys.path.append('../..')
from hippocampus.environments import SimpleMDP, HexWaterMaze, TwoStepTask
from hippocampus.experiments.reliability_in_twostep import CombinedAgent
from definitions import FIGURE_FOLDER
# TODO: do this on linear track 

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.notebook import  tqdm
import pandas as pd
import os
import matplotlib

In [None]:
sns.palplot(sns.color_palette())

In [None]:
ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.85),inv_temp=10)

init_p_sr = .5
ag.p_sr = init_p_sr

ag.HPC.learning_rate =.01

In [None]:

df = pd.DataFrame({})
for ep in tqdm(range(1,400)):
    results = ag.one_episode(deterministic_policy=False)
    results['trial'] = ep 
    df = df.append(results, ignore_index=True)
    ag.HPC.learning_rate *=.95

In [None]:
plt.plot(np.array([df['omega'].iloc[i] for i in range(len(df))]))

In [None]:
dls_reliab = pd.concat([pd.Series([0.]), df['DLS reliability']])
hpc_reliab = pd.concat([pd.Series([0.]), df['HPC reliability']])


In [None]:
font = {'size': 22}

matplotlib.rc('font', **font)


fig, ax = plt.subplots()
#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])
ax.plot(dls_reliab, color=sns.color_palette()[1],  linewidth=2)
ax.plot(hpc_reliab, color=sns.color_palette()[2],  linewidth=2)

# Move left and bottom spines outward by 10 points
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.ylabel('Reliability')
plt.xlabel('Trial')
plt.legend(['DLS reliability', 'HPC reliability'])
plt.ylim([-.1,1])
plt.tight_layout()
plt.savefig(os.path.join(FIGURE_FOLDER, 'reliability.pdf'))
#plt.xlim([-10,150])


In [None]:
#df.plot(x='trial', y='P(SR)')
fig, ax = plt.subplots()

font = {'size': 22}


ax.plot(pd.concat([pd.Series([init_p_sr]),  df['P(SR)']]), color=sns.color_palette()[0],linewidth=2)

# Move left and bottom spines outward by 10 points
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_position(('outward', 10))
# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.ylabel('Pr(HPC)')
plt.xlabel('Trial')

plt.ylim([0,1])
plt.tight_layout()
plt.savefig(os.path.join(FIGURE_FOLDER, 'psr.pdf'))


In [None]:
n_agents = 10
n_trials = 300
dls_reliab_M = np.zeros((n_agents, n_trials))
hpc_reliab_M = np.zeros((n_agents, n_trials))
psr_M = np.zeros((n_agents, n_trials))

for ia in tqdm(range(n_agents)):

    ag = CombinedAgent(env=SimpleMDP(5, reward_probability=.8))

    init_p_sr = .5
    ag.p_sr = init_p_sr
    ag.HPC.learning_rate=.01
    df = pd.DataFrame({})
    for ep in tqdm(range(1, n_trials),leave=False):
        results = ag.one_episode()
        results['trial'] = ep 
        df = df.append(results, ignore_index=True)

    dls_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['DLS reliability']])
    hpc_reliab_M[ia,:] = pd.concat([pd.Series([0.]), df['HPC reliability']])
    psr_M[ia,:] =  pd.concat([pd.Series([init_p_sr]),  df['P(SR)']])


In [None]:
fig, ax = plt.subplots()
#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])
ax.plot(dls_reliab_M.mean(axis=0))
ax.plot(hpc_reliab_M.mean(axis=0))
plt.ylim([-.1,1])


In [None]:
fig, ax = plt.subplots()
#df.plot(ax=ax, x='trial', y=['DLS reliability', 'HPC reliability'])
ax.plot(dls_reliab_M.mean(axis=0))
ax.plot(hpc_reliab_M.mean(axis=0))
plt.ylim([-.1,1])
plt.xlim([-10,150])

In [None]:
omegas = df['omega']

In [None]:
alloms = np.concatenate(np.array(df['omega'])).reshape(ag.env.nr_states, -1)

In [None]:
plt.plot(alloms[4])

In [None]:
df.plot(x='trial', y='omega_dls')
plt.ylim([0,1])

In [None]:
omg0 = [df['omega'][i][0] for i in range(272)]
omg1 = [df['omega'][i][1] for i in range(272)]
omg2 = [df['omega'][i][2] for i in range(272)]
omg3 = [df['omega'][i][3] for i in range(272)]
omg4 = [df['omega'][i][4] for i in range(272)]
omg5 = [df['omega'][i][5] for i in range(272)]
omg6 = [df['omega'][i][6] for i in range(272)]
omg7 = [df['omega'][i][7] for i in range(272)]
omg8 = [df['omega'][i][8] for i in range(272)]





In [None]:
plt.plot(omg0)
plt.plot(omg1)
plt.plot(omg2)
plt.plot(omg3)
plt.plot(omg4)
plt.plot(omg5)
plt.plot(omg6)
plt.plot(omg7)
plt.plot(omg8)




In [None]:
alloms = np.concatenate([omg0, omg1, omg2, omg3, omg4, omg5, omg6, omg7, omg8]).reshape( -1, len(omg0))

In [None]:
plt.plot(alloms.mean(axis=0))
plt.plot(df['omega_dls'])
plt.ylim([0,1])