In [None]:
import jax.numpy as jnp
import pymdp.jax
from fast_structure_learning import *

In [None]:
path_to_file = "dove.mp4"

frames = read_frames_from_mp4(path_to_file)
(observations, locations_matrix, group_indices, sv_discrete_axis, V_per_patch), patch_indices = map_rgb_2_discrete(frames, tile_diameter=32, n_bins=16)
observations = jnp.asarray(observations)
agents, RG, LB = spm_mb_structure_learning(observations, locations_matrix, max_levels=8)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(frames[0])
colors = ['r', 'b']
for (i,locations_matrix) in enumerate(LB):
    plt.scatter(locations_matrix[:,0], locations_matrix[:,1], c=colors[i])
plt.show()

In [None]:
one_hot = jnp.zeros([1, 8])
one_hot = one_hot.at[0, 5].set(1.0)
D = one_hot

In [None]:
from pymdp.jax.control import compute_expected_obs
from jax import vmap
from functools import partial

expected_obs = partial(compute_expected_obs, A_dependencies=agents[-1].A_dependencies)

qo = vmap(expected_obs)([D,], agents[-1].A)

In [None]:
import jax.tree_util as jtu

expected_obs2 = partial(compute_expected_obs, A_dependencies=agents[-2].A_dependencies)

# split this in initial state "D" and path "E"
D2 = qo[::2]
# TODO also transition using this policy matrix?
E2 = jtu.tree_map(lambda x: jnp.expand_dims(jnp.argmax(x), [0]), qo[1::2])

qs_next, _ = agents[-2].infer_empirical_prior(E2, jtu.tree_map(lambda x : jnp.expand_dims(x, 0), D2))

# stack D2 and qs_next and generate qo at once
#qo2 = vmap(expected_obs2)(D2, agents[-2].A)
#qo2_nxt = vmap(expected_obs2)(qs_next, agents[-2].A)

stacked = []
for x,y in zip(D2, qs_next):
    stacked.append(jnp.concatenate([x, y], axis=0))

A_stacked = jtu.tree_map(lambda x: jnp.broadcast_to(x, (2, x.shape[1], x.shape[2])), agents[-2].A)
qo_stacked = vmap(expected_obs2)(stacked, A_stacked)


In [None]:
obs = jnp.array(qo_stacked)

# vmap over the batch dimension
#img = map_discrete_2_rgb(obs[:, 0, :], locations_matrix, group_indices, sv_discrete_axis, V_per_patch, patch_indices, frames.shape[-3:])
#img2 = map_discrete_2_rgb(obs[:, 1, :], locations_matrix, group_indices, sv_discrete_axis, V_per_patch, patch_indices, frames.shape[-3:])

map_discrete_2_rgb_fn = partial(map_discrete_2_rgb, locations_matrix=locations_matrix, group_indices=group_indices, sv_discrete_axis=sv_discrete_axis, V_per_patch=V_per_patch, patch_indices=patch_indices, image_shape=frames.shape[-3:])
imgs = vmap(map_discrete_2_rgb_fn, in_axes=1, out_axes=0)(obs)
imgs = imgs.reshape((4, imgs.shape[-3], imgs.shape[-2], imgs.shape[-1]))

In [None]:
import numpy as onp

def to_img(img):
    im = jnp.transpose(img, (1, 2, 0))
    im /= 255
    im = jnp.clip(im, 0, 1)
    im = (255*im).astype(onp.uint8)
    return im

plt.imshow(to_img(imgs[0]))


In [None]:
!pip install -q mediapy

In [None]:
import mediapy

ims = [to_img(imgs[i]) for i in range(imgs.shape[0])]

with mediapy.set_show_save_dir("."):
    mediapy.show_videos({"predictions": ims}, fps=1, codec='gif')
