# 4 UPT Tutorial on CAE ML Graph

This graph is from the SS-INR2-SOLVER

In [1]:
import dgl
import torch

# Load the DGL graph from a binary file
graph_list, _ = dgl.load_graphs("test_graph.bin")

# If there is only one graph in the file, retrieve it
graph = graph_list[0]

# Print basic information about the graph
print("Graph:", graph)
print("Number of nodes:", graph.num_nodes())
print("Number of edges:", graph.num_edges())

# If the graph contains node or edge features, access them as follows
if 'feat' in graph.ndata:
    print("Node features:", graph.ndata['feat'])
if 'feat' in graph.edata:
    print("Edge features:", graph.edata['feat'])

Graph: Graph(num_nodes=87345, num_edges=0,
      ndata_schemes={'x': Scheme(shape=(3,), dtype=torch.float32), 'pos': Scheme(shape=(3,), dtype=torch.float32), 'sv': Scheme(shape=(3,), dtype=torch.float32), 'y': Scheme(shape=(2,), dtype=torch.float32), 'node_type': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'edge_type': Scheme(shape=(1,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int32), 'x': Scheme(shape=(8,), dtype=torch.float32)})
Number of nodes: 87345
Number of edges: 0


In [2]:
graph.ndata['x'] = torch.concat([graph.ndata['pos'], graph.ndata['sv']],dim=1)

In [3]:
graph.ndata['x'].shape

torch.Size([87345, 6])

#### Encoder

The encoder processes input features (e.g. velocities, pressure, ...) and input positions at timestep 
 and encodes it into a latent representation 
. Input features and input positions are sparse tensors. The input features are first processed with a shallow MLP. Then, the input positions are added to the result of the MLP. This representation is then used for a message passing, where messages are only passed to selected supernodes. We randomly select a fixed number nodes from each pointcloud during dataloading which are then used as "supernodes".

In [4]:
import torch
from models.encoders.rans_perceiver import RansPerceiver as EncoderRansPerceiver

In [5]:
mesh_encoder = EncoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    num_output_tokens = 1024,
    add_type_token = True,
    init_weights = "truncnormal",
    input_shape = (None, 6)
)

In [6]:
batch_idx = torch.zeros(graph.ndata["x"].shape[0], dtype=torch.int64)

In [7]:
batch_idx.shape

torch.Size([87345])

In [8]:
batch_idx

tensor([0, 0, 0,  ..., 0, 0, 0])

In [9]:
mesh_embed = mesh_encoder(mesh_pos=graph.ndata["x"], batch_idx=batch_idx)

In [10]:
mesh_embed.shape

torch.Size([1, 1024, 768])

In [11]:
mesh_embed

tensor([[[-0.5807,  0.4658,  0.1716,  ..., -0.7107, -0.5791, -0.0913],
         [ 0.5517, -1.2879,  0.6687,  ..., -1.1653, -0.3798, -0.5626],
         [-0.1468, -0.9169,  1.4105,  ...,  1.0836, -0.6124,  0.4049],
         ...,
         [-0.9655,  0.8354,  1.0881,  ...,  0.3756, -1.3532, -1.2547],
         [-1.3816, -0.4939,  1.3127,  ...,  0.5757,  0.6239, -0.0216],
         [ 1.8845,  1.2737, -0.2964,  ...,  1.0076, -0.6336,  0.4241]]],
       grad_fn=<AddBackward0>)

#### Approximator

The approximator takes the latent_tokens and pushes them forward by one timestep. It simply consists of some transformer blocks.

In [12]:
from models.latent.transformer_model import TransformerModel

In [13]:
latent = TransformerModel(
    init_weights = "truncnormal",
    drop_path_rate = 0.2,
    drop_path_decay = False,
    dim = 768,
    num_attn_heads = 12,
    depth = 12,
    input_shape = mesh_encoder.output_shape
)

In [14]:
propagated = latent(mesh_embed)

In [15]:
propagated.shape

torch.Size([1, 1024, 768])

In [16]:
propagated

tensor([[[-2.4416, -0.4059,  0.1636,  ...,  0.4360,  0.7620,  2.4100],
         [-0.0073, -1.2096, -0.7920,  ...,  1.8539, -1.7077,  1.3626],
         [ 0.2752,  2.5950, -0.6274,  ...,  1.3275,  1.4254,  2.3314],
         ...,
         [ 0.7461,  0.2784,  0.9726,  ..., -0.7214,  0.8626,  0.2719],
         [-2.0826,  1.3264,  1.0127,  ...,  1.7330, -0.0511,  1.8760],
         [-0.2851,  0.2323,  0.1737,  ...,  1.5888,  0.7301,  2.9876]]],
       grad_fn=<ViewBackward0>)

#### Decoder

The decoder takes the latent_tokens and decodes them into the original space of the input data. It does this by querying the latent space at arbitrary positions. It first employs some transformer blocks, followed by a perceiver decoder block. The output positions are encoded via a shallow MLP before being used as query vector for the perceiver.

For training, the output positions need to have an associated ground truth value in the dataset as the model is trained via an mean-squared-error loss between the predictions at the output positions and the ground truth value at the output positions. For inference, the output positions can be arbitrary. Also: output positions and input positions do not have to match (also not during training).

