In [1]:
import os, sys
sys.path.append("..")
from config.load_config import load_yaml_config, to_dict, recursive_namespace
import torch
import numpy as np
from torch import Tensor
from importlib import reload

import pytorch_lightning as pl

import models.Model3D as Model3D; Model3D = reload(module=Model3D)
import models.Model4D as Model4D; Model4D = reload(module=Model4D)
import models.EncoderPLModule as EncoderPLModule

EncoderPLModule = reload(module=EncoderPLModule)
EncoderTemporalSequence = Model4D.EncoderTemporalSequence
CineComaEncoder = EncoderPLModule.CineComaEncoder
DecoderStyle = Model4D.DecoderStyle
DecoderTemporalSequence = Model4D.DecoderTemporalSequence
AutoencoderTemporalSequence = Model4D.AutoencoderTemporalSequence

### Generate dataset

In [18]:
from data.synthetic.SyntheticMeshPopulation import SyntheticMeshPopulation
from data.SyntheticDataModules import SyntheticMeshesDataset
from torch.utils.data import DataLoader
from utils.helpers import get_datamodule, get_coma_args

config = load_yaml_config("../config_files/config_folded_c_and_s.yaml")
dm = get_datamodule(config)
coma_args = get_coma_args(config, dm)
coma_args["phase_input"] = False

Retrieving synthetic population from cached file.


___

### `EncoderTemporalSequence`

*Status: tested and working*

In [5]:
enc_config = {k: v for k,v in coma_args.items() if k in Model4D.ENCODER_ARGS}
cine_encoder = EncoderTemporalSequence(enc_config, z_aggr_function="DFT", n_timeframes=20)

Let's test the model on a single datapoint:

In [6]:
datapoint = next(iter(dm.train_dataloader()))
s_t = datapoint["s_t"]
s_t.shape

torch.Size([16, 20, 1002, 3])

In [7]:
mu = cine_encoder(s_t)["mu"]
mu.shape

torch.Size([16, 8])

___

### `DecoderStyle`

*Status: tested and working*

In [8]:
dec_c_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_C_ARGS}
dec_s_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_S_ARGS}

In [9]:
decoder_s = DecoderStyle(dec_s_config, phase_embedding_method="exp")

In [10]:
zc_dim = dec_s_config["latent_dim_content"]
zs_dim = dec_s_config["latent_dim_style"]
zc = Tensor(np.random.random((config.batch_size, zc_dim)))
zs = Tensor(np.random.random((config.batch_size, zs_dim)))

In [11]:
s_hat = decoder_s(zc, zs, coma_args["n_timeframes"])

### `DecoderTemporalSequence`

In [12]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")

___

### `AutoencoderTemporalSequence`

In [19]:
ae_config = [enc_config, dec_c_config, dec_s_config]

In [20]:
autoencoder_4d = AutoencoderTemporalSequence(*ae_config, z_aggr_function="dft", n_timeframes=coma_args["n_timeframes"])

In [21]:
avg_shat, shat_t = autoencoder_4d(s_t)

In [22]:
print(avg_shat.shape)
print(shat_t.shape)

torch.Size([16, 1002, 3])
torch.Size([16, 20, 1002, 3])


___

# PyTorch Lightning module

In [17]:
# PLEncoder = CineComaEncoder(cine_encoder, config)
# trainer = pl.Trainer()
# trainer.fit(PLEncoder, dm)