In [1]:
from pathlib import Path
import os
from IPython.display import HTML
import wandb
import torch
from mp_transformer.models.transformer import MovementPrimitiveTransformer
from mp_transformer.config import CONFIG
from mp_transformer.train import setup
from mp_transformer.utils import save_side_by_side_strip

In [2]:
current_dir = Path.cwd().parts[-1]
if current_dir == "demo":
    os.chdir("..")
!pwd

/data/daniel/git/mp-transformer


In [3]:
run = wandb.init(project="mp-transformer")
artifact = run.use_artifact("tcs-mr/mp-transformer/model:v5", type='model')
artifact_dir = artifact.download()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdaniel-a[0m ([33mtcs-mr[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model:v5, 86.25MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.1


In [4]:
print(artifact_dir)

./artifacts/model:v5


In [6]:
model, train_dataset, val_dataset = setup(CONFIG)
artifact_dir = "./artifacts/model:v5"
model = model.load_from_checkpoint(Path(artifact_dir, "model.ckpt"), config=CONFIG)
run.finish()

In [10]:
item = val_dataset[-1]
poses, timestamps = item["poses"], item["timestamps"]
# poses = torch.stack([poses[0, :], poses[-1, :]])
# timestamps = torch.stack([timestamps[0], timestamps[-1]])
# ys, timestamps = item["poses"], item["timestamps"]
# poses[48:80, :] = 0
# timestamps[48:80] = 0
# poses[0:, :] = 0
# timestamps[0:] = 0
# item["poses"] = poses
# item["timestamps"] = timestamps
y_hat = model.infer(poses, timestamps)

In [11]:
# print(item["poses"], item["timestamps"])
save_side_by_side_strip(item, model, CONFIG["num_primitives"])

Video saved to tmp/comp_vid.mp4
Video saved to tmp/comp_vid0.mp4
Video saved to tmp/comp_vid1.mp4
Video saved to tmp/comp_vid2.mp4
Video saved to tmp/comp_vid3.mp4
Video saved to tmp/comp_vid4.mp4
Video saved to tmp/comp_vid5.mp4
Moviepy - Building video tmp/comp_strip.mp4.
Moviepy - Writing video tmp/comp_strip.mp4



                                                               

Moviepy - Done !
Moviepy - video ready tmp/comp_strip.mp4


In [12]:
HTML("""
<video width="320" height="240" controls>
  <source src="../tmp/comp_strip.mp4" type="video/mp4">
</video>
""")

In [13]:
init_latents = torch.zeros(1, 6, 128)
timestamps = timestamps.unsqueeze(0)

In [14]:
sampled_latents = torch.randn_like(init_latents)
out = model.decoder(timestamps, sampled_latents)
recons_sequence = out["recons_sequence"]
recons_sequence = recons_sequence.squeeze(0).detach().numpy()
recons_sequence.shape, type(recons_sequence)

((128, 3), numpy.ndarray)

In [15]:
import numpy as np
import imageio
from mp_transformer.datasets.toy_dataset import unnormalize_pose
from mp_transformer.utils.generate_toy_data import BONE_LENGTHS, render_image

In [16]:
imgs = []
for rec in recons_sequence:
    rec = unnormalize_pose(rec)
    img = render_image(rec, BONE_LENGTHS)
    imgs.append(img)

In [17]:
output_file = f"tmp/gen_vid.mp4"
with imageio.get_writer(output_file, fps=20) as writer:
    for img in imgs:
        img_array = np.array(img)  # Convert PIL Image object to NumPy array
        writer.append_data(img_array)

print(f"Video saved to {output_file}")

Video saved to tmp/gen_vid.mp4


In [18]:
HTML("""
<video width="320" height="240" controls>
  <source src="../tmp/gen_vid.mp4" type="video/mp4">
</video>
""")

In [26]:
def interpolate_sequences(model, joint_angles_before, joint_angles_after, timestamps_before, timestamps_after, num_interpolations=10):
    """
    Given a VAE-Transformer model, joint angles and timestamps before and after the missing part, interpolates between these sequences.

    Args:
        model (nn.Module): The VAE-Transformer model.
        joint_angles_before (torch.Tensor): A tensor of joint angles before the missing part.
        joint_angles_after (torch.Tensor): A tensor of joint angles after the missing part.
        timestamps_before (torch.Tensor): A tensor of timestamps before the missing part.
        timestamps_after (torch.Tensor): A tensor of timestamps after the missing part.
        num_interpolations (int): The number of interpolation steps.

    Returns:
        list: A list of interpolated sequences.
    """
    joint_angles_before, joint_angles_after = joint_angles_before.unsqueeze(0), joint_angles_after.unsqueeze(0)
    timestamps_before, timestamps_after = timestamps_before.unsqueeze(0), timestamps_after.unsqueeze(0)
    print(joint_angles_before.shape, joint_angles_after.shape, timestamps_before.shape, timestamps_after.shape)
    # Encode the sequences before and after the missing part
    encoder_outputs_before = model.encoder(joint_angles_before, timestamps_before)
    encoder_outputs_after = model.encoder(joint_angles_after, timestamps_after)

    # Get the latent primitives
    latents_before = encoder_outputs_before["latent_primitives"]
    # TODO: latents_middle !!!
    # TODO: or single latents and sample middle part before
    # TODO: or average latents_before and latents_after, since time is encoded implicitly
    latents_after = encoder_outputs_after["latent_primitives"]
    print(f"{latents_before.shape=}, {latents_after.shape=}")

    interpolated_sequences = []

    timestamps = torch.linspace(0, 1, 128).unsqueeze(0)
    # TODO: How to interpolate between the latent primitives?
    for i in range(num_interpolations):
        alpha = i / (num_interpolations - 1)  # linear interpolation coefficient
        interpolated_latents = alpha * latents_before + (1 - alpha) * latents_after

        # Decode the interpolated latents
        decoder_outputs = model.decoder(timestamps, interpolated_latents)  # assuming timestamps_before and timestamps_after are the same

        # Get the reconstructed sequence
        recons_sequence = decoder_outputs["recons_sequence"]

        interpolated_sequences.append(recons_sequence)

    return interpolated_sequences


In [22]:
poses_before = poses[:59, :]
poses_after = poses[70:, :]
poses_before.shape, poses_after.shape

(torch.Size([59, 3]), torch.Size([58, 3]))

In [23]:
timestamps = item["timestamps"]
timestamps_before = timestamps[:59]
timestamps_after = timestamps[70:]
timestamps_before.shape, timestamps_after.shape

(torch.Size([59]), torch.Size([58]))

In [27]:
interpolate_sequences = interpolate_sequences(model, poses_before, poses_after, timestamps_before, timestamps_after, num_interpolations=10)

torch.Size([1, 59, 3]) torch.Size([1, 58, 3]) torch.Size([1, 59]) torch.Size([1, 58])
pose_embeddings.shape=torch.Size([1, 59, 128])
self.positional_encoding(timestamps).shape=torch.Size([1, 59, 128])
pose_embeddings.shape=torch.Size([1, 58, 128])
self.positional_encoding(timestamps).shape=torch.Size([1, 58, 128])
latents_before.shape=torch.Size([1, 6, 128]), latents_after.shape=torch.Size([1, 6, 128])


In [25]:
torch.stack(interpolate_sequences).shape

torch.Size([10, 1, 128, 3])

In [13]:
imgs = []
for rec in interpolate_sequences:
    rec = unnormalize_pose(rec)
    img = render_image(rec, BONE_LENGTHS)
    imgs.append(img)

Exception: Number of angles and bone_lengths should be the same but: 1 is not 3

In [13]:
poses, timestamps = item["poses"], item["timestamps"]
poses[48:80, :] = 0
timestamps[48:80] = 0
poses, timestamps

(tensor([[0.0711, 0.6657, 0.3384],
         [0.0488, 0.6581, 0.3320],
         [0.0276, 0.6545, 0.3330],
         [0.0076, 0.6548, 0.3411],
         [0.9891, 0.6581, 0.3555],
         [0.9723, 0.6630, 0.3746],
         [0.9573, 0.6678, 0.3969],
         [0.9443, 0.6716, 0.4202],
         [0.9336, 0.6740, 0.4424],
         [0.9252, 0.6754, 0.4617],
         [0.9193, 0.6766, 0.4765],
         [0.9161, 0.6783, 0.4862],
         [0.9154, 0.6811, 0.4908],
         [0.9175, 0.6846, 0.4914],
         [0.9224, 0.6880, 0.4894],
         [0.9300, 0.6899, 0.4863],
         [0.9402, 0.6892, 0.4833],
         [0.9531, 0.6850, 0.4810],
         [0.9685, 0.6773, 0.4790],
         [0.9862, 0.6671, 0.4766],
         [0.0062, 0.6554, 0.4725],
         [0.0281, 0.6435, 0.4659],
         [0.0518, 0.6322, 0.4564],
         [0.0771, 0.6213, 0.4442],
         [0.1036, 0.6102, 0.4303],
         [0.1312, 0.5979, 0.4158],
         [0.1595, 0.5833, 0.4019],
         [0.1883, 0.5663, 0.3897],
         [0.2173, 0.

In [14]:
# poses, timestamps = poses.unsqueeze(0), timestamps.unsqueeze(0)
y_hat = model.infer(poses, timestamps)

timestamps=tensor([[0.0000, 0.0079, 0.0157, 0.0236, 0.0315, 0.0394, 0.0472, 0.0551, 0.0630,
         0.0709, 0.0787, 0.0866, 0.0945, 0.1024, 0.1102, 0.1181, 0.1260, 0.1339,
         0.1417, 0.1496, 0.1575, 0.1654, 0.1732, 0.1811, 0.1890, 0.1969, 0.2047,
         0.2126, 0.2205, 0.2283, 0.2362, 0.2441, 0.2520, 0.2598, 0.2677, 0.2756,
         0.2835, 0.2913, 0.2992, 0.3071, 0.3150, 0.3228, 0.3307, 0.3386, 0.3465,
         0.3543, 0.3622, 0.3701, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6299,
         0.6378, 0.6457, 0.6535, 0.6614, 0.6693, 0.6772, 0.6850, 0.6929, 0.7008,
         0.7087, 0.7165, 0.7244, 0.7323, 0.7402, 0.7480, 0.7559, 0.7638, 0.7717,
         0.7795, 0.7874, 0.7953, 0.8031, 0.8110, 0.8189, 0.8268, 0.8346, 0.8425,
         0.8504, 