In [146]:
import os, sys
sys.path.append("..")
from config.load_config import load_yaml_config, to_dict, recursive_namespace
import torch
import numpy as np
from torch import Tensor
from importlib import reload

import pytorch_lightning as pl

### Generate dataset

In [147]:
from data.synthetic.SyntheticMeshPopulation import SyntheticMeshPopulation
from data.SyntheticDataModules import SyntheticMeshesDataset
from torch.utils.data import DataLoader
from utils.helpers import get_datamodule, get_coma_args

In [148]:
config = load_yaml_config("../config_files/config_folded_c_and_s.yaml")
dm = get_datamodule(config)
coma_args = get_coma_args(config, dm)
coma_args["phase_input"] = False


Retrieving synthetic population from cached file.


### Load (or reload) modules

In [149]:
import models.Model3D as Model3D; Model3D = reload(module=Model3D)
import models.Model4D as Model4D; Model4D = reload(module=Model4D)
import models.EncoderPLModule as EncoderPLModule

EncoderPLModule = reload(module=EncoderPLModule)
EncoderTemporalSequence = Model4D.EncoderTemporalSequence
CineComaEncoder = EncoderPLModule.CineComaEncoder
DecoderStyle = Model4D.DecoderStyle
AutoencoderTemporalSequence = Model4D.AutoencoderTemporalSequence

In [150]:
enc_config = {k: v for k,v in coma_args.items() if k in Model3D.ENCODER_ARGS}

___

### `EncoderTemporalSequence`

In [151]:
cine_encoder = EncoderTemporalSequence(enc_config, z_aggr_function="DFT", n_timeframes=20)

Let's test the model on a single datapoint:

In [152]:
datapoint = next(iter(dm.train_dataloader()))
s_t = datapoint["s_t"]
mu = cine_encoder(s_t)["mu"]

In [157]:
mu.shape

torch.Size([16, 4])

___

### `DecoderStyle`

In [153]:
dec_s_config = {k: v for k,v in coma_args.items() if k in Model4D.DECODER_S_ARGS}

In [154]:
decoder_s = DecoderStyle(dec_s_config, phase_embedding_method="exp")

In [158]:
decoder_s

DecoderStyle(
  (phase_tensor): PhaseTensor()
  (decoder_3d): Decoder3DMesh(
    (dec_lin): Linear(in_features=12, out_features=2016, bias=True)
    (layers): ModuleDict(
      (layer_0): ModuleDict(
        (activation_function): ReLU()
        (pool): Pool()
        (graph_conv): ChebConv_Coma(32, 32, K=6, normalization=None)
      )
      (layer_1): ModuleDict(
        (activation_function): ReLU()
        (pool): Pool()
        (graph_conv): ChebConv_Coma(32, 16, K=6, normalization=None)
      )
      (layer_2): ModuleDict(
        (activation_function): ReLU()
        (pool): Pool()
        (graph_conv): ChebConv_Coma(16, 16, K=6, normalization=None)
      )
      (layer_3): ModuleDict(
        (activation_function): ReLU()
        (pool): Pool()
        (graph_conv): ChebConv_Coma(16, 3, K=6, normalization=None)
      )
    )
  )
)

In [155]:
zc_dim = dec_s_config["latent_dim_content"]
zs_dim = dec_s_config["latent_dim_style"]
zc = Tensor(np.random.random(zc_dim)).unsqueeze(0)
zs = Tensor(np.random.random(zs_dim)).unsqueeze(0)
s_hat = decoder_s(zc, zs, coma_args["n_timeframes"])

In [156]:
s_hat.shape

torch.Size([1, 20, 1002, 3])

In [None]:
DecoderStyle()

### `DecoderTemporalSequence`

___

In [None]:
PLEncoder = CineComaEncoder(cine_encoder, config)
trainer = pl.Trainer()

In [None]:
trainer.fit(PLEncoder, dm)

___

### `AutoencoderTemporalSequence`