In [None]:
import os, sys
os.chdir("..")

In [None]:
sys.path.append(os.getcwd())

In [None]:
# from models.Model3D import Encoder3DMesh, Decoder3DMesh
# from .PhaseModule import PhaseTensor
# from .TemporalAggregators import Mean_Aggregator, DFT_Aggregator, FCN_Aggregator

In [None]:
from easydict import EasyDict
from models.layers import ChebConv_Coma, Pool
from typing import Sequence, Union, List
from copy import copy

import torch
from torch import nn
from torch.nn import ModuleList, ModuleDict
from torch.fft import rfft

from data.SyntheticDataModules import SyntheticMeshesDataset, SyntheticMeshesDM

import numpy as np
from copy import copy
from typing import Sequence, Union, List

from IPython import embed # left there for debugging if needed

____
## 3D models

In [None]:
#TODO: Implement common parent class for encoder and decoder (GraphConvStack?), to capture common behaviour.

################# FULL AUTOENCODER #################

class Autoencoder3DMesh(nn.Module):

    def __init__(self, enc_config, dec_config):

        super(Autoencoder3DMesh, self).__init__()

        self.encoder = Encoder3DMesh(**enc_config)
        self.decoder = Decoder3DMesh(**dec_config)

    def forward(self, x):

        mu, logvar = self.encoder(x)
        # Add sampling if is_variational == True and it's in training mode
        x_hat = self.decoder(mu)
        return x_hat

################# ENCODER #################

ENCODER_ARGS = [
    "num_features",
    "n_layers",
    "n_nodes",
    "num_conv_filters_enc",
    "cheb_polynomial_order",
    "latent_dim_content",
    "template",
    "is_variational",
    "phase_input",
    "downsample_matrices",
    "adjacency_matrices",
    "activation_layers"
]

