In [None]:
from ColorShapeSpace_sim import CSC2Env
from CogSSM import SSM, silu
from dataclasses import dataclass
from combined_model import AgentConfig, QAgent, Deform
import torch
import torch.nn.functional as F
from train_utilities import build_value_obs, play_once, collect_batch_q, make_csc2_vector, TrainConfig
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
%matplotlib inline
plt.ion()
import pickle
from scipy.ndimage import uniform_filter
import json

In [None]:
import os

save_dir = "../models/cogssm_value_e0400"
with open(os.path.join(save_dir, "agent.pt"), "rb") as f:
    state_dict = torch.load(f)
with open(os.path.join(save_dir, "deform.pt"), "rb") as f:
    deform_state_dict = torch.load(f)
with open(os.path.join(save_dir, "config.json"), "rb") as f:
    config = json.load(f)

DIMS = config["dims"]
H = config["d_model"]
S = config["d_state"]
P = 1
CHUNK_SIZE = config["chunk_size"]
agent_config = TrainConfig(**config)

In [None]:
fig, axs = plt.subplots(4, 2)
fig.tight_layout()
for i, s in enumerate([756, 34, 2246, 4242]):
    env = CSC2Env(trials=CHUNK_SIZE * 15, render=False, n=36, render_reward=False, seed=s, one_d=DIMS==1)
    r_dist = env.reward_grid
    f_dists = env.freq_grid
    if DIMS == 2:
        axs[i, 0].imshow(r_dist)
        axs[i, 1].imshow(f_dists)
    else:
        axs[i, 0].plot(r_dist)
        axs[i, 1].plot(f_dists)
    axs[i, 0].set_xticks([])
    axs[i, 0].set_yticks([])
    axs[i, 0].set_xticklabels([])
    axs[i, 0].set_yticklabels([])
    axs[i, 0].axis('off')
    axs[i, 1].set_xticks([])
    axs[i, 1].set_yticks([])
    axs[i, 1].set_xticklabels([])
    axs[i, 1].set_yticklabels([])
    axs[i, 1].axis('off')
fig.savefig("example_env_dists.jpg")
plt.show()
    

In [None]:
model = QAgent(config=agent_config)
deform = Deform(channels=agent_config.percept_dim, deform_basis=4, groups=agent_config.dims)
model.load_state_dict(state_dict)
deform.load_state_dict(deform_state_dict)
#model.ssm.D.data = model.ssm.D.data * 0

In [None]:
# shwo agent ENV
# track hidden state outs
model.reset()
model.sequential = True
model.reset()
states = play_once(agent=model, cfg=agent_config, trials=CHUNK_SIZE*50, seed=34, render=False, deformer=deform, k=4).numpy()
saved_cache = model.cache.clone()

In [None]:
# generate batch of all stimuli 
angles = 2 * np.pi * np.arange(36) / 36
if DIMS == 2: 
    all_pairwise = np.meshgrid(angles, angles)
    all_pairwise = np.stack(all_pairwise, axis=2)
    all_pairwise = all_pairwise.reshape(1, -1, 2)
else:
    all_pairwise = angles.reshape(1, -1, 1)
raw_obs = build_value_obs(all_pairwise, cfg=agent_config).float()

In [None]:
# get and plot deformation
raw_percept = raw_obs[:, :agent_config.percept_dim]
rp = raw_percept.detach().cpu().numpy()
plt.scatter(rp[:, 2], rp[:, 3])
deform_percept = deform(raw_percept)
obs = raw_obs.clone()
obs[:, :agent_config.percept_dim] = deform_percept
obs = obs.unsqueeze(0)
dp = deform_percept.detach().cpu().numpy()
plt.scatter(dp[:, 2], dp[:, 3])
plt.show()

In [None]:
# get cog model value space
model.sequential = True
_, v_pred = model.forward(obs, torch.zeros(obs.shape[:2]), k=all_pairwise.shape[1])
v_pred = v_pred.detach().to("cpu")
model.cache = saved_cache.clone()

In [None]:
if DIMS == 2:
    v_pred = v_pred.reshape(36, 36).T
    plt.imshow(v_pred.numpy())
else:
    v_pred = v_pred.flatten()
    plt.plot(v_pred)
print(v_pred.min(), v_pred.max())

In [None]:
#plot deformation



