In [None]:
%load_ext autoreload
%autoreload 2

from src.data.mocap import MotionCaptureDataset
from src.models.mocap import SphericalVAE, VariationalAutoencoder
from sklearn.decomposition import PCA
from src.utils import plot_3d
import optuna
import torch

In [None]:
vae_study = optuna.load_study(
    study_name="mocap-07-vae",
    storage="sqlite:///../runs/mocap-07-vae/optuna-storage.db"
)
best_vae_trial = vae_study.best_trial
svae_study = optuna.load_study(
    study_name="mocap-07-svae",
    storage="sqlite:///../runs/mocap-07-svae/optuna-storage.db"
)
best_svae_trial = svae_study.best_trial

In [None]:
dataset = MotionCaptureDataset("07")
n_features = dataset.n_features
X = dataset.X

In [None]:
vae_params = vae_study.best_trial.params
vae_layer_sizes = [vae_params[f"layer_size_{i+1}"] for i in range(vae_params["n_layers"])]
vae_dropout = vae_study.best_trial.params["dropout"]
vae_number = vae_study.best_trial.number

svae_params = svae_study.best_trial.params
svae_layer_sizes = [svae_params[f"layer_size_{i+1}"] for i in range(svae_params["n_layers"])]
svae_dropout = svae_study.best_trial.params["dropout"]
svae_number = svae_study.best_trial.number

In [None]:
svae = SphericalVAE(
    feature_dim=n_features,
    latent_dim=3,
    encoder_params={
        "layer_sizes" : svae_layer_sizes,
        "dropout": svae_dropout,
        "activation_function" : "Tanh"
    },
    decoder_params={
        "layer_sizes" : svae_layer_sizes[::-1],
        "dropout": svae_dropout,
        "activation_function" : "Tanh"
    },
)
svae_state_dict = torch.load(f"../runs/mocap-07-svae/checkpoints/{svae_number:03}.pt")
svae.load_state_dict(svae_state_dict)

In [None]:
vae = VariationalAutoencoder(
    feature_dim=n_features,
    latent_dim=3,
    encoder_params={
        "layer_sizes" : vae_layer_sizes,
        "dropout": vae_dropout,
        "activation_function" : "Tanh"
    },
    decoder_params={
        "layer_sizes" : vae_layer_sizes[::-1],
        "dropout": vae_dropout,
        "activation_function" : "Tanh"
    },
)
vae_state_dict = torch.load(f"../runs/mocap-07-vae/checkpoints/{vae_number:03}.pt")
vae.load_state_dict(vae_state_dict)

In [None]:
import re
pattern = re.compile("_(\d+):")
classes = [int(pattern.search(s).group(1)) for s in dataset.labels]

In [None]:
x, y, z = PCA(n_components=3).fit_transform(X).T
fig = plot_3d(x, y, z, classes = classes)
fig.show()

In [None]:
output = vae(X)
plot_3d( *output["z"].T.detach().numpy(), classes)

In [None]:
output = svae(X)
plot_3d( *output["z"].T.detach().numpy(), classes)