class Encoder3DMesh(nn.Module):

    '''
    '''

    def __init__(self,
        phase_input: bool,
        num_conv_filters_enc: Sequence[int],
        num_features: int,
        cheb_polynomial_order: int,
        n_layers: int,
        n_nodes: int,
        is_variational: bool,
        latent_dim: int,
        template,
        adjacency_matrices: List[torch.Tensor],
        downsample_matrices: List[torch.Tensor],
        activation_layers="ReLU"):

        super(Encoder3DMesh, self).__init__()

        self.n_nodes = n_nodes
        self.phase_input = phase_input
        self.filters_enc = copy(num_conv_filters_enc)
        self.filters_enc.insert(0, num_features)
        self.K = cheb_polynomial_order

        self.matrices = {}
        A_edge_index, A_norm = self._build_adj_matrix(adjacency_matrices)

        self.matrices["A_edge_index"] = A_edge_index
        self.matrices["A_norm"] = A_norm
        self.matrices["downsample"] = downsample_matrices
                
        self._n_features_before_z = self.matrices["downsample"][-1].shape[0] * self.filters_enc[-1]
        self._is_variational = is_variational
        self.latent_dim = latent_dim

        self.activation_layers = [activation_layers] * n_layers if isinstance(activation_layers, str) else activation_layers
        self.layers = self._build_encoder()

        # Fully connected layers connecting the last pooling layer and the latent space layer.
        self.enc_lin_mu = torch.nn.Linear(self._n_features_before_z, self.latent_dim)

        if self._is_variational:
            self.enc_lin_var = torch.nn.Linear(self._n_features_before_z, self.latent_dim)

    def _build_encoder(self):

        cheb_conv_layers = self._build_cheb_conv_layers(self.filters_enc, self.K)
        pool_layers = self._build_pool_layers(self.matrices["downsample"])
        activation_layers = self._build_activation_layers(self.activation_layers)

        encoder = ModuleDict()

        for i in range(len(cheb_conv_layers)):
            layer = f"layer_{i}"
            encoder[layer] = ModuleDict()            
            encoder[layer]["graph_conv"] = cheb_conv_layers[i]
            encoder[layer]["pool"] = pool_layers[i]
            encoder[layer]["activation_function"] = activation_layers[i]

        return encoder

    def _build_pool_layers(self, downsample_matrices:Sequence[np.array]):

        '''
        downsample_matrices: list of matrices binary matrices
        '''

        pool_layers = ModuleList()
        for i in range(len(downsample_matrices)):
            pool_layers.append(Pool())
        return pool_layers


    def _build_activation_layers(self, activation_type:Union[str, Sequence[str]]):

        '''
        activation_type: string or list of strings containing the name of a valid activation function from torch.functional
        '''

        activation_layers = ModuleList()

        for i in range(len(activation_type)):
            activ_fun = getattr(torch.nn.modules.activation, activation_type[i])()
            activation_layers.append(activ_fun)

        return activation_layers


    def _build_cheb_conv_layers(self, n_filters, K):
        # Chebyshev convolutions (encoder)

        #TOFIX: this should be specified in the docs.
        if self.phase_input:
            n_filters[0] = 2 * n_filters[0]

        cheb_enc = torch.nn.ModuleList([ChebConv_Coma(n_filters[0], n_filters[1], K[0])])
        cheb_enc.extend([
            ChebConv_Coma(
                n_filters[i],
                n_filters[i+1],
                K[i]
            ) for i in range(1, len(n_filters)-1)
        ])
        return cheb_enc


    def _build_adj_matrix(self, adjacency_matrices):
        adj_edge_index, adj_norm = zip(*[
            ChebConv_Coma.norm(adjacency_matrices[i]._indices(), self.n_nodes[i])
            for i in range(len(self.n_nodes))
        ])
        return list(adj_edge_index), list(adj_norm)

    
    def concatenate_graph_features(self, x):
        embed()
        x = x.reshape(x.shape[0], self._n_features_before_z)
        return x


    # perform a forward pass only through the convolutional stack (not the FCN layer)
    def forward_conv_stack(self, x, preserve_graph_structure=True):
        
        # a "layer" here is: a graph convolution + pooling operation + activation function
        for i, layer in enumerate(self.layers): 
            
            if self.matrices["downsample"][i].device != x.device:
                self.matrices["downsample"][i] = self.matrices["upsample"][i].to(x.device)
            if self.matrices["A_edge_index"][i].device != x.device:
                self.matrices["A_edge_index"][i] = self.matrices["A_edge_index"][i].to(x.device)
            if self.matrices["A_norm"][i].device != x.device:
                self.matrices["A_norm"][i] = self.matrices["A_norm"][i].to(x.device)
  
            x = self.layers[layer]["graph_conv"](x, self.matrices["A_edge_index"][i], self.matrices["A_norm"][i])
            try:
                x = self.layers[layer]["pool"](x, self.matrices["downsample"][i])
            except:
                embed()
            x = self.layers[layer]["activation_function"](x)
        
        if not preserve_graph_structure:
            x = self.concatenate_graph_features(x)
            
        return x
    
    
    def forward(self, x):

        x = self.forward_conv_stack(x)       
        mu = self.enc_lin_mu(x)
        log_var = self.enc_lin_var(x) if self._is_variational else None        
        return {"mu": mu, "log_var": log_var}

    
################# DECODER #################

DECODER_ARGS = [
    "num_features",
    "n_layers",
    "n_nodes",
    "num_conv_filters_dec",
    "cheb_polynomial_order",
    "latent_dim_content",
    "is_variational",
    "upsample_matrices",
    "adjacency_matrices",
    "activation_layers",
    "template"
]

