In [3]:
import os, sys
sys.path.append("..")
from config.load_config import load_yaml_config, to_dict, recursive_namespace
import torch
import numpy as np
from torch import Tensor
from importlib import reload

import pytorch_lightning as pl

import models.Model3D as Model3D; Model3D = reload(module=Model3D)
import models.Model4D as Model4D; Model4D = reload(module=Model4D)
import models.lightning.EncoderLightningModule as encoder_lightning; EncoderLightningModule = reload(module=EncoderLightningModule)
import models.lightning.DecoderLightningModule as decoder_lightning; decoder_lightning = reload(module=decoder_lightning)

EncoderTemporalSequence = Model4D.EncoderTemporalSequence
DecoderTemporalSequence = Model4D.DecoderTemporalSequence
AutoencoderTemporalSequence = Model4D.AutoencoderTemporalSequence
DecoderStyle = Model4D.DecoderStyle
TemporalEncoderLightning = encoder_lightning.TemporalEncoderLightning
TemporalDecoderLightning = decoder_lightning.TemporalDecoderLightning

## Generate dataset

In [4]:
from data.synthetic.SyntheticMeshPopulation import SyntheticMeshPopulation
from data.SyntheticDataModules import SyntheticMeshesDataset
from torch.utils.data import DataLoader
from utils.helpers import get_datamodule, get_coma_args

config = load_yaml_config("../config_files/config_folded_c_and_s.yaml")
dm = get_datamodule(config)
coma_args = get_coma_args(config, dm)
coma_args["phase_input"] = False

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 640/640 [00:32<00:00, 19.65it/s]


___

### `EncoderTemporalSequence`

*Status: tested and working*

In [None]:
enc_config = {k: v for k,v in coma_args.items() if k in Model4D.ENCODER_ARGS}
cine_encoder = EncoderTemporalSequence(enc_config, z_aggr_function="DFT", n_timeframes=20)

Let's test the model on a single datapoint:

In [None]:
datapoint = next(iter(dm.train_dataloader()))
s_t = datapoint["s_t"]
s_t.shape

In [None]:
mu = cine_encoder(s_t)["mu"]
mu.shape

___

### `DecoderStyle`

*Status: tested and working*

In [None]:
dec_c_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_C_ARGS}
dec_s_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_S_ARGS}

In [None]:
decoder_s = DecoderStyle(dec_s_config, phase_embedding_method="exp")

In [None]:
zc_dim = dec_s_config["latent_dim_content"]
zs_dim = dec_s_config["latent_dim_style"]
zc = Tensor(np.random.random((config.batch_size, zc_dim)))
zs = Tensor(np.random.random((config.batch_size, zs_dim)))

In [None]:
s_hat = decoder_s(zc, zs, coma_args["n_timeframes"])

### `DecoderTemporalSequence`

In [None]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")

___

### `AutoencoderTemporalSequence`

In [None]:
ae_config = [enc_config, dec_c_config, dec_s_config]

In [None]:
autoencoder_4d = AutoencoderTemporalSequence(*ae_config, z_aggr_function="dft", n_timeframes=coma_args["n_timeframes"])

In [None]:
avg_shat, shat_t = autoencoder_4d(s_t)

In [None]:
print(avg_shat.shape)
print(shat_t.shape)

___

# PyTorch Lightning modules

In [5]:
coma_args

{'num_features': 3,
 'n_layers': 4,
 'num_conv_filters_enc': [16, 16, 32, 32],
 'num_conv_filters_dec_c': [16, 16, 32, 32],
 'num_conv_filters_dec_s': [16, 16, 32, 32],
 'cheb_polynomial_order': [6, 6, 6, 6],
 'latent_dim_content': 9,
 'latent_dim_style': 27,
 'is_variational': True,
 'mode': 'testing',
 'n_timeframes': 20,
 'phase_input': False,
 'z_aggr_function': 'FCN',
 'downsample_matrices': [tensor(indices=tensor([[   0,    1,    2,  ...,  498,  499,  500],
                         [   4,    8,    9,  ...,  998, 1000, 1001]]),
         values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1., 1

Encoder

Decoder

In [None]:
dec_c_config["latent_dim_content"] = dec_s_config["latent_dim_content"] = 9
dec_s_config["latent_dim_style"] = 36

In [None]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")
dec_pl = TemporalDecoderLightning(decoder_4d, config)

In [None]:
trainer = pl.Trainer()
trainer.fit(dec_pl, dm)

Full autoencoder

In [None]:
# PLEncoder = CineComaEncoder(cine_encoder, config)
# trainer = pl.Trainer()
# trainer.fit(PLEncoder, dm)