In [None]:
import numpy as onp
from functools import partial
import wandb
import pickle
from moviepy.editor import ImageSequenceClip
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

from keycld import models
from keycld.models import predict, predict_constraint
from keycld.data.dm import Data
from keycld.util import NumpyLoader, visualize_n_maps

In [None]:
# change the string below to the 'Run path' (see Overview on the wandb dashboard)
# this notebook is only tested for the KeyCLD models

run = wandb.Api().run('<Run path>')

run.file('model.p').download(replace=True)
with open('model.p', 'rb') as f:
    model = pickle.load(f)
args = run.config

args

In [None]:
data = Data(args['environment'], args['init_mode'], args['control'])
dataloader = NumpyLoader(data.train, batch_size=1, num_workers=12, shuffle=True)

In [None]:
# mass matrix
mass_matrix_static = model.mass_matrix(jnp.zeros(2))

with onp.printoptions(precision=2, suppress=True, floatmode='fixed'):
    print(mass_matrix_static)

In [None]:
# potential energy
def calculate_potential_energy(image):
    keypoints, _ = model.encoder(image[None])
    state = keypoints.flatten()
    return model.potential_energy(state)
images = data.grid['x']
positions = data.grid['positions']
potential_energies = jax.vmap(jax.vmap(calculate_potential_energy))(images)

plt.imshow(potential_energies)
plt.show()

### Prediction

In [None]:
i = 0
solver = 'dopri'
item = data.val[i]
t = item['t']
x = item['x']
action = item['action']
keypoints, keypoint_maps = model.encoder(x)

keypoint_maps_n = keypoint_maps / keypoint_maps.max((1, 2), keepdims=True)
heatmaps = (onp.concatenate([x, visualize_n_maps(keypoint_maps_n)], axis=-2) * 255).astype(onp.uint8)

# keypoints_pred = predict(model.ode, t, keypoints[:2], action, solver=solver)
keypoints_pred = predict_constraint(data.constraint_fn, model.ode, t, keypoints[:2], action, solver=solver)
x_recon, gaussian_maps = model.renderer(keypoints_pred)

prediction = (onp.concatenate([x, visualize_n_maps(gaussian_maps), x_recon], axis=-2) * 255).astype(onp.uint8)

In [None]:
ImageSequenceClip(list(heatmaps), fps=30).resize((512, 256)).ipython_display()

In [None]:
ImageSequenceClip(list(prediction), fps=30).resize((3*256, 256)).ipython_display()