class Decoder3DMesh(nn.Module):
    
    def __init__(self,
        num_features: int,
        n_layers: int,
        n_nodes: int,
        num_conv_filters_dec: Sequence[int],
        cheb_polynomial_order: int,
        latent_dim: int,
        is_variational: bool,
        template,
        upsample_matrices: List[torch.Tensor],
        adjacency_matrices: List[torch.Tensor],
        activation_layers="ReLU"):

        super(Decoder3DMesh, self).__init__()

        self.n_nodes = n_nodes
        self.filters_dec = copy(num_conv_filters_dec)
        self.filters_dec.insert(0, num_features)
        self.filters_dec = list(reversed(self.filters_dec))

        self.K = cheb_polynomial_order

        self.matrices = {}
        A_edge_index, A_norm = self._build_adj_matrix(adjacency_matrices)
        self.matrices["A_edge_index"] = list(reversed(A_edge_index))
        self.matrices["A_norm"] = list(reversed(A_norm))
        self.matrices["upsample"] = list(reversed(upsample_matrices))

        self._n_features_before_z = self.matrices["upsample"][0].shape[1] * self.filters_dec[0]

        self._is_variational = is_variational
        self.latent_dim = latent_dim

        self.activation_layers = [activation_layers] * n_layers if isinstance(activation_layers, str) else activation_layers

        # Fully connected layer connecting the latent space layer with the first upsampling layer.
        self.dec_lin = torch.nn.Linear(self.latent_dim, self._n_features_before_z)

        self.layers = self._build_decoder()


    def _build_decoder(self):

        cheb_conv_layers = self._build_cheb_conv_layers(self.filters_dec, self.K)
        pool_layers = self._build_pool_layers(self.matrices["upsample"])
        activation_layers = self._build_activation_layers(self.activation_layers)

        decoder = ModuleDict()

        for i in range(len(cheb_conv_layers)):
            layer = f"layer_{i}"
            decoder[layer] = ModuleDict()
            decoder[layer]["activation_function"] = activation_layers[i]
            decoder[layer]["pool"] = pool_layers[i]
            decoder[layer]["graph_conv"] = cheb_conv_layers[i]

        return decoder


    def _build_pool_layers(self, upsample_matrices:Sequence[np.array]):

        '''
        downsample_matrices: list of matrices binary matrices
        '''

        pool_layers = ModuleList()
        for i in range(len(upsample_matrices)):
            pool_layers.append(Pool())
        return pool_layers


    def _build_activation_layers(self, activation_type:Union[str, Sequence[str]]):

        '''
        activation_type: string or list of strings containing the name of a valid activation function from torch.functional
        '''

        activation_layers = ModuleList()

        for i in range(len(activation_type)):
            activ_fun = getattr(torch.nn.modules.activation, activation_type[i])()
            activation_layers.append(activ_fun)

        return activation_layers


    def _build_cheb_conv_layers(self, n_filters, K):
        # Chebyshev convolutions (decoder)
        cheb_dec = torch.nn.ModuleList([ChebConv_Coma(n_filters[0], n_filters[1], K[0])])
        for i in range(1, len(n_filters)-1):
            conv_layer = ChebConv_Coma(n_filters[i], n_filters[i+1], K[i])
            cheb_dec.extend([conv_layer])

        cheb_dec[-1].bias = None  # No bias for last convolution layer
        return cheb_dec


    def _build_adj_matrix(self, adjacency_matrices):
        adj_edge_index, adj_norm = zip(*[
            ChebConv_Coma.norm(adjacency_matrices[i]._indices(), self.n_nodes[i])
            for i in range(len(self.n_nodes))
        ])
        return list(adj_edge_index), list(adj_norm)


    def forward(self, x):

        x = self.dec_lin(x)
        batch_size = x.shape[0] if x.dim() == 2 else 1
        x = x.reshape(batch_size, -1, self.layers["layer_0"]["graph_conv"].in_channels)

        for i, layer in enumerate(self.layers):
            
            if self.matrices["upsample"][i].device != x.device:
                self.matrices["upsample"][i] = self.matrices["upsample"][i].to(x.device)
            if self.matrices["A_edge_index"][i].device != x.device:
                self.matrices["A_edge_index"][i] = self.matrices["A_edge_index"][i].to(x.device)
            if self.matrices["A_norm"][i].device != x.device:
                self.matrices["A_norm"][i] = self.matrices["A_norm"][i].to(x.device)

            x = self.layers[layer]["activation_function"](x)
            x = self.layers[layer]["pool"](x, self.matrices["upsample"][i])
            x = self.layers[layer]["graph_conv"](x, self.matrices["A_edge_index"][i], self.matrices["A_norm"][i])

        return x

## 4D models

### Phase module

