In [1]:
from src.models import SphericalVAE, VariationalAutoencoder
from src.data import SyntheticS2
from src.utils import plot_3d

from plotly.subplots import make_subplots
import torch

synthetic_s2 = SyntheticS2(test=True)

fig = make_subplots(
    rows=2, 
    cols=2,     
    specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}],
           [{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
    subplot_titles=(
        "True latent representation", 
        "MDS of synth. data", "Ordinary VAE", 
        "Spherical VAE"),
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

X_latent = synthetic_s2.X_latent
X = synthetic_s2.X
classes = synthetic_s2.y

feature_dim = X.shape[-1]

svae = SphericalVAE(feature_dim=feature_dim, latent_dim=3)
svae.to(torch.double)
svae_state_dict = torch.load('../models/svae.pt')
svae.load_state_dict(svae_state_dict)

vae = VariationalAutoencoder(feature_dim, 3)
vae.to(torch.double)
vae_state_dict = torch.load('../models/vae.pt')
vae.load_state_dict(vae_state_dict)

plot_3d( *X_latent.T, classes, fig=fig, row=1, col=1)

# Data

In [2]:
from sklearn.manifold import MDS
mds = MDS(n_components=3)
X_transformed = mds.fit_transform(X.numpy())
plot_3d( *X_transformed.T, classes, fig=fig, row=1, col=2)

# Spherical VAE

In [3]:
output = svae(X)
plot_3d( *output["z"].T.detach().numpy(), classes, fig=fig, row=2, col=2)

# Ordinary VAE

In [4]:
output = vae(X)
plot_3d( *output["z"].T.detach().numpy(), classes, fig=fig, row=2, col=1)

# All together

In [5]:
fig.update_layout(width=800, height=800)