In [7]:
import numpy as np
import pandas as pd
from lib.utils import read_pickle
from os.path import join
import os
import matplotlib.pyplot as plt
from lib.Sequences import createX
import collections

# In order to use Webots API, Replace the path to your webots library
# You could also set it in your OS system.
import sys
sys.path.insert(0, r'C:\Program Files\Webots\lib\controller\python')


# Draw X matrix animation in Leibold 2020's paper

In [2]:
from lib.Sequences import Sequences
import numpy as np
import matplotlib.pyplot as plt
import os
R = 5
L = 20
K = R + L - 1
seq = Sequences(R, L, False)

fids = [[] for _ in range(40)]
fids[0].append('1')
fids[8].append('2')
fids[11].append('3')
fids[14].append('2')
fids[19].append('4')
os.makedirs('Xgif', exist_ok=True)
for i, ids in enumerate(fids):

    seq.step(ids)
    out = seq.X.copy()
    X = np.zeros((4, K))
    X[:len(out), :] = out
    fig, ax = plt.subplots(figsize=(8, 1.8))
    ax.imshow(X)
    # ax.axis('off')
    ax.set_xticks(np.arange(K+1)-0.5, minor=True)
    ax.set_yticks(np.arange(4+1)-0.5, minor=True)
    plt.grid(which='minor')
    ax.set_title(f'X(t={i+1})')
    fig.tight_layout()
    fig.savefig('Xgif/%d.png'%(i), dpi=100)
    plt.close()



# Drawing X for my simulation
The script below may fail, since the pickled file is not the same. The script is kept here to show how the animation in my slides (the evolution of Matrix $X$ when feature nodes were observed in camera) is done.

In [3]:
# Load data =============================================
debug_plot_tag = False
project_name = 'RegressedToTrueState'
project_dir = join('data', project_name)
plot_dir = join('plots', project_name)
os.makedirs(plot_dir, exist_ok=True)
chpt_name = 'PPO$'
max_chpt_num = 1

# Append traj data and fsigma
trajkeys = 't', 'x', 'y', 'a', 'sid', 'r', 'terminated', 'truncated'
alltrajdict = {key:[] for key in trajkeys}
allfsigmalist = []
for i in range(1, max_chpt_num+1):
    traj_data_pth = join(project_dir, chpt_name.replace('$', '%d'%i) + '_trajdata.pickle')
    trajdict_list = read_pickle(traj_data_pth)  # a list of dictionaries, each for one episode
    for trajdict in trajdict_list:
        # Traj data
        for key in trajkeys:
            alltrajdict[key].extend(trajdict[key])
        # Fsigma
        allfsigmalist.extend(trajdict['fsigma'])
alltrajdf = pd.DataFrame(alltrajdict)

# Load hipposlam
load_hipposlam_pth = join(project_dir, chpt_name.replace('$', '%d'%max_chpt_num) + '_hipposlam.pickle')
hipposlam = read_pickle(load_hipposlam_pth)
hipposeq = hipposlam['hipposeq']
hippomap = hipposlam['hippomap']
hippoteach = hipposlam['hippoteach']
fpos = hipposlam['fpos']
R = hipposeq.R
L = hipposeq.L
F = hipposeq.num_f
K = hipposeq.X_Ncol
stored_f = hipposeq.stored_f
id2fkey_dict = {val:key for key, val in stored_f.items()}

# Assert data validity
assert R == hippomap.R
assert F == hippomap.current_F
assert K == hippomap.K
assert alltrajdf.shape[0] == len(allfsigmalist)

ModuleNotFoundError: No module named 'hipposlam.Sequences'; 'hipposlam' is not a package

In [76]:
inds_c = [val for key, val in stored_f.items() if key[-1]=='c']
inds_f = [val for key, val in stored_f.items() if key[-1]=='f']
inds_t = [val for key, val in stored_f.items() if key[-1]=='t']

inds = inds_f + inds_c + inds_t
sepf = len(inds_f)
sepc = len(inds_c) + sepf

In [95]:
os.makedirs('gif/Xgif_sim', exist_ok=True)
os.makedirs('gif/traj', exist_ok=True)
Xgif_dir = join('gif/Xgif_sim')
traj_dir = join('gif/traj')


# Obtain x, y, a data for each sigma ======================================
xdict = collections.defaultdict(list)
ydict = collections.defaultdict(list)
adict = collections.defaultdict(list)
trange = np.arange(11, 100)  # 11
X = np.zeros((trange.shape[0], F, K))

xmin, xmax = alltrajdf['x'].min(), alltrajdf['x'].max()
ymin, ymax = alltrajdf['y'].min(), alltrajdf['y'].max()
for i, t in enumerate(trange):
    x = alltrajdf['x'][11:t+1]
    y = alltrajdf['y'][11:t+1]

    fig = plt.figure(figsize=(7, 8), facecolor='w')
    ax1 = fig.add_axes([0.05, 0.4, 0.4, 0.35])
    ax2 = fig.add_axes([0.50, 0.05, 0.6, 0.9])
    ax = [ax1, ax2]
    ax[0].plot(x, y, color='r')

    for _, fposeach in fpos.items():
        ax[0].scatter(fposeach[0], fposeach[1], marker='o', color='g')
    ax[0].set_xlim(xmin, xmax)
    ax[0].set_ylim(ymin, ymax)


    fsigma = allfsigmalist[t]
    Xslice = createX(R, F, K, stored_f, fsigma)
    ax[1].imshow(Xslice[inds])
    ax[1].axhline(sepf-0.5, color='r')
    ax[1].axhline(sepc-0.5, color='r')
    ax[1].set_xticks(np.arange(K+1)-0.5, minor=True)
    ax[1].set_yticks(np.arange(F+1)-0.5, minor=True)
    ax[1].grid(which='minor', linewidth=1)
    ax[1].set_title(f'X(t={t})')
    # fig.tight_layout()
    fig.savefig(f'{traj_dir}/{i}.png', dpi=100)
    # break
    plt.close()