In [None]:
class PhaseTensor(nn.Module):

    def __init__(self, version="version_1"):

        super(PhaseTensor, self).__init__()
        self.version = version

    def phase_tensor(self, x):
        '''
        params:
          z: a batched vector (N x T x M)

        returns:
          a phase-aware vector (N x T x 2M)
        '''

        if self.version == "version_1":

            sen_t = []; cos_t = []
            n_timeframes, rank = x.shape[1], x.dim()

            for i in range(n_timeframes):
                phase = 2 * np.pi * i / n_timeframes
                sen_t.append(np.sin(phase))
                cos_t.append(np.cos(phase))

            dims_to_expand = list(range(rank))
            dims_to_expand.remove(1)  # don't expand along the "time" dimension
            dims_to_expand = tuple(dims_to_expand)

            sen_t = np.array(sen_t); cos_t = np.array(cos_t)
            sen_t = np.expand_dims(sen_t, axis=dims_to_expand)
            cos_t = np.expand_dims(cos_t, axis=dims_to_expand)
            sen_t = torch.Tensor(sen_t); cos_t = torch.Tensor(cos_t)
            sen_t = sen_t.type_as(x); cos_t = cos_t.type_as(x)

            phased_x = torch.cat((sen_t * x, cos_t * x), dim=-1)
            phased_x.type_as(x)

        elif self.version == "version_2":

            phased_x = x.type(torch.complex64)
            n_timeframes = x.shape[1]

            for t in range(n_timeframes):
                phase = 2 * np.pi * t / n_timeframes * torch.ones_like(x[:, t, ...])
                phase = torch.FloatTensor(phase)
                phase = phase.type_as(x)  # to(x.device)

                # torch.polar(x, phase) returns x * exp(i * phase), i.e. x as a phasor
                phased_x[:, t, ...] = torch.polar(x[:, t, ...], phase)

            # concatenate sin and cosine along last dimension
            phased_x = torch.cat((phased_x.real, phased_x.imag), dim=-1)

        return phased_x

    def forward(self, x):
        return self.phase_tensor(x)

### Temporal Aggregators

In [None]:
class Mean_Aggregator(nn.Module):

    def forward(self, x):
        return torch.Tensor.mean(x, axis=1)


class FCN_Aggregator(nn.Module):

    def __init__(self, features_in, features_out):
        super(FCN_Aggregator, self).__init__()
        self.fcn = torch.nn.Linear(features_in, features_out)

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1] * x.shape[2])
        return self.fcn(x)


class DFT_Aggregator(nn.Module):

    '''
      x [N, T, ..., F] -> [N, ..., n_comps * F]
    '''

    def __init__(self, features_in, features_out):
        super(DFT_Aggregator, self).__init__()
        self.fcn = torch.nn.Linear(features_in, features_out)

    def forward(self, x):

        x = rfft(x, dim=1)
        # Concatenate features in the frequency domain
        x = x.reshape(x.shape[0], x.shape[1] * x.shape[2])
        x = torch.cat((x.real, x.imag), dim=-1)
        x = self.fcn(x)
        return x

In [None]:
BATCH_DIMENSION = 0
TIME_DIMENSION = 1
NODE_DIMENSION = 2
FEATURE_DIMENSION = 3

COMMON_ARGS = [
    "num_features",
    "n_layers",
    "n_nodes",
    "cheb_polynomial_order",
    "is_variational",
    "adjacency_matrices",
    "activation_layers",
    "template",
]

ENCODER_ARGS = copy(COMMON_ARGS)
ENCODER_ARGS.extend([
  "phase_input",
  "downsample_matrices",
  "num_conv_filters_enc",
  "latent_dim_content",
  "latent_dim_style"
])

DECODER_C_ARGS = copy(COMMON_ARGS)
DECODER_C_ARGS.extend([
  "upsample_matrices",
  "num_conv_filters_dec_c",
  "latent_dim_content"
])

DECODER_S_ARGS = copy(COMMON_ARGS)
DECODER_S_ARGS.extend([
    "upsample_matrices",
    "num_conv_filters_dec_s",
    "latent_dim_content",
    "latent_dim_style",
    "n_timeframes"
])

In [None]:
def _steal_attributes_from_child(self, child: str, attributes: Union[List[str], str]):

    '''
       Make attributes from an object's child visible from the (parent) object's namespace
    '''

    child = getattr(self, child)

    if isinstance(attributes, str):
        attributes = [attributes]

    for attribute in attributes:
        setattr(self, attribute, getattr(child, attribute))
    return self


class AutoencoderTemporalSequence(nn.Module):

    def __init__(self, enc_config, dec_c_config, dec_s_config, z_aggr_function="dft", n_timeframes=None, phase_embedding_method="exp"):

        super(AutoencoderTemporalSequence, self).__init__()
        
        self.encoder = EncoderTemporalSequence(
            enc_config, 
            z_aggr_function, 
            n_timeframes=n_timeframes
        )
        
        self.decoder = DecoderTemporalSequence(
            dec_c_config, 
            dec_s_config, 
            phase_embedding_method
        )

        self._is_variational = self.encoder.encoder_3d_mesh._is_variational

        self.template_mesh = dec_c_config["template"]

                    
    def forward(self, s_t):

        z = self.encoder(s_t)
        avg_s, shat_t = self.decoder(z)
                        
        return z, avg_s, shat_t

    
    def set_mode(self, mode: str):
        '''
        params:
          mode: "training" or "testing"
        '''
        self._mode = mode


