# Train from Scratch

### Load Stage HP

In [1]:
import yaml
import re

def resolve_references(data, context):
    """
    Recursively resolve references in the YAML data using the context dictionary.

    Args:
        data (dict): The YAML data.
        context (dict): The context dictionary with variable definitions.

    Returns:
        dict: The YAML data with resolved references and preserved types.
    """
    if isinstance(data, dict):
        return {k: resolve_references(v, context) for k, v in data.items()}
    elif isinstance(data, list):
        return [resolve_references(item, context) for item in data]
    elif isinstance(data, str):
        # Find all placeholders in the format ${...}
        matches = re.findall(r'\$\{([^}]+)\}', data)
        for match in matches:
            # Replace the placeholder with the corresponding value from the context
            keys = match.split('.')
            value = context
            for key in keys:
                value = value.get(key)
                if value is None:
                    break
            if value is not None:
                # Attempt to cast the interpolated value to the original type if needed
                if isinstance(value, int):
                    return int(data.replace(f"${{{match}}}", str(value)))
                elif isinstance(value, float):
                    return float(data.replace(f"${{{match}}}", str(value)))
                elif isinstance(value, bool):
                    return bool(data.replace(f"${{{match}}}", str(value)))
                data = data.replace(f"${{{match}}}", str(value))
        return data
    else:
        return data

def load_yaml_with_interpolation(file_path):
    """
    Load a YAML file with variable interpolation into a nested dictionary.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        dict: The YAML data with interpolated variables.
    """
    with open(file_path, 'r') as file:
        try:
            data = yaml.safe_load(file)
            # Resolve references in the YAML data
            data = resolve_references(data, data)
        except yaml.YAMLError as e:
            print(f"Error loading YAML file: {e}")
            return None
    return data

stage_hp = load_yaml_with_interpolation("yamls/shapenetcar/upt/dim768_seq1024sdf512_cnext_lr5e4_sd02_reprcnn_grn_grid32.yaml")

### Load Model

In [2]:
import torch
from models import model_from_kwargs
from models.base.composite_model_base import CompositeModelBase
from utils.factory import create

from models.encoders.rans_grid_convnext import RansGridConvnext
from models.encoders.rans_perceiver import RansPerceiver as EncoderRansPerceiver
from models.latent.transformer_model import TransformerModel
from models.decoders.rans_perceiver import RansPerceiver as DecoderRansPerceiver

class RansSimformerNognnSdfModel_CAEML(CompositeModelBase):
    def __init__(
            self,
            grid_encoder,
            mesh_encoder,
            latent,
            decoder,
            **kwargs,
    ):
        super().__init__(**kwargs)
        common_kwargs = dict(
            update_counter=self.update_counter,
            path_provider=self.path_provider,
            dynamic_ctx=self.dynamic_ctx,
            static_ctx=self.static_ctx,
            data_container=self.data_container,
        )
        # grid_encoder
        self.grid_encoder = grid_encoder
        # mesh_encoder
        self.mesh_encoder = mesh_encoder
        # latent
        self.latent = latent
        # decoder
        self.decoder = decoder

    @property
    def submodels(self):
        return dict(
            grid_encoder=self.grid_encoder,
            mesh_encoder=self.mesh_encoder,
            latent=self.latent,
            decoder=self.decoder,
        )

    # noinspection PyMethodOverriding
    def forward(self, mesh_pos, sdf, query_pos, batch_idx, unbatch_idx, unbatch_select):
        outputs = {}

        # encode data
        grid_embed = self.grid_encoder(sdf)
        mesh_embed = self.mesh_encoder(mesh_pos=mesh_pos, batch_idx=batch_idx)
        embed = torch.concat([grid_embed, mesh_embed], dim=1)

        # propagate
        propagated = self.latent(embed)

        # decode
        x_hat = self.decoder(
            propagated,
            query_pos=query_pos,
            unbatch_idx=unbatch_idx,
            unbatch_select=unbatch_select,
        )
        outputs["x_hat"] = x_hat

        return outputs
        
grid_encoder = RansGridConvnext(
    patch_size = 2,
    kernel_size = 3,
    depthwise = False,
    global_response_norm = True,
    depths = [ 2, 2, 2 ],
    dims = [ 192, 384, 768 ],
    upsample_size = 64,
    upsample_mode = "nearest",
    resolution = (32, 32, 32), # This is because they are separate from the Input Positions I guess.
    concat_pos_to_sdf = True
)

mesh_encoder = EncoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    num_output_tokens = 1024,
    add_type_token = True,
    init_weights = "truncnormal",
    input_shape = (None, 3)
)

latent = TransformerModel(
    init_weights = "truncnormal",
    drop_path_rate = 0.2,
    drop_path_decay = False,
    dim = 768,
    num_attn_heads = 12,
    depth = 12,
    input_shape = mesh_encoder.output_shape
)

decoder = DecoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 1),
    static_ctx = {"ndim":3} # Not Sure.
)


