In [8]:
"""In this file, we reproduce the results of the CMVAE model on the PolyMNIST dataset."""

from CMVAE_architectures import Dec, Enc, load_mmnist_classifiers
import torch

from multivae.data.datasets.mmnist import MMNISTDataset
from multivae.models.cmvae import CMVAE, CMVAEConfig
from multivae.trainers.base import BaseTrainer, BaseTrainerConfig
from multivae.trainers.base.callbacks import ProgressBarCallback, WandbCallback
from pythae.data.datasets import DatasetOutput

###### Set the paths for loading and saving ######
DATA_PATH = "/home/bigdata/siyi/data"
SAVE_PATH = "/home/siyi/project/mm/result/runs"

In [2]:
###### Set the paths for loading and saving ######
DATA_PATH = "/home/asenella/data"
SAVE_PATH = "/home/asenella/experiments"

###### Define model configuration ########
modalities = ["m0", "m1", "m2", "m3", "m4"]

model_config = CMVAEConfig(
    n_modalities=5,
    K=1,
    decoders_dist={m: "laplace" for m in modalities},
    decoder_dist_params={m: dict(scale=0.75) for m in modalities},
    prior_and_posterior_dist="laplace_with_softmax",
    beta=2.5,
    modalities_specific_dim=32,
    latent_dim=32,
    input_dims={m: (3, 28, 28) for m in modalities},
    learn_modality_prior=True,
    number_of_clusters=40,
    loss="iwae_looser",
)

encoders = {
    m: Enc(model_config.modalities_specific_dim, ndim_u=model_config.latent_dim)
    for m in modalities
}
decoders = {
    m: Dec(model_config.latent_dim + model_config.modalities_specific_dim)
    for m in modalities
}

model = CMVAE(model_config, encoders, decoders)

In [3]:
train_data = MMNISTDataset('/bigdata/siyi/data', modalities)

In [5]:
state_dict = torch.load("/home/siyi/project/mm/result/Dynamic_project/PM51/reproduce_cmvae/K__1/CMVAE_training_2025-09-01_16-26-14/final_model/model.pt")

In [6]:
model.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [11]:
cond_mods = {'m1': train_data[0]['data']['m1'].unsqueeze(0), 'm2': train_data[0]['data']['m2'].unsqueeze(0)}
cond_mods_name = ['m1', 'm2']
gen_mods_name = ['m0', 'm3', 'm4']
out = model.encode(inputs=DatasetOutput(data=cond_mods), cond_mod=cond_mods_name, gen_mod=gen_mods_name)

In [10]:
out

ModelOutput([('z',
              tensor([[-1.3466, -4.7154,  0.1109,  0.6420, -0.0188, -3.4930, -1.4929,  0.7362,
                        2.2687, -1.4370, -1.5727, -1.4706, -0.7566, -4.5920,  1.8497, -1.8275,
                        2.0954,  2.2184,  4.6179,  0.4313,  4.8770,  3.3431,  0.9605, -0.9490,
                        0.5719, -1.7380,  0.8560, -6.7802,  0.1557, -2.8408, -0.0209,  0.7917]],
                     grad_fn=<SubBackward0>)),
             ('one_latent_space', False),
             ('modalities_z',
              {'m0': tensor([[ 1.2942e-01, -6.9427e-01, -1.3979e+00,  2.6601e-01, -4.2426e+00,
                         1.1081e-01,  7.0461e-01, -1.4765e+00,  1.4961e+00, -6.1687e-01,
                        -6.1294e-01, -9.0049e-01, -8.0148e-02, -2.9790e-01, -2.8714e-01,
                        -1.1054e-01,  1.5925e-01, -6.0271e+00,  5.9336e-01, -6.3538e-01,
                         1.7974e-03,  3.9869e+00, -1.2088e-01,  1.7324e+00, -9.1827e-01,
                        -7.10