In [17]:
from models.decoders.rans_perceiver import RansPerceiver as DecoderRansPerceiver

In [18]:
decoder = DecoderRansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 1),
    static_ctx = {"ndim":6} # Not Sure.
)

In [19]:
pred = decoder(propagated, query_pos=graph.ndata["x"].unsqueeze(dim=0), unbatch_idx=batch_idx, unbatch_select=[0])

In [20]:
pred.shape

torch.Size([87345, 1])

In [21]:
pred

tensor([[ 0.0411],
        [ 0.0393],
        [ 0.0218],
        ...,
        [-0.0245],
        [ 0.0553],
        [ 0.0005]], grad_fn=<CatBackward0>)

In [22]:
graph.ndata["y"].shape

torch.Size([87345, 2])

### Add support for Image Output: Taking Inspitation from the Decoder Perceiver from UPT Tutorial

In [23]:
import einops
import torch
from kappamodules.layers import ContinuousSincosEmbed, LinearProjection, Sequential
from kappamodules.transformer import PerceiverBlock, Mlp, DitPerceiverBlock, DitBlock
from kappamodules.vit import VitBlock
from torch_geometric.utils import unbatch
from torch import nn
from models.base.single_model_base import SingleModelBase
import math

from functools import partial

class RansPerceiver(SingleModelBase):
    def __init__(
        self,
        dim,
        num_attn_heads,
        init_weights="xavier_uniform",
        init_last_proj_zero=False,
        use_last_norm=False,
        output_mode="sparse",  
        num_images=1,  # Number of "image-slots" in the channel dimension
        image_dims=None,  # (height, width) for image output
        last_activation=None,  # <-- New parameter
        **kwargs,
    ):
        """
        `last_activation` can be:
            - None                (no activation, unbounded)
            - "sigmoid"          (clamps output to [0,1])
            - "tanh_shift_scale" (maps output to [0,1] via 0.5 * (tanh(x) + 1))
        """
        super().__init__(**kwargs)
        self.dim = dim
        self.num_attn_heads = num_attn_heads
        self.num_images = num_images  
        self.use_last_norm = use_last_norm
        self.output_mode = output_mode
        self.image_dims = image_dims
        self.last_activation = last_activation  # <-- Store last_activation

        # Input projection
        _, input_dim = self.input_shape
        self.proj = LinearProjection(input_dim, dim, init_weights=init_weights)

        # Query tokens (positional embedding + MLP)
        self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=self.static_ctx["ndim"])
        self.query_mlp = Mlp(in_dim=dim, hidden_dim=dim, init_weights=init_weights)

        # Transformer block (latent to tokens)
        self.perceiver = PerceiverBlock(
            dim=dim,
            num_heads=num_attn_heads,
            init_last_proj_zero=init_last_proj_zero,
            init_weights=init_weights,
        )

        # Figure out final channels (e.g., if output_dim=3 (RGB) then final is 3 * num_images)
        _, output_dim = self.output_shape
        final_channels = output_dim * num_images

        # Final projection
        self.pred = LinearProjection(dim, final_channels, init_weights=init_weights)

        # Optional normalization
        self.norm = nn.LayerNorm(dim, eps=1e-6) if use_last_norm else nn.Identity()

    def forward(self, x, query_pos, unbatch_idx, unbatch_select):
        """
        x          : [batch_size, latent_seq_len, dim]
        query_pos  : [batch_size, height*width, pos_dim]
        """
        # 1) Project input
        x = self.proj(x)

        # 2) Create query embeddings
        query_pos_embed = self.pos_embed(query_pos)      # [batch_size, height*width, dim]
        query = self.query_mlp(query_pos_embed)          # [batch_size, height*width, dim]

        # 3) Perceiver decoding
        x = self.perceiver(q=query, kv=x)
        x = self.norm(x)
        x = self.pred(x)
        # Now shape: [batch_size, height*width, output_dim*num_images]

        # 4) Reshape / rearrange based on self.output_mode
        if self.output_mode == "sparse":
            # E.g. for point-cloud style outputs
            x = einops.rearrange(
                x, 
                "batch_size max_num_points (channels) -> (batch_size max_num_points) channels", 
                channels=self.output_shape[1]  # e.g. 3
            )
            unbatched = unbatch(x, batch=unbatch_idx)
            x = torch.concat([unbatched[i] for i in unbatch_select])

        elif self.output_mode == "dense_to_sparse_unpadded":
            x = einops.rearrange(x, "b seqlen c -> (b seqlen) c")

        elif self.output_mode == "image":
            # Reshape to [batch_size, num_images, output_dim, height, width]
            if self.image_dims is not None:
                height, width = self.image_dims

                # We expect exactly height*width tokens
                expected_tokens = height * width
                assert x.size(1) == expected_tokens, (
                    f"Expected {expected_tokens} tokens (height*width), got {x.size(1)}"
                )

                batch_size = x.size(0)
                output_dim = self.output_shape[1]  # e.g., 3 (RGB)
                
                # Step-by-step reshape:
                # Currently => [b, height*width, num_images*output_dim]
                # 1) [b, hw, num_images, output_dim]
                x = x.view(batch_size, height * width, self.num_images, output_dim)

                # 2) [b, num_images, output_dim, hw]
                x = x.permute(0, 2, 3, 1)

                # 3) [b, num_images, output_dim, height, width]
                x = x.view(batch_size, self.num_images, output_dim, height, width)

                # Optional final dimension reorder if you prefer [b, num_images, width, height, output_dim]:
                x = x.permute(0, 1, 4, 3, 2)
            else:
                raise ValueError("image_dims must be provided for 'image' output mode.")

        else:
            raise ValueError(f"Unsupported output mode: {self.output_mode}")

        # 5) Apply optional final activation
        if self.last_activation == "sigmoid":
            x = torch.sigmoid(x)
        elif self.last_activation == "tanh_shift_scale":
            # Map from [-1, +1] to [0, 1]
            x = 0.5 * (torch.tanh(x) + 1.0)
        elif self.last_activation is not None:
            raise ValueError(
                f"Unsupported value for last_activation='{self.last_activation}'. "
                "Use None, 'sigmoid', or 'tanh_shift_scale'."
            )

        return x