model = RansSimformerNognnSdfModel_CAEML(
    grid_encoder,
    mesh_encoder,
    latent,
    decoder,
)

### Load Dataset

In [3]:
from datasets.shapenet_car import ShapenetCar
from utils.data_container import DataContainer

train_dataset = ShapenetCar(
    split = "train",
    grid_resolution = 32,
    standardize_query_pos = False,
    concat_pos_to_sdf = True,
    global_root = '/home/ubuntu/UPT/data/shapenet_car_processed',
    local_root = '/home/ubuntu/UPT/data',
    seed=None,
)

test_dataset = ShapenetCar(
    split = "test",
    grid_resolution = 32,
    standardize_query_pos = False,
    concat_pos_to_sdf = True,
    global_root = '/home/ubuntu/UPT/data/shapenet_car_processed',
    local_root = '/home/ubuntu/UPT/data',
    seed=None,
)

datasets = {"train": train_dataset, "test": test_dataset}
data_container = DataContainer(
    **datasets,
    num_workers=1,
    pin_memory=False,
    config_provider=None,
    seed=0,
)

### Set Up Data Loaders

In [4]:
from datasets.collators.rans_simformer_nognn_collator import RansSimformerNognnCollator
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=1,
    shuffle=True,
    drop_last=True,
    collate_fn=RansSimformerNognnCollator(),
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=True,
    drop_last=True,
    collate_fn=RansSimformerNognnCollator(),
)

In [6]:
data_iter = iter(train_dataloader)

In [7]:
batch = next(data_iter)

UseModeWrapperException: wrap kappadata.KDDataset into kappadata.ModeWrapper before calling __getitem__

In [5]:
for batch in train_dataloader:
    print(batch)

UseModeWrapperException: wrap kappadata.KDDataset into kappadata.ModeWrapper before calling __getitem__

### Set Up Optimizer

In [None]:
epochs = 1
optim = torch.optim.AdamW(model.parameters(), lr=5.0e-4, weight_decay=0.05)
total_updates = len(train_dataloader) * epochs
warmup_updates = int(total_updates * 0.1)
lrs = torch.concat(
    [
        # linear warmup
        torch.linspace(0, optim.defaults["lr"], warmup_updates),
        # linear decay
        torch.linspace(optim.defaults["lr"], 0, total_updates - warmup_updates),
    ],
)

### Set Up Trainer

In [None]:
from trainers.rans_simformer_nognn_trainer import RansSimformerNognnTrainer

trainer = RansSimformerNognnTrainer(
    device = "cuda",
    data_container = data_container,
    loss_function = stage_hp["trainer"]["loss_function"],
    precision = 'bfloat16',
    max_epochs = 100,
    effective_batch_size = 1,
    max_batch_size = 16,
)

### Train

In [None]:
# train model
update = 0
pbar = tqdm(total=total_updates)
pbar.update(0)
pbar.set_description(
    f"train_loss: ??????? "
    f"rollout_loss: ???????"
)
train_losses = []
rollout_losses = []
rollout_loss = 0.
loss = None
for _ in range(epochs):
    # train for an epoch
    model.train()
    for batch in train_dataloader:
        # schedule learning rate
        for param_group in optim.param_groups:
            param_group["lr"] = lrs[update]

        # forward pass
        y_hat = model(
            input_feat=batch["input_feat"].to(device),
            input_pos=batch["input_pos"].to(device),
            supernode_idxs=batch["supernode_idxs"].to(device),
            batch_idx=batch["batch_idx"].to(device),
            output_pos=batch["output_pos"].to(device),
            timestep=batch["timestep"].to(device),
        )
        y = batch["output_feat"].to(device)
        loss = F.mse_loss(y_hat, y)

        # backward pass
        loss.backward()

        # update step
        optim.step()
        optim.zero_grad()

        # status update
        update += 1
        pbar.update()
        pbar.set_description(
            f"train_loss: {loss.item():.6f} "
            f"rollout_loss: {rollout_loss:.6f}"
        )
        train_losses.append(loss.item())

    for test_batch in rollout_dataloader:
        with torch.no_grad():
            rollout_preds = model.rollout(
                input_feat=test_batch["input_feat"].to(device),
                input_pos=test_batch["input_pos"].to(device),
                supernode_idxs=test_batch["supernode_idxs"].to(device),
                batch_idx=test_batch["batch_idx"].to(device),
            )
            assert len(test_batch["output_feat"]) == 1, "batch_size for rollout should be 1"
            output_feat = test_batch["output_feat"][0]
            num_rollout_timesteps = len(output_feat)
            rollout_loss = 0.
            for i in range(num_rollout_timesteps):
                pred = rollout_preds[i]
                target = output_feat[i]
                rollout_loss += F.mse_loss(pred, target.to(device))
            rollout_loss /= num_rollout_timesteps
        rollout_losses.append(rollout_loss)
        pbar.set_description(
            f"train_loss: {loss.item():.6f} "
            f"rollout_loss: {rollout_loss:.6f}"
        )
pbar.close()