## Load Model from Scratch

In [72]:
from models import model_from_kwargs

Load the Nested and Connected YAML

In [73]:
import yaml
import re

def resolve_references(data, context):
    """
    Recursively resolve references in the YAML data using the context dictionary.

    Args:
        data (dict): The YAML data.
        context (dict): The context dictionary with variable definitions.

    Returns:
        dict: The YAML data with resolved references.
    """
    if isinstance(data, dict):
        return {k: resolve_references(v, context) for k, v in data.items()}
    elif isinstance(data, list):
        return [resolve_references(item, context) for item in data]
    elif isinstance(data, str):
        # Find all placeholders in the format ${...}
        matches = re.findall(r'\$\{([^}]+)\}', data)
        for match in matches:
            # Replace the placeholder with the corresponding value from the context
            keys = match.split('.')
            value = context
            for key in keys:
                value = value.get(key)
                if value is None:
                    break
            if value is not None:
                data = data.replace(f"${{{match}}}", str(value))
        return data
    else:
        return data

def load_yaml_with_interpolation(file_path):
    """
    Load a YAML file with variable interpolation into a nested dictionary.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        dict: The YAML data with interpolated variables.
    """
    with open(file_path, 'r') as file:
        try:
            data = yaml.safe_load(file)
            # Resolve references in the YAML data
            data = resolve_references(data, data)
        except yaml.YAMLError as e:
            print(f"Error loading YAML file: {e}")
            return None
    return data

In [74]:
stage_hp = load_yaml_with_interpolation("yamls/shapenetcar/upt/dim768_seq1024sdf512_cnext_lr5e4_sd02_reprcnn_grn_grid32.yaml")

In [75]:
model = model_from_kwargs(
        **stage_hp["model"],
        input_shape=(None, 3),#trainer.input_shape,
        output_shape=(None, 1),#trainer.output_shape,
        update_counter=None,#trainer.update_counter,
        path_provider=None,#path_provider,
        data_container=None,#data_container,
    )

AttributeError: 'NoneType' object has no attribute 'get_dataset'

In [76]:
from models.composite.rans_simformer_nognn_sdf_model import *

In [78]:
RansSimformerNognnSdfModel(
    stage_hp["model"]['grid_encoder'],
    stage_hp["model"]['mesh_encoder'],
    stage_hp["model"]['latent'],
    stage_hp["model"]['decoder'],
)

AttributeError: 'NoneType' object has no attribute 'get_dataset'

In [79]:
stage_hp["model"]['grid_encoder']

{'kind': 'encoders.rans_grid_convnext',
 'patch_size': 2,
 'kernel_size': 3,
 'depthwise': False,
 'global_response_norm': True,
 'depths': [2, 2, 2],
 'dims': [192, 384, 768],
 'upsample_size': 64,
 'upsample_mode': 'nearest',
 'optim': "{'kind': 'adamw', 'lr': '${vars.lr}', 'weight_decay': 0.05, 'schedule': {'template': '${yaml:schedules/wupcos_epoch}', 'template.vars.end_epoch': 50}}"}

In [81]:
import torch
from models import model_from_kwargs
from models.base.composite_model_base import CompositeModelBase
from utils.factory import create