In [None]:
plt.plot(uniform_filter(states.reshape(CHUNK_SIZE*50, H * S * P), 100, axes=(0,), mode="constant")[:-200], label=[str(i) for i in range(H * S * P)])
plt.legend()

In [None]:
# we can recover states spaces by forcing certain states
try:
    del v_pred
except Exception:
    pass
fig, axs = plt.subplots(H, max(2, S))
h = H
s = S
for i in range(h):
    for j in range(s):
        state = torch.ones(1296, 1, h, s, device="cpu") * 0.
        #obs = build_value_obs(all_pairwise, torch.zeros(1, device="cpu")).unsqueeze(0)
        model.sequential = True
        model.cache.ssm_state = state.clone()
        model.cache.conv_state = model.cache.conv_state * 0.
        _, bias = model.forward(obs, torch.zeros(obs.shape[:2]), k=all_pairwise.shape[1])
        state[:, :, i, j] = 1
        model.cache.ssm_state = state.clone()
        model.cache.conv_state = model.cache.conv_state * 0.
        _, v_pred = model.forward(obs, torch.zeros(obs.shape[:2]), k=all_pairwise.shape[1])
        v_pred = v_pred.detach().to("cpu") - bias.detach().to("cpu")
        if DIMS == 2:
            v_pred = v_pred.reshape(36, 36).T
            res = v_pred.numpy()
            axs[i, j].imshow(res)
        else:
            v_pred = v_pred.flatten()
            res = v_pred.numpy()
            axs[i , j].plot(res)
        print(res.min(), res.max())
#model.cache = saved_cache.clone()


In [None]:
# manually multiply b(x)u(x) to get true filters
# obs = build_value_obs(all_pairwise, torch.zeros(1, device="cpu")).unsqueeze(0)
conv = model.ssm.conv1d
conv.padding = 0
linear_weights = model.ssm.in_proj.weight

In [None]:
sep_readout = True

if sep_readout:
    inner = obs[0] @ initial_weights.T + in_bias
    inner = inner @ linear_weights.T
    z = silu(inner[:, :H].unsqueeze(2))
    r_sig = torch.zeros_like(inner)
    inner = torch.stack([inner, r_sig], dim=2)
    xBC = conv(inner[:, H:-P, :])
else:
    inner = obs[0] @ initial_weights.T + in_bias
    inner = inner @ linear_weights.T
    r_sig = torch.zeros_like(inner)
    inner = torch.stack([inner, r_sig], dim=2)
    xB = conv(inner[:, :, :])
    B = xB[..., H:]
    C = B.clone()
    xBC = torch.cat([xB, C], dim=-1)
    x = xBC[:, :H]
    z = x.clone()

In [None]:
x = xBC[:, :H]
x = silu(x)
b = xBC[:, H:H+S].transpose(1, 2)
c = xBC[:, H+S:H+2*S].transpose(1, 2)
BX = x * b
CZ = z * c

In [None]:
if DIMS==2:
    BX = BX.reshape(36, 36, H, S)
else:
    BX = BX.reshape(36, H, S)
fig, axs = plt.subplots(H, S + 1)
for i in range(H):
    for j in range(S):
        bx = BX[..., i, j].detach().cpu().numpy()
        cz = CZ[..., i, j].detach().cpu().numpy()
        if DIMS==1:
            axs[i, j].plot(bx)
            axs[i, j].plot(cz)
        else:
            axs[i, j].imshow(bx)
plt.show()

In [None]:
nx = x.detach().cpu().numpy().squeeze()
nb = b.detach().cpu().numpy().squeeze()
nc = c.detach().cpu().numpy().squeeze()
nz = z.detach().cpu().numpy().squeeze()

In [None]:
fig, axs = plt.subplots(2, 2)
axs[0, 0].plot(nx)
axs[0, 1].plot(nb)
axs[1, 0].plot(nz)
axs[1, 1].plot(nc)

In [None]:
n = 1 
y = n * BX * CZ
plt.plot(y.detach().cpu().numpy().squeeze())

In [None]:
x = np.arange(500) * np.pi * 2 / 100
#x = np.cos(x)
m = x[:, None] 
m1 = np.cos(-2 * m + 2 * m.T - 2)
m2 = np.sin(3*m + 1 * m.T + 1)
m3 = np.cos(-2 * m + -2 * m.T - .5)
plt.imshow(m1 + m2 + m3)