class DecoderPerceiver(nn.Module):
    def __init__(
            self,
            input_dim,
            output_dim,
            ndim,
            dim,
            depth,
            num_heads,
            unbatch_mode="dense_to_sparse_unpadded",
            perc_dim=None,
            perc_num_heads=None,
            cond_dim=None,
            init_weights="truncnormal002",
            num_images=1, 
            **kwargs,
    ):
        super().__init__(**kwargs)
        perc_dim = perc_dim or dim
        perc_num_heads = perc_num_heads or num_heads
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.ndim = ndim
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.perc_dim = perc_dim
        self.perc_num_heads = perc_num_heads
        self.cond_dim = cond_dim
        self.init_weights = init_weights
        self.unbatch_mode = unbatch_mode
        self.num_images = num_images

        # input projection
        self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True)

        # blocks
        if cond_dim is None:
            block_ctor = VitBlock
        else:
            block_ctor = partial(DitBlock, cond_dim=cond_dim)
        self.blocks = Sequential(
            *[
                block_ctor(
                    dim=dim,
                    num_heads=num_heads,
                    init_weights=init_weights,
                )
                for _ in range(depth)
            ],
        )

        # prepare perceiver
        self.pos_embed = ContinuousSincosEmbed(
            dim=perc_dim,
            ndim=ndim,
        )
        if cond_dim is None:
            block_ctor = PerceiverBlock
        else:
            block_ctor = partial(DitPerceiverBlock, cond_dim=cond_dim)

        # decoder
        self.query_proj = nn.Sequential(
            LinearProjection(perc_dim, perc_dim, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(perc_dim, perc_dim, init_weights=init_weights),
        )
        self.perc = block_ctor(dim=perc_dim, kv_dim=dim, num_heads=perc_num_heads, init_weights=init_weights)
        self.pred = nn.Sequential(
            nn.LayerNorm(perc_dim, eps=1e-6),
            LinearProjection(perc_dim, output_dim*num_images, init_weights=init_weights),
        )

    def forward(self, x, output_pos, condition=None):
        # check inputs
        assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)"
        assert output_pos.ndim == 3, "expected shape (batch_size, num_outputs, dim) num_outputs might be padded"
        if condition is not None:
            assert condition.ndim == 2, "expected shape (batch_size, cond_dim)"

        # pass condition to DiT blocks
        cond_kwargs = {}
        if condition is not None:
            cond_kwargs["cond"] = condition

        # input projection
        x = self.input_proj(x)

        # apply blocks
        x = self.blocks(x, **cond_kwargs)

        # create query
        query = self.pos_embed(output_pos)
        query = self.query_proj(query)

        x = self.perc(q=query, kv=x, **cond_kwargs)
        x = self.pred(x)
        if self.unbatch_mode == "dense_to_sparse_unpadded":
            # dense to sparse where no padding needs to be considered
            x = einops.rearrange(
                x,
                "batch_size seqlen dim -> (batch_size seqlen) dim",
            )
        elif self.unbatch_mode == "image":
            # rearrange to square image
            height, width = 96, 192
            x = einops.rearrange(
                x,
                "batch_size (height width) dim -> batch_size dim height width",
                height=int(height),
                width=int(width),
            )
            x = x.view(1, self.num_images, 3, height, width)
            x = x.permute(0, 1, 3, 4, 2)
        else:
            raise NotImplementedError(f"invalid unbatch_mode '{self.unbatch_mode}'")

        return x

In [24]:
decoder = RansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 3), # 3 Channels for RGB
    static_ctx = {"ndim": 2}, # Images are 2D
    output_mode="image",  # New parameter to control output mode
    num_images=2,
    image_dims=(192,96),  # Optional: Tuple of (height, width) for non-square images
)

In [25]:
decoder