##########################################################################################

class EncoderTemporalSequence(nn.Module):

    def __init__(self, encoder_config, z_aggr_function, phase_embedding=None, n_timeframes=None):

        super(EncoderTemporalSequence, self).__init__()
        encoder_config = copy(encoder_config)
        encoder_config["latent_dim"] = encoder_config.pop("latent_dim_content") + encoder_config.pop("latent_dim_style")
        
        self.latent_dim = encoder_config["latent_dim"]
        self.encoder_3d_mesh = Encoder3DMesh(**encoder_config)

        self = _steal_attributes_from_child(self, child="encoder_3d_mesh", attributes=["matrices"])

        self.z_aggr_function = self._get_z_aggr_function(z_aggr_function, n_timeframes)
        self.phase_embedding = phase_embedding


    def _get_z_aggr_function(self, z_aggr_function, n_timeframes=None):

        if z_aggr_function == "mean":
            if phase_embedding is None:
                exit("The temporal aggregation cannot be the mean if phase information is not embedded into the input meshes.")
            z_aggr_function = Mean_Aggregator()

        elif z_aggr_function.lower() in {"fcn", "fully_connected"}:
            self.n_timeframes = n_timeframes
            z_aggr_function = FCN_Aggregator(
                features_in=n_timeframes * self.latent_dim,
                features_out=(self.latent_dim)
            )

        elif z_aggr_function.lower() in {"dft", "discrete_fourier_transform"}:
            self.n_timeframes = n_timeframes
            z_aggr_function = DFT_Aggregator(
                features_in=(n_timeframes // 2 + 1) * 2 * (self.latent_dim),
                features_out=(self.latent_dim)
            )

        return z_aggr_function


    def set_mode(self, mode: str):
        '''
        params:
          mode: "training" or "testing"
        '''
        self._mode = mode


    def encoder(self, x):
                
        self.n_timeframes = x.shape[1]

        # Iterate through time points
        bottleneck_t = [ self.encoder_3d_mesh(x[:, i, :]) for i in range(self.n_timeframes) ]
        mu = [ bottleneck["mu"] for bottleneck in bottleneck_t ]

        # If one element (and therefore all elements) are None, replace the whole thing with None
        log_var = [ bottleneck["log_var"] for bottleneck in bottleneck_t ] if bottleneck_t[0]["log_var"] is not None else None

        mu = torch.cat(mu).reshape(-1, self.n_timeframes, self.latent_dim)
        mu = self.z_aggr_function(mu)

        if log_var is not None:
            log_var_t = torch.cat(log_var).reshape(-1, self.n_timeframes, self.latent_dim)
            log_var = self.z_aggr_function(log_var_t)

        bottleneck = {"mu": mu, "log_var": log_var}
        return bottleneck


    def forward(self, x):
        return self.encoder(x)

    
##########################################################################################

class DecoderStyle(nn.Module):

    def __init__(self, decoder_config: dict, phase_embedding_method: str, n_timeframes: Union[int, None]=None):

        super(DecoderStyle, self).__init__()

        decoder_config = copy(decoder_config)
        self.n_timeframes = decoder_config.pop("n_timeframes")
        self.phase_embedding = self._get_phase_embedding(phase_embedding_method, self.n_timeframes)

        decoder_config = copy(decoder_config)
        decoder_config["latent_dim"] = decoder_config.pop("latent_dim_content") + 2 * decoder_config.pop("latent_dim_style")
        decoder_config["num_conv_filters_dec"] = decoder_config.pop("num_conv_filters_dec_s")

        self.decoder_3d = Decoder3DMesh(**decoder_config)


    def  _get_phase_embedding(self, phase_embedding_method, n_timeframes):

        if phase_embedding_method.lower() in ["inverse_dft", "dft"]:
            raise NotImplementedError

        elif phase_embedding_method.lower() in ["concatenation", "concat"]:
            raise NotImplementedError

        elif phase_embedding_method.lower() in ["exponential_v1", "exp_v1", "exp"]:
            return PhaseTensor(version="version_1")

        elif phase_embedding_method.lower() in ["exponential_v2", "exp_v2"]:
            return PhaseTensor(version="version_2")

        else:
            raise ValueError(f"Method of phase embedding {phase_embedding_method} has not been recognised.")


    def _process_one_timeframe(self, z_c, phased_z_s, t):

        z_s_t = phased_z_s[:, t, ...]
        z = torch.cat([z_c, z_s_t], axis=-1)
        s_t = self.decoder_3d(z)
        s_t = s_t.unsqueeze(1)
        return s_t


    def forward(self, z_c, z_s, n_timeframes):

        phased_z_s = z_s.unsqueeze(TIME_DIMENSION).repeat(1, self.n_timeframes, *[1 for x in z_s.shape[1:]])
        phased_z_s = self.phase_embedding(phased_z_s)
        s_out = [ self._process_one_timeframe(z_c, phased_z_s, t) for t in range(n_timeframes) ]
        s_out = torch.cat(s_out, dim=1)
                
        return s_out



class DecoderTemporalSequence(nn.Module):

    def __init__(self, decoder_c_config, decoder_s_config, phase_embedding_method, n_timeframes=None):

        super(DecoderTemporalSequence, self).__init__()

        decoder_c_config = copy(decoder_c_config)
        decoder_c_config["num_conv_filters_dec"] = decoder_c_config.pop("num_conv_filters_dec_c")
        decoder_c_config["latent_dim"] = decoder_c_config.pop("latent_dim_content")

        self.template_mesh = decoder_c_config["template"]
        self.latent_dim_content = decoder_c_config["latent_dim"]
        self.latent_dim_style = decoder_s_config["latent_dim_style"]

        self.decoder_content = Decoder3DMesh(**decoder_c_config)
        self.decoder_style = DecoderStyle(decoder_s_config, phase_embedding_method, n_timeframes)

        self = _steal_attributes_from_child(self, child="decoder_content", attributes=["matrices"])


    def set_mode(self, mode: str):
        '''
        params:
          mode: "training" or "testing"
        '''
        self._mode = mode


    def forward(self, z):

        bottleneck = self._partition_z(z["mu"], z["log_var"])
        z_c, z_s = bottleneck["mu_c"], bottleneck["mu_s"]
        avg_shape = self.decoder_content(z_c)
        def_field_t = self.decoder_style(z_c, z_s, self.decoder_style.n_timeframes)
        shape_t = avg_shape.unsqueeze(TIME_DIMENSION) + def_field_t
        return avg_shape, shape_t


    def _partition_z(self, mu, log_var=None):

        bottleneck = {
            "mu_c": mu[:, :self.latent_dim_content],
            "mu_s": mu[:, self.latent_dim_content:]
        }

        if log_var is not None:
            bottleneck.update({
                "log_var_c": log_var[:, :self.latent_dim_content],
                "log_var_s": log_var[:, self.latent_dim_content:]
            })

        return bottleneck


### Data modules

In [None]:
from config.load_config import load_yaml_config
config = load_yaml_config("config_files/config_folded_c_and_s.yaml")

In [None]:
# params = { 
#   "N": 100, "T": 20, "mesh_resolution": 10,
#   "l_max": 2, "freq_max": 2, 
#   "amplitude_static_max": 0.3, "amplitude_dynamic_max": 0.1, 
#   "random_seed": 144
# }

# preproc_params = EasyDict({"center_around_mean": False})

mesh_ds = SyntheticMeshesDataset(config.dataset.parameters, config.dataset.preprocessing)
mesh_dm = SyntheticMeshesDM(mesh_ds)
mesh_dm.setup()

In [None]:
len(mesh_ds)

---
# <center> --- Tests --- <center>

In [None]:
from utils.helpers import get_coma_args

In [None]:
coma_args = get_coma_args(config, mesh_dm)

In [None]:
enc_params = {
    "phase_input" : False, 
    "num_conv_filters_enc" : [16, 16, 16, 16], 
    "num_features" : 3,
    "cheb_polynomial_order" : [6, 6, 6, 6],
    "n_layers" : 4,
    "n_nodes" : coma_args.n_nodes,
    "is_variational" : True,
    "latent_dim" : 16,
    "template": coma_args.template,
    "adjacency_matrices": coma_args.adjacency_matrices,
    "downsample_matrices": coma_args.downsample_matrices,
}

encoder = Encoder3DMesh(**enc_params)

In [None]:
x = EasyDict(next(iter(mesh_dm.train_dataloader())))

encoder.forward_conv_stack(x.s_t, preserve_graph_structure=False)