In [37]:
import os, sys
sys.path.append("..")
from config.load_config import load_yaml_config, to_dict, recursive_namespace
import torch
import numpy as np
from torch import Tensor
from importlib import reload

import pytorch_lightning as pl

import models.Model3D as Model3D; Model3D = reload(module=Model3D)
import models.Model4D as Model4D; Model4D = reload(module=Model4D)
import models.lightning.EncoderLightningModule as EncoderLightningModule

EncoderLightningModule = reload(module=EncoderLightningModule)
EncoderTemporalSequence = Model4D.EncoderTemporalSequence
CineComaEncoder = EncoderLightningModule.CineComaEncoder
DecoderStyle = Model4D.DecoderStyle
DecoderTemporalSequence = Model4D.DecoderTemporalSequence
AutoencoderTemporalSequence = Model4D.AutoencoderTemporalSequence

## Generate dataset

In [38]:
from data.synthetic.SyntheticMeshPopulation import SyntheticMeshPopulation
from data.SyntheticDataModules import SyntheticMeshesDataset
from torch.utils.data import DataLoader
from utils.helpers import get_datamodule, get_coma_args

config = load_yaml_config("../config_files/config_folded_c_and_s.yaml")
dm = get_datamodule(config)
coma_args = get_coma_args(config, dm)
coma_args["phase_input"] = False

Retrieving synthetic population from cached file.


___

### `EncoderTemporalSequence`

*Status: tested and working*

In [5]:
enc_config = {k: v for k,v in coma_args.items() if k in Model4D.ENCODER_ARGS}
cine_encoder = EncoderTemporalSequence(enc_config, z_aggr_function="DFT", n_timeframes=20)

Let's test the model on a single datapoint:

In [6]:
datapoint = next(iter(dm.train_dataloader()))
s_t = datapoint["s_t"]
s_t.shape

torch.Size([16, 20, 1002, 3])

In [7]:
mu = cine_encoder(s_t)["mu"]
mu.shape

torch.Size([16, 8])

___

### `DecoderStyle`

*Status: tested and working*

In [15]:
dec_c_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_C_ARGS}
dec_s_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_S_ARGS}

In [16]:
decoder_s = DecoderStyle(dec_s_config, phase_embedding_method="exp")

In [17]:
zc_dim = dec_s_config["latent_dim_content"]
zs_dim = dec_s_config["latent_dim_style"]
zc = Tensor(np.random.random((config.batch_size, zc_dim)))
zs = Tensor(np.random.random((config.batch_size, zs_dim)))

In [18]:
s_hat = decoder_s(zc, zs, coma_args["n_timeframes"])

### `DecoderTemporalSequence`

In [19]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")

___

### `AutoencoderTemporalSequence`

In [19]:
ae_config = [enc_config, dec_c_config, dec_s_config]

In [20]:
autoencoder_4d = AutoencoderTemporalSequence(*ae_config, z_aggr_function="dft", n_timeframes=coma_args["n_timeframes"])

In [21]:
avg_shat, shat_t = autoencoder_4d(s_t)

In [22]:
print(avg_shat.shape)
print(shat_t.shape)

torch.Size([16, 1002, 3])
torch.Size([16, 20, 1002, 3])


___

# PyTorch Lightning modules

Encoder

Decoder

In [58]:
import models.lightning.DecoderLightningModule as decoder_lightning; decoder_lightning = reload(module=decoder_lightning)
TemporalDecoderLightning = decoder_lightning.TemporalDecoderLightning

In [63]:
dec_c_config["latent_dim_content"] = 9
dec_s_config["latent_dim_content"] = 9
dec_s_config["latent_dim_style"] = 36

In [64]:
decoder_4d = DecoderTemporalSequence(dec_c_config, dec_s_config, phase_embedding_method="exp")
dec_pl = TemporalDecoderLightning(decoder_4d, config)

In [65]:
trainer = pl.Trainer()
trainer.fit(dec_pl, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | DecoderTemporalSequence | 207 K 
--------------------------------------------------
207 K     Trainable params
0         Non-trainable params
207 K     Total params
0.831     Total estimated model params size (MB)


A Jupyter Widget
Python 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:59) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.2.0 -- An enhanced Interactive Python. Type '?' for help.



In [1]:  quit()



Python 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:59) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.2.0 -- An enhanced Interactive Python. Type '?' for help.



In [1]:  quit()



A Jupyter Widget


  rank_zero_warn(


Python 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:59) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.2.0 -- An enhanced Interactive Python. Type '?' for help.



In [1]:  quit()





  rank_zero_deprecation(


Python 3.9.12 | packaged by conda-forge | (main, Mar 24 2022, 23:25:59) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.2.0 -- An enhanced Interactive Python. Type '?' for help.



  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Full autoencoder

In [17]:
# PLEncoder = CineComaEncoder(cine_encoder, config)
# trainer = pl.Trainer()
# trainer.fit(PLEncoder, dm)