In [1]:
import os, sys
sys.path.append("..")
from config.load_config import load_yaml_config, to_dict, recursive_namespace
import torch
import numpy as np
from torch import Tensor
from importlib import reload

import pytorch_lightning as pl

import models.Model3D as Model3D; Model3D = reload(module=Model3D)
import models.Model4D as Model4D; Model4D = reload(module=Model4D)
import models.lightning.EncoderLightningModule as EncoderLightningModule

EncoderLightningModule = reload(module=EncoderLightningModule)
EncoderTemporalSequence = Model4D.EncoderTemporalSequence
CineComaEncoder = EncoderLightningModule.CineComaEncoder
DecoderStyle = Model4D.DecoderStyle
DecoderTemporalSequence = Model4D.DecoderTemporalSequence
AutoencoderTemporalSequence = Model4D.AutoencoderTemporalSequence

## Generate dataset

In [2]:
from data.synthetic.SyntheticMeshPopulation import SyntheticMeshPopulation
from data.SyntheticDataModules import SyntheticMeshesDataset
from torch.utils.data import DataLoader
from utils.helpers import get_datamodule, get_coma_args

config = load_yaml_config("../config_files/config_folded_c_and_s.yaml")
dm = get_datamodule(config)
coma_args = get_coma_args(config, dm)
coma_args["phase_input"] = False

Retrieving synthetic population from cached file.


___

### `EncoderTemporalSequence`

*Status: tested and working*

In [3]:
enc_config = {k: v for k,v in coma_args.items() if k in Model4D.ENCODER_ARGS}
cine_encoder = EncoderTemporalSequence(enc_config, z_aggr_function="DFT", n_timeframes=20)

Let's test the model on a single datapoint:

In [4]:
datapoint = next(iter(dm.train_dataloader()))
s_t = datapoint["s_t"]
s_t.shape

torch.Size([16, 20, 1002, 3])

In [None]:
mu = cine_encoder(s_t)["mu"]
mu.shape

___

### `DecoderStyle`

*Status: tested and working*

In [5]:
dec_c_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_C_ARGS}
dec_s_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_S_ARGS}

In [6]:
decoder_s = DecoderStyle(dec_s_config, phase_embedding_method="exp")

In [7]:
zc_dim = dec_s_config["latent_dim_content"]
zs_dim = dec_s_config["latent_dim_style"]
zc = Tensor(np.random.random((config.batch_size, zc_dim)))
zs = Tensor(np.random.random((config.batch_size, zs_dim)))

In [8]:
s_hat = decoder_s(zc, zs, coma_args["n_timeframes"])

### `DecoderTemporalSequence`

In [9]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")

___

### `AutoencoderTemporalSequence`

In [None]:
ae_config = [enc_config, dec_c_config, dec_s_config]

In [None]:
autoencoder_4d = AutoencoderTemporalSequence(*ae_config, z_aggr_function="dft", n_timeframes=coma_args["n_timeframes"])

In [None]:
avg_shat, shat_t = autoencoder_4d(s_t)

In [None]:
print(avg_shat.shape)
print(shat_t.shape)

___

# PyTorch Lightning modules

Encoder

Decoder

In [12]:
import models.lightning.DecoderLightningModule as decoder_lightning; decoder_lightning = reload(module=decoder_lightning)
TemporalDecoderLightning = decoder_lightning.TemporalDecoderLightning

In [13]:
dec_c_config["latent_dim_content"] = dec_s_config["latent_dim_content"] = 9
dec_s_config["latent_dim_style"] = 36

In [15]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")
dec_pl = TemporalDecoderLightning(decoder_4d, config)

In [None]:
trainer = pl.Trainer()
trainer.fit(dec_pl, dm)

Full autoencoder

In [None]:
# PLEncoder = CineComaEncoder(cine_encoder, config)
# trainer = pl.Trainer()
# trainer.fit(PLEncoder, dm)