RansPerceiver(
  (proj): LinearProjection(
    (proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (pos_embed): ContinuousSincosEmbed(dim=768)
  (query_mlp): Mlp(
    (fc1): Linear(in_features=768, out_features=768, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=768, out_features=768, bias=True)
  )
  (perceiver): PerceiverBlock(
    (norm1q): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (norm1kv): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (attn): PerceiverAttention1d(
      (kv): Linear(in_features=768, out_features=1536, bias=True)
      (q): Linear(in_features=768, out_features=768, bias=True)
      (proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (drop_path1): DropPath(drop_prob=0.000)
    (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate='none')
      (fc2): Line

### Suppose I have a Single Image (192,96)

In [26]:
# Generate positions for a single image: 192 x 96 = 18,432 tokens
height, width = 192, 96
base_grid = torch.stack(torch.meshgrid(
    [torch.arange(height), torch.arange(width)], indexing="ij"
))  # [2, 192, 96]
base_grid = einops.rearrange(base_grid, "c h w -> (h w) c")  # [18,432, 2]

# shape => [1, 18432, 2]
output_pos = base_grid.unsqueeze(0).float()

print(output_pos.shape)

torch.Size([1, 18432, 2])


In [27]:
pred = decoder(propagated, query_pos=output_pos, unbatch_idx=batch_idx, unbatch_select=[0])

In [28]:
pred.shape

torch.Size([1, 2, 96, 192, 3])

In [29]:
graph.ndata["y"].shape

torch.Size([87345, 2])

### Alternative Decoders

In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

# A basic SIREN layer
class SIRENLayer(nn.Module):
    def __init__(self, in_features, out_features, w0=30.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)
        # Initialize following SIREN paper recommendations
        # e.g., uniform_(-1 / in_features, 1 / in_features), etc.
        self.w0 = w0

    def forward(self, x):
        # x shape: [batch, ..., in_features]
        return torch.sin(self.w0 * self.linear(x))

class ModulatedSIREN(nn.Module):
    def __init__(
        self,
        num_tokens=256,
        dim=384,
        hidden_dim=256,
        depth=5, 
        out_frames=22,
        out_channels=3,
        height=192,
        width=96
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        self.out_frames = out_frames
        self.height = height
        self.width = width

        # SIREN layers
        self.layers = nn.ModuleList()
        self.layers.append(SIRENLayer(3, hidden_dim))  # (x, y, frame) -> hidden
        for _ in range(depth - 2):
            self.layers.append(SIRENLayer(hidden_dim, hidden_dim))
        self.final_layer = nn.Linear(hidden_dim, out_channels)

        # An MLP that predicts scale & shift for each hidden layer from the latent
        # Suppose we have 2 parameters (scale, shift) for each hidden layer => 2*(depth-1)*hidden_dim is the total
        # But simpler: produce a scale, shift per layer (not per channel). Up to you how fine-grained you want it.
        self.modulation_mlp = nn.Sequential(
            nn.Linear(num_tokens * dim, 128),
            nn.ReLU(),
            nn.Linear(128, 2*(depth - 1))  # scale, shift for each of the hidden layers (except final)
        )

    def forward(self, latent):
        """
        latent: [batch_size, num_tokens, dim]
        Output: [batch_size, out_frames, height, width, out_channels]
        """
        bsz = latent.size(0)

        # Summarize the latent
        # e.g., flatten -> MLP -> scale/shift parameters
        latent_flat = latent.reshape(bsz, -1)  # [batch_size, num_tokens * dim]
        mod_params = self.modulation_mlp(latent_flat)  # [batch_size, 2*(depth-1)]
        # Reshape: [batch_size, depth-1, 2] => each layer has scale & shift
        mod_params = mod_params.view(bsz, -1, 2)

        # Prepare the coordinate grid
        # We'll create a meshgrid of shape (height*width*out_frames) x 3 => (x, y, t)
        # Normalize coordinates to [-1,1] or something similar.
        xs = torch.linspace(-1, 1, steps=self.width, device=latent.device)
        ys = torch.linspace(-1, 1, steps=self.height, device=latent.device)
        ts = torch.linspace(-1, 1, steps=self.out_frames, device=latent.device)

        # Create a 3D meshgrid
        Y, X, T = torch.meshgrid(ys, xs, ts, indexing='ij')  # shape = [H, W, F]
        coords = torch.stack([X, Y, T], dim=-1)  # [H, W, F, 3]
        coords = coords.reshape(-1, 3)          # [H*W*F, 3]

        # We'll process all coords in a batch
        # Expand to [batch_size, H*W*F, 3]
        coords = coords.unsqueeze(0).expand(bsz, -1, -1)

        # Forward pass through SIREN
        x = coords  # rename for clarity
        for layer_idx, layer in enumerate(self.layers):
            x = layer(x)  # shape: [batch_size, HWF, hidden_dim] or [batch_size, HWF, out_features]
            if layer_idx < len(self.layers) - 1:  # skip final layer
                # Apply modulation (scale & shift) for this layer
                scale = mod_params[:, layer_idx, 0].view(bsz, 1)  # [batch, 1]
                shift = mod_params[:, layer_idx, 1].view(bsz, 1)  # [batch, 1]
                x = x * (1 + scale) + shift

        # Final layer to get RGB
        x = self.final_layer(x)  # [batch_size, HWF, out_channels]

        # Reshape to [batch_size, H, W, F, C]
        x = x.view(bsz, self.out_frames, self.height, self.width, -1)

        return x  # shape: [batch_size, H, W, F, 3]

class SimpleTransformerBlock(nn.Module):
    """A basic Transformer block with multi-head self-attention + MLP."""
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Self-attention
        x2, _ = self.attn(x, x, x)
        x = x + x2
        x = self.norm1(x)
        # MLP
        x2 = self.mlp(x)
        x = x + x2
        x = self.norm2(x)
        return x

class VisionTransformerDecoder(nn.Module):
    def __init__(
        self,
        num_tokens=256,
        dim=384,
        depth=4,
        num_heads=8,
        patch_size=16,
        out_frames=22,
        out_channels=3,
        out_height=192,
        out_width=96
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        
        # Positional embeddings for the decoder tokens (optional)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))

        # A stack of Transformer blocks
        self.transformer = nn.ModuleList([
            SimpleTransformerBlock(dim, num_heads=num_heads) for _ in range(depth)
        ])

        # Suppose we decode into a patch grid for each frame, e.g. 12 x 6 patches = 72 patches,
        # for each of the 22 frames, total 72 * 22 = 1584 tokens. That might be large.
        # Simplify: let's decode fewer patches, then upscale with a small CNN.

        # Create a learnable "query" token for each patch. 
        # Suppose we have total_patches = frames * patches_per_frame
        self.patches_per_frame = (out_height // patch_size) * (out_width // patch_size)
        self.total_output_tokens = out_frames * self.patches_per_frame

        self.query_tokens = nn.Parameter(
            torch.randn(1, self.total_output_tokens, dim)
        )

        # Final linear to go from dim -> patch_size*patch_size*out_channels
        self.patch_projection = nn.Linear(dim, patch_size * patch_size * out_channels)

        # A small upscaling CNN if we want to refine the patches into full resolution.
        self.refine_cnn = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

        self.out_frames = out_frames
        self.patch_size = patch_size
        self.out_channels = out_channels
        self.out_height = out_height
        self.out_width = out_width

    def forward(self, x):
        """
        x: [batch_size, num_tokens, dim]
        Output: [batch_size, out_frames, out_height, out_width, out_channels]
        """
        bsz = x.size(0)

        # Add positional embedding
        x = x + self.pos_embedding[:, : self.num_tokens, :]

        # Pass through Transformer blocks
        for block in self.transformer:
            x = block(x)   # [batch_size, num_tokens, dim]

        # Expand query tokens for each batch
        query_tokens = self.query_tokens.repeat(bsz, 1, 1)  # [batch_size, total_output_tokens, dim]
        
        # Cross-attention: let the query tokens attend to the output x (which is the "memory")
        # For simplicity, just do self-attention with the combined tokens. 
        # Or you could implement cross-attention specifically. Here is a simple approach:
        combined = torch.cat([x, query_tokens], dim=1)  # shape: [B, 256 + total_output_tokens, dim]
        for block in self.transformer:
            combined = block(combined)

        # The last part of combined are the query tokens
        out_tokens = combined[:, self.num_tokens:, :]  # [B, total_output_tokens, dim]

        # Convert tokens -> patches
        patches = self.patch_projection(out_tokens)  # [B, total_output_tokens, patch_size^2 * out_channels]
        patches = patches.view(bsz, 
                               self.total_output_tokens,
                               self.out_channels,
                               self.patch_size,
                               self.patch_size)  # [B, T, C, pH, pW]

        # Rearrange patches into (frames, height, width)
        # We know total_output_tokens = out_frames * patches_per_frame
        # patches_per_frame = (out_height // patch_size) * (out_width // patch_size)
        patches = patches.view(bsz, 
                               self.out_frames,
                               self.patches_per_frame,
                               self.out_channels,
                               self.patch_size,
                               self.patch_size)  # [B, F, P, C, pH, pW]

        # Now we need to tile these patches along height/width
        # Suppose we laid them out row by row:
        # e.g., out_height // patch_size rows, out_width // patch_size cols
        row_count = self.out_height // self.patch_size
        col_count = self.out_width // self.patch_size
        
        # Reshape to a grid
        patches = patches.view(bsz, 
                               self.out_frames,
                               row_count,
                               col_count,
                               self.out_channels,
                               self.patch_size,
                               self.patch_size)

        # Combine row, col into a single H, W
        patches = patches.permute(0, 1, 4, 2, 5, 3, 6)  # [B, F, C, row_count, pH, col_count, pW]
        patches = patches.reshape(
            bsz, 
            self.out_frames,
            self.out_channels,
            row_count * self.patch_size,
            col_count * self.patch_size
        )  # -> [B, F, C, out_height, out_width]

        # Optional refine
        # Flatten frames into batch dimension, run refine, then reshape back
        patches = patches.view(bsz * self.out_frames, self.out_channels, self.out_height, self.out_width)
        patches = self.refine_cnn(patches)
        patches = patches.view(bsz, self.out_frames, self.out_channels, self.out_height, self.out_width)

        # Permute to [B, F, H, W, C]
        patches = patches.permute(0, 1, 3, 4, 2).contiguous()
        return patches

class SimpleCNNDecoder(nn.Module):
    def __init__(
        self,
        num_tokens=256,
        dim=384,
        hidden_dim=1024,
        out_frames=22,
        out_channels=3,
        out_height=192,
        out_width=96
    ):
        super().__init__()
        self.num_tokens = num_tokens
        self.dim = dim
        self.out_frames = out_frames
        self.out_channels = out_channels
        self.out_height = out_height
        self.out_width = out_width
        
        # Example: Project from (num_tokens, dim) -> Flatten -> FC -> feature map
        # Let's create a seed feature map of shape (512, 12, 6) as an example
        self.initial_fc = nn.Linear(num_tokens * dim, 512 * 12 * 6)
        
        # Build upsampling decoder layers
        # Each layer roughly doubles the height/width until we reach (192,96).
        # We also have out_frames * out_channels = 66 for the final.
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # -> (256, 24, 12)
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # -> (128, 48, 24)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # -> (64, 96, 48)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, out_frames * out_channels, 
                               kernel_size=4, stride=2, padding=1),          # -> (66, 192, 96)
        )

    def forward(self, x):
        """
        x: [batch_size, num_tokens, dim]
        We want output: [batch_size, out_frames, out_height, out_width, out_channels]
        """
        batch_size = x.shape[0]
        
        # Flatten
        x = x.reshape(batch_size, -1)  # shape: [batch_size, num_tokens * dim]
        
        # Map to initial feature map
        x = self.initial_fc(x)  # [batch_size, 512*12*6]
        x = x.view(batch_size, 512, 12, 6)  # [batch_size, 512, 12, 6]
        
        # Decode
        x = self.decoder(x)  # [batch_size, out_frames * out_channels, out_height, out_width]
        
        # Reshape: [batch_size, out_frames, out_channels, out_height, out_width]
        x = x.view(batch_size, self.out_frames, self.out_channels, 
                   self.out_height, self.out_width)
        
        # Optionally reorder dims to [batch_size, out_frames, out_height, out_width, out_channels]
        x = x.permute(0, 1, 3, 4, 2)
        return x


In [None]:
decoder = SimpleCNNDecoder(
    num_tokens=1024,
    dim=768,
    hidden_dim=1024,
    out_frames=22,
    out_channels=3,
    out_height=192,
    out_width=96
)

In [None]:
decoder

In [None]:
propagated.shape

In [None]:
pred = decoder(propagated)

In [None]:
pred.shape

In [None]:
decoder = VisionTransformerDecoder(
    num_tokens=1024,
    dim=768,
    depth=4,
    num_heads=8,
    patch_size=16,
    out_frames=22,
    out_channels=3,
    out_height=192,
    out_width=96
)

In [None]:
decoder

In [None]:
pred = decoder(propagated)

In [None]:
pred.shape

In [None]:
decoder = ModulatedSIREN(
        num_tokens=1024,
        dim=768,
        hidden_dim=1024,
        depth=4, 
        out_frames=22,
        out_channels=3,
        height=96,
        width=192
    )

In [None]:
decoder

In [None]:
pred = decoder(propagated)

In [None]:
pred.shape

In [None]:
propagated.shape

### Switch to 3D Query POS

In [None]:
import torch
import einops

def generate_query_positions(num_images, height, width, device=None):
    """
    Returns a tensor of shape [num_images * height * width, 3],
    where each row is (image_idx, y, x).
    Then you'll typically unsqueeze(0) to get [1, num_images*height*width, 3]
    if you want a single batch dimension.
    """
    # We'll collect a list of (image_idx, y, x) for each image
    coords_list = []
    for img_idx in range(num_images):
        # 1) Create a meshgrid for a single image: (y,x) => shape [2, height, width]
        base_grid = torch.stack(
            torch.meshgrid([
                torch.arange(height, device=device),
                torch.arange(width, device=device)
            ], indexing="ij")
        )  # shape [2, height, width]
        # rearrange to [height*width, 2]
        base_grid = einops.rearrange(base_grid, "c h w -> (h w) c")  
        
        # 2) Create an image_idx vector for these pixels
        image_idx = torch.full((base_grid.shape[0], 1), float(img_idx), device=device)
        
        # 3) Concatenate: [height*width, 3] => (image_idx, y, x)
        coords = torch.cat([image_idx, base_grid], dim=-1)
        coords_list.append(coords)
    
    # Combine all images into one big array: [num_images * height*width, 3]
    all_coords = torch.cat(coords_list, dim=0)
    return all_coords


In [None]:
num_images = 2
height, width = 192, 96

# shape => [2 * 192 * 96, 3]
output_pos = generate_query_positions(num_images, height, width, device="cpu")

# Typically, you'd add a batch dimension => [1, (2*192*96), 3]
output_pos = output_pos.unsqueeze(0)  
print("output_pos.shape =", output_pos.shape)
# e.g. [1, 36864, 3] if num_images=2

In [None]:
import einops
import torch
from kappamodules.layers import ContinuousSincosEmbed, LinearProjection
from kappamodules.transformer import PerceiverBlock, Mlp
from torch_geometric.utils import unbatch
from torch import nn
from models.base.single_model_base import SingleModelBase
import math

class RansPerceiver(SingleModelBase):
    def __init__(
        self,
        dim,
        num_attn_heads,
        init_weights="xavier_uniform",
        init_last_proj_zero=False,
        use_last_norm=False,
        output_mode="sparse",  
        num_images=1,  # Number of "image-slots" in the channel dimension
        image_dims=None,  # (height, width) for image output
        last_activation=None,  # <-- New parameter
        **kwargs,
    ):
        """
        `last_activation` can be:
            - None                (no activation, unbounded)
            - "sigmoid"          (clamps output to [0,1])
            - "tanh_shift_scale" (maps output to [0,1] via 0.5 * (tanh(x) + 1))
        """
        super().__init__(**kwargs)
        self.dim = dim
        self.num_attn_heads = num_attn_heads
        self.num_images = num_images  
        self.use_last_norm = use_last_norm
        self.output_mode = output_mode
        self.image_dims = image_dims
        self.last_activation = last_activation  # <-- Store last_activation

        # Input projection
        _, input_dim = self.input_shape
        self.proj = LinearProjection(input_dim, dim, init_weights=init_weights)

        # Query tokens (positional embedding + MLP)
        self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=self.static_ctx["ndim"]+1)
        self.query_mlp = Mlp(in_dim=dim, hidden_dim=dim, init_weights=init_weights)

        # Transformer block (latent to tokens)
        self.perceiver = PerceiverBlock(
            dim=dim,
            num_heads=num_attn_heads,
            init_last_proj_zero=init_last_proj_zero,
            init_weights=init_weights,
        )

        # Figure out final channels (e.g., if output_dim=3 (RGB) then final is 3 * num_images)
        _, output_dim = self.output_shape
        # final_channels = output_dim * num_images
        final_channels = output_dim * 1

        # Final projection
        self.pred = LinearProjection(dim, final_channels, init_weights=init_weights)

        # Optional normalization
        self.norm = nn.LayerNorm(dim, eps=1e-6) if use_last_norm else nn.Identity()

    def forward(self, x, query_pos, unbatch_idx, unbatch_select):
        """
        x          : [batch_size, latent_seq_len, dim]
        query_pos  : [batch_size, height*width, pos_dim]
        """
        # 1) Project input
        x = self.proj(x)

        # 2) Create query embeddings
        query_pos_embed = self.pos_embed(query_pos)      # [batch_size, height*width, dim]
        query = self.query_mlp(query_pos_embed)          # [batch_size, height*width, dim]

        # 3) Perceiver decoding
        x = self.perceiver(q=query, kv=x)
        x = self.norm(x)
        x = self.pred(x)  
        # Now shape: [batch_size, height*width, output_dim*num_images]

        # 4) Reshape / rearrange based on self.output_mode
        if self.output_mode == "sparse":
            # E.g. for point-cloud style outputs
            x = einops.rearrange(
                x, 
                "batch_size max_num_points (channels) -> (batch_size max_num_points) channels", 
                channels=self.output_shape[1]  # e.g. 3
            )
            unbatched = unbatch(x, batch=unbatch_idx)
            x = torch.concat([unbatched[i] for i in unbatch_select])

        elif self.output_mode == "dense_to_sparse_unpadded":
            x = einops.rearrange(x, "b seqlen c -> (b seqlen) c")

        elif self.output_mode == "image":
            # Reshape to [batch_size, num_images, output_dim, height, width]
            if self.image_dims is not None:
                height, width = self.image_dims

                # We expect exactly height*width tokens
                expected_tokens = height * width * num_images
                assert x.size(1) == expected_tokens, (
                    f"Expected {expected_tokens} tokens (height*width), got {x.size(1)}"
                )

                batch_size = x.size(0)
                output_dim = self.output_shape[1]  # e.g., 3 (RGB)

                print(x.shape)
                
                # Step-by-step reshape:
                # Currently => [b, height*width, num_images*output_dim]
                # 1) [b, hw, num_images, output_dim]
                x = x.view(batch_size, height * width, self.num_images, output_dim)

                # 2) [b, num_images, output_dim, hw]
                x = x.permute(0, 2, 3, 1)

                # 3) [b, num_images, output_dim, height, width]
                x = x.view(batch_size, self.num_images, output_dim, height, width)

                # Optional final dimension reorder if you prefer [b, num_images, width, height, output_dim]:
                x = x.permute(0, 1, 4, 3, 2)
            else:
                raise ValueError("image_dims must be provided for 'image' output mode.")

        else:
            raise ValueError(f"Unsupported output mode: {self.output_mode}")

        # 5) Apply optional final activation
        if self.last_activation == "sigmoid":
            x = torch.sigmoid(x)
        elif self.last_activation == "tanh_shift_scale":
            # Map from [-1, +1] to [0, 1]
            x = 0.5 * (torch.tanh(x) + 1.0)
        elif self.last_activation is not None:
            raise ValueError(
                f"Unsupported value for last_activation='{self.last_activation}'. "
                "Use None, 'sigmoid', or 'tanh_shift_scale'."
            )

        return x


In [None]:
decoder = RansPerceiver(
    dim = 768,
    num_attn_heads = 12,
    init_weights = "truncnormal",
    input_shape = latent.output_shape,
    output_shape = (None, 3), # 3 Channels for RGB
    static_ctx = {"ndim": 2}, # Images are 2D
    output_mode="image",  # New parameter to control output mode
    num_images=2,
    image_dims=(192,96),  # Optional: Tuple of (height, width) for non-square images
)

In [None]:
pred = decoder(propagated, query_pos=output_pos, unbatch_idx=batch_idx, unbatch_select=[0])

In [None]:
pred.shape

### Try with the Decoder Perceiver Class

In [None]:
from functools import partial

import einops
import torch
from kappamodules.layers import ContinuousSincosEmbed, LinearProjection, Sequential
from kappamodules.transformer import PerceiverBlock, DitPerceiverBlock, DitBlock
from kappamodules.vit import VitBlock
from torch import nn
import math


class DecoderPerceiver(nn.Module):
    def __init__(
            self,
            input_dim,
            output_dim,
            ndim,
            dim,
            depth,
            num_heads,
            unbatch_mode="dense_to_sparse_unpadded",
            perc_dim=None,
            perc_num_heads=None,
            cond_dim=None,
            init_weights="truncnormal002",
            num_images=1, 
            **kwargs,
    ):
        super().__init__(**kwargs)
        perc_dim = perc_dim or dim
        perc_num_heads = perc_num_heads or num_heads
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.ndim = ndim
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.perc_dim = perc_dim
        self.perc_num_heads = perc_num_heads
        self.cond_dim = cond_dim
        self.init_weights = init_weights
        self.unbatch_mode = unbatch_mode
        self.num_images = num_images

        # input projection
        self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights, optional=True)

        # blocks
        if cond_dim is None:
            block_ctor = VitBlock
        else:
            block_ctor = partial(DitBlock, cond_dim=cond_dim)
        self.blocks = Sequential(
            *[
                block_ctor(
                    dim=dim,
                    num_heads=num_heads,
                    init_weights=init_weights,
                )
                for _ in range(depth)
            ],
        )

        # prepare perceiver
        self.pos_embed = ContinuousSincosEmbed(
            dim=perc_dim,
            ndim=ndim,
        )
        if cond_dim is None:
            block_ctor = PerceiverBlock
        else:
            block_ctor = partial(DitPerceiverBlock, cond_dim=cond_dim)

        # decoder
        self.query_proj = nn.Sequential(
            LinearProjection(perc_dim, perc_dim, init_weights=init_weights),
            nn.GELU(),
            LinearProjection(perc_dim, perc_dim, init_weights=init_weights),
        )
        self.perc = block_ctor(dim=perc_dim, kv_dim=dim, num_heads=perc_num_heads, init_weights=init_weights)
        self.pred = nn.Sequential(
            nn.LayerNorm(perc_dim, eps=1e-6),
            LinearProjection(perc_dim, output_dim*num_images, init_weights=init_weights),
        )

    def forward(self, x, output_pos, condition=None):
        # check inputs
        assert x.ndim == 3, "expected shape (batch_size, num_latent_tokens, dim)"
        assert output_pos.ndim == 3, "expected shape (batch_size, num_outputs, dim) num_outputs might be padded"
        if condition is not None:
            assert condition.ndim == 2, "expected shape (batch_size, cond_dim)"

        # pass condition to DiT blocks
        cond_kwargs = {}
        if condition is not None:
            cond_kwargs["cond"] = condition

        # input projection
        x = self.input_proj(x)

        # apply blocks
        x = self.blocks(x, **cond_kwargs)

        # create query
        query = self.pos_embed(output_pos)
        query = self.query_proj(query)

        x = self.perc(q=query, kv=x, **cond_kwargs)
        x = self.pred(x)
        if self.unbatch_mode == "dense_to_sparse_unpadded":
            # dense to sparse where no padding needs to be considered
            x = einops.rearrange(
                x,
                "batch_size seqlen dim -> (batch_size seqlen) dim",
            )
        elif self.unbatch_mode == "image":
            # rearrange to square image
            height, width = 96, 192
            x = einops.rearrange(
                x,
                "batch_size (height width) dim -> batch_size dim height width",
                height=int(height),
                width=int(width),
            )
            x = x.view(1, self.num_images, 3, height, width)
            x = x.permute(0, 1, 3, 4, 2)
        else:
            raise NotImplementedError(f"invalid unbatch_mode '{self.unbatch_mode}'")

        return x

In [None]:
decoder=DecoderPerceiver(
        # tell the decoder the dimension of the input (dim of approximator)
        input_dim=768,
        # 3 channels for RGB
        output_dim=3,
        # images have 2D coordinates
        ndim=2,
        # as in ViT-T
        dim= 768,
        num_heads=12,
        # ViT-T has 12 blocks -> parameters are split evenly among encoder/approximator/decoder
        depth=4,
        # reshape to image after decoding
        unbatch_mode="image",
        num_images=2
    )

In [None]:
decoder

In [None]:
# Generate positions for a single image: 192 x 96 = 18,432 tokens
height, width = 192, 96
base_grid = torch.stack(torch.meshgrid(
    [torch.arange(height), torch.arange(width)], indexing="ij"
))  # [2, 192, 96]
base_grid = einops.rearrange(base_grid, "c h w -> (h w) c")  # [18,432, 2]

# shape => [1, 18432, 2]
output_pos = base_grid.unsqueeze(0).float()

print(output_pos.shape)

In [None]:
pred = decoder(propagated, output_pos=output_pos)

In [None]:
pred.shape

In [None]:
torch.Size([1, 2, 96, 192, 3])