class RansSimformerNognnSdfModel_CAEML(CompositeModelBase):
    def __init__(
            self,
            grid_encoder,
            mesh_encoder,
            latent,
            decoder,
            resolution,
            concat_pos_to_sdf,
            **kwargs,
    ):
        super().__init__(**kwargs)
        common_kwargs = dict(
            update_counter=self.update_counter,
            path_provider=self.path_provider,
            dynamic_ctx=self.dynamic_ctx,
            static_ctx=self.static_ctx,
            data_container=self.data_container,
        )
        # grid_encoder
        self.grid_encoder = create(
            grid_encoder,
            model_from_kwargs,
            resolution=resolution,
            concat_pos_to_sdf=concat_pos_to_sdf,
        )
        """
        # mesh_encoder
        self.mesh_encoder = create(
            mesh_encoder,
            model_from_kwargs,
            input_shape=self.input_shape,
            **common_kwargs,
        )
        # latent
        self.latent = create(
            latent,
            model_from_kwargs,
            input_shape=self.mesh_encoder.output_shape,
            **common_kwargs,
        )
        # decoder
        self.decoder = create(
            decoder,
            model_from_kwargs,
            **common_kwargs,
            input_shape=self.latent.output_shape,
            output_shape=self.output_shape,
        )
        """

    @property
    def submodels(self):
        return dict(
            grid_encoder=self.grid_encoder,
            mesh_encoder=self.mesh_encoder,
            latent=self.latent,
            decoder=self.decoder,
        )

    # noinspection PyMethodOverriding
    def forward(self, mesh_pos, sdf, query_pos, batch_idx, unbatch_idx, unbatch_select):
        outputs = {}

        # encode data
        grid_embed = self.grid_encoder(sdf)
        mesh_embed = self.mesh_encoder(mesh_pos=mesh_pos, batch_idx=batch_idx)
        embed = torch.concat([grid_embed, mesh_embed], dim=1)

        # propagate
        propagated = self.latent(embed)

        # decode
        x_hat = self.decoder(
            propagated,
            query_pos=query_pos,
            unbatch_idx=unbatch_idx,
            unbatch_select=unbatch_select,
        )
        outputs["x_hat"] = x_hat

        return outputs

In [82]:
model = RansSimformerNognnSdfModel_CAEML(
    stage_hp["model"]['grid_encoder'],
    stage_hp["model"]['mesh_encoder'],
    stage_hp["model"]['latent'],
    stage_hp["model"]['decoder'],
    resolution=(32, 32, 32),
    concat_pos_to_sdf=True
)

In [84]:
model.input_shape

### Build Each Block from Scratch

In [95]:
from models.encoders.rans_grid_convnext import RansGridConvnext
from models.encoders.rans_perceiver import RansPerceiver as EncoderRansPerceiver
from models.latent.transformer_model import TransformerModel
from models.decoders.rans_perceiver import RansPerceiver as DecoderRansPerceiver

In [96]:
stage_hp

{'wandb': 'cvsim',
 'name': 'snc-all-sdfpos-e1000-subsam1-lr5e4-sdfperconly-seqlen1024-sdf512-cnext-dim768-sd02unif-reprcnn-grn-grid32',
 'stage_name': 'stage1',
 'vars': {'lr': 0.0005,
  'batch_size': 1,
  'max_batch_size': 16,
  'epochs': 100,
  'grid_resolution': 32,
  'optim': {'kind': 'adamw',
   'lr': '0.0005',
   'weight_decay': 0.05,
   'schedule': {'template': '${yaml:schedules/wupcos_epoch}',
    'template.vars.end_epoch': 50}}},
 'datasets': {'train': {'kind': 'shapenet_car',
   'split': 'train',
   'grid_resolution': '32',
   'concat_pos_to_sdf': True,
   'collators': [{'kind': 'rans_simformer_nognn_collator'}]},
  'test': {'kind': 'shapenet_car',
   'split': 'test',
   'grid_resolution': '32',
   'concat_pos_to_sdf': True,
   'collators': [{'kind': 'rans_simformer_nognn_collator'}]}},
 'model': {'kind': 'rans_simformer_nognn_sdf_model',
  'grid_encoder': {'kind': 'encoders.rans_grid_convnext',
   'patch_size': 2,
   'kernel_size': 3,
   'depthwise': False,
   'global_respo

In [97]:
"""
When additionally using SDF features as input, the SDF features are encoded by a shallow ConvNeXt
V2 [113] that processes the SDF features into 512 (8x8x8) tokens, which are concatenated to the
perceiver tokens. To distinguish between the two token types, a learnable vector per type is added to
each of the tokens.
"""

"""
ConvNext model class for image processing with configurable architecture.

Parameters:
----------
patch_size : int
    Size of each image patch.
input_dim : int
    Number of input channels (e.g., 3 for RGB images).
dims : list[int]
    Dimensionality (number of channels) for each stage.
depths : list[int]
    Number of layers in each stage.
ndim : int
    Spatial dimensionality of the data (2 for 2D images).
drop_path_rate : float
    Probability of dropping paths in the network for regularization.
drop_path_decay : float
    Rate at which drop path probability decays across layers.
kernel_size : int
    Size of the convolutional kernel.
depthwise : bool
    Whether to use depthwise separable convolutions for efficiency.
global_response_norm : bool
    Whether to apply global response normalization for feature scaling.
"""
    
grid_encoder = RansGridConvnext(
    patch_size = 2,
    kernel_size = 3,
    depthwise = False,
    global_response_norm = True,
    depths = [ 2, 2, 2 ],
    dims = [ 192, 384, 768 ],
    upsample_size = 64,
    upsample_mode = "nearest",
    resolution = (32, 32, 32), # This is because they are separate from the Input Positions I guess.
    concat_pos_to_sdf = True
)

In [98]:
grid_encoder

RansGridConvnext(
  (model): ConvNext(
    (stem): Sequential(
      (0): Conv3d(4, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2))
      (1): LayerNorm3d(
        (layer): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      )
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (downsampling): Identity()
        (blocks): Sequential(
          (0): ConvNextBlock(
            (drop_path): DropPath(drop_prob=0.000)
            (conv): Conv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): LayerNorm3d(
              (layer): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            )
            (mlp): Mlp(
              (fc1): Conv3d(192, 768, kernel_size=(1, 1, 1), stride=(1, 1, 1))
              (act): GELU(approximate='none')
              (grn): GlobalResponseNorm()
              (fc2): Conv3d(768, 192, kernel_size=(1, 1, 1), stride=(1, 1, 1))
            )
          )
          (1): ConvNextBlock(
            (dro

In [99]:
mesh_encoder = EncoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    num_output_tokens = 1024,
    add_type_token = True,
    init_weights = "truncnormal",
    input_shape = (None, 3)
)

In [113]:
mesh_encoder

RansPerceiver(
  (pos_embed): ContinuousSincosEmbed(dim=768)
  (mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
  )
  (block): PerceiverPoolingBlock(
    (perceiver): PerceiverBlock(
      (norm1q): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (norm1kv): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): PerceiverAttention1d(
        (kv): Linear(in_features=768, out_features=1536, bias=True)
        (q): Linear(in_features=768, out_features=768, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (drop_path1): DropPath(drop_prob=0.000)
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, 

In [101]:
mesh_encoder.output_shape

(1024, 768)

In [102]:
latent = TransformerModel(
    init_weights = "truncnormal",
    drop_path_rate = 0.2,
    drop_path_decay = False,
    dim = 768,
    num_attn_heads = 12,
    depth = 12,
    input_shape = mesh_encoder.output_shape
)

In [103]:
latent

TransformerModel(
  (input_proj): LinearProjection(
    (proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (blocks): ModuleList(
    (0-11): 12 x PrenormBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): DotProductAttention1d(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (ls1): Identity()
      (drop_path1): DropPath(drop_prob=0.200)
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (ls2): Identity()
      (drop_path2): DropPath(drop_prob=0.200)
    )
  )
)

In [105]:
latent.output_shape

(1024, 768)

In [110]:
decoder = DecoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 1),
    static_ctx = {"ndim":3} # Not Sure.
)

In [111]:
decoder

RansPerceiver(
  (proj): LinearProjection(
    (proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (pos_embed): ContinuousSincosEmbed(dim=768)
  (query_mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=768, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=768, out_features=768, bias=True)
  )
  (perceiver): PerceiverBlock(
    (norm1q): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (norm1kv): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (attn): PerceiverAttention1d(
      (kv): Linear(in_features=768, out_features=1536, bias=True)
      (q): Linear(in_features=768, out_features=768, bias=True)
      (proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (drop_path1): DropPath(drop_prob=0.000)
    (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate='none')
      (fc2): Line

Now piece them up in a class

In [114]:
import torch
from models import model_from_kwargs
from models.base.composite_model_base import CompositeModelBase
from utils.factory import create

from models.encoders.rans_grid_convnext import RansGridConvnext
from models.encoders.rans_perceiver import RansPerceiver as EncoderRansPerceiver
from models.latent.transformer_model import TransformerModel
from models.decoders.rans_perceiver import RansPerceiver as DecoderRansPerceiver

class RansSimformerNognnSdfModel_CAEML(CompositeModelBase):
    def __init__(
            self,
            grid_encoder,
            mesh_encoder,
            latent,
            decoder,
            **kwargs,
    ):
        super().__init__(**kwargs)
        common_kwargs = dict(
            update_counter=self.update_counter,
            path_provider=self.path_provider,
            dynamic_ctx=self.dynamic_ctx,
            static_ctx=self.static_ctx,
            data_container=self.data_container,
        )
        # grid_encoder
        self.grid_encoder = grid_encoder
        # mesh_encoder
        self.mesh_encoder = mesh_encoder
        # latent
        self.latent = latent
        # decoder
        self.decoder = decoder

    @property
    def submodels(self):
        return dict(
            grid_encoder=self.grid_encoder,
            mesh_encoder=self.mesh_encoder,
            latent=self.latent,
            decoder=self.decoder,
        )

    # noinspection PyMethodOverriding
    def forward(self, mesh_pos, sdf, query_pos, batch_idx, unbatch_idx, unbatch_select):
        outputs = {}

        # encode data
        grid_embed = self.grid_encoder(sdf)
        mesh_embed = self.mesh_encoder(mesh_pos=mesh_pos, batch_idx=batch_idx)
        embed = torch.concat([grid_embed, mesh_embed], dim=1)

        # propagate
        propagated = self.latent(embed)

        # decode
        x_hat = self.decoder(
            propagated,
            query_pos=query_pos,
            unbatch_idx=unbatch_idx,
            unbatch_select=unbatch_select,
        )
        outputs["x_hat"] = x_hat

        return outputs
        
grid_encoder = RansGridConvnext(
    patch_size = 2,
    kernel_size = 3,
    depthwise = False,
    global_response_norm = True,
    depths = [ 2, 2, 2 ],
    dims = [ 192, 384, 768 ],
    upsample_size = 64,
    upsample_mode = "nearest",
    resolution = (32, 32, 32), # This is because they are separate from the Input Positions I guess.
    concat_pos_to_sdf = True
)

mesh_encoder = EncoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    num_output_tokens = 1024,
    add_type_token = True,
    init_weights = "truncnormal",
    input_shape = (None, 3)
)

latent = TransformerModel(
    init_weights = "truncnormal",
    drop_path_rate = 0.2,
    drop_path_decay = False,
    dim = 768,
    num_attn_heads = 12,
    depth = 12,
    input_shape = mesh_encoder.output_shape
)

decoder = DecoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 1),
    static_ctx = {"ndim":3} # Not Sure.
)


model = RansSimformerNognnSdfModel_CAEML(
    grid_encoder,
    mesh_encoder,
    latent,
    decoder,
)

In [115]:
model

RansSimformerNognnSdfModel_CAEML(
  (grid_encoder): RansGridConvnext(
    (model): ConvNext(
      (stem): Sequential(
        (0): Conv3d(4, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2))
        (1): LayerNorm3d(
          (layer): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        )
      )
      (stages): ModuleList(
        (0): ConvNextStage(
          (downsampling): Identity()
          (blocks): Sequential(
            (0): ConvNextBlock(
              (drop_path): DropPath(drop_prob=0.000)
              (conv): Conv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (norm): LayerNorm3d(
                (layer): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
              )
              (mlp): Mlp(
                (fc1): Conv3d(192, 768, kernel_size=(1, 1, 1), stride=(1, 1, 1))
                (act): GELU(approximate='none')
                (grn): GlobalResponseNorm()
                (fc2): Conv3d(768, 192, kernel_size=(1, 1