# GraphCastEncoderEmbedder

First used in [1] and defined in [2].

```python
"""GraphCast feature embedder for gird node features, multimesh node features,
    grid2mesh edge features, and multimesh edge features."""
```

- [1] https://vscode.dev/github/NVIDIA/modulus/blob/main/modulus/models/graphcast/graph_cast_net.py#L411
- [2] https://vscode.dev/github/NVIDIA/modulus/blob/main/modulus/models/gnn_layers/embedder.py#L25

In [1]:
%run review/__common.py
%load_ext autoreload
%autoreload 2



In [2]:
# Invoked as:
# self.encoder_embedder = GraphCastEncoderEmbedder(
#     input_dim_grid_nodes=input_dim_grid_nodes, # 31
#     input_dim_mesh_nodes=input_dim_mesh_nodes, # 3 (dafault)
#     input_dim_edges=input_dim_edges, # 4 (default)
#     output_dim=hidden_dim, # 64
#     hidden_dim=hidden_dim, # 64
#     hidden_layers=hidden_layers, # 1 (default)
#     activation_fn=activation_fn, # get_activation("silu")
#     norm_type=norm_type, # TELayerNorm
#     recompute_activation=recompute_activation, # True
# )

# Defined as:
# class GraphCastEncoderEmbedder(nn.Module):
#     def __init__(
#         self,
#         input_dim_grid_nodes: int = 474,
#         input_dim_mesh_nodes: int = 3,
#         input_dim_edges: int = 4,
#         output_dim: int = 512,
#         hidden_dim: int = 512,
#         hidden_layers: int = 1,
#         activation_fn: nn.Module = nn.SiLU(),
#         norm_type: str = "LayerNorm",
#         recompute_activation: bool = False,
#     ):
#         ...

from modulus.models.layers import get_activation

input_dim_grid_nodes = 31
input_dim_mesh_nodes = 3
input_dim_edges = 4
output_dim = 64
hidden_dim = 64
hidden_layers = 1
activation_fn = get_activation("silu")
norm_type = "TELayerNorm"
recompute_activation = True

In [3]:
from modulus.models.gnn_layers.mesh_graph_mlp import MeshGraphMLP

grid_node_mlp = MeshGraphMLP(
    input_dim=input_dim_grid_nodes,
    output_dim=output_dim,
    hidden_dim=hidden_dim,
    hidden_layers=hidden_layers,
    activation_fn=activation_fn,
    norm_type=norm_type,
    recompute_activation=recompute_activation,
)

In [4]:
grid_node_mlp

MeshGraphMLP(
  (model): Sequential(
    (0): Linear(in_features=31, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): LayerNorm()
  )
)

In [5]:
var_in = torch.randn(3, 32)
var_in, var_in.shape

(tensor([[-1.3673e+00,  4.9746e-01,  6.1957e-01, -6.8527e-02,  3.3726e-01,
          -4.6997e-01,  1.2477e+00, -4.0895e-01,  7.9581e-02,  1.1556e+00,
          -1.8389e+00, -6.2278e-01,  1.1993e+00, -5.0889e-01, -6.8851e-01,
          -3.0719e-01,  3.1220e-01,  5.4997e-01, -1.1058e+00, -3.4738e-01,
          -4.3666e-01, -1.7158e+00,  1.7773e+00, -4.9188e-01,  3.5039e-01,
          -1.2541e+00, -6.1864e-01, -1.0331e-01, -9.5205e-01, -3.2043e-01,
          -7.0157e-01, -1.3125e+00],
         [ 1.6349e+00,  3.1402e-01, -2.7442e-01, -8.7179e-01, -4.3892e-01,
          -1.2839e+00, -6.3945e-01, -1.0386e+00, -1.3345e+00,  3.2789e-01,
          -1.8072e+00, -8.4228e-01, -2.1458e+00, -1.1730e-01, -1.1539e+00,
          -4.0861e-01, -5.6240e-01, -9.9082e-01,  1.0773e+00, -7.0471e-01,
          -2.7601e+00, -9.9589e-02,  1.6313e+00, -1.0782e-01, -2.3176e+00,
           1.1752e+00, -5.8979e-01, -5.1697e-01,  3.0977e-01, -1.2914e+00,
          -1.0224e+00, -8.3342e-01],
         [ 1.5534e-01,  1.

In [7]:
grid_node_mlp(var_in), grid_node_mlp(var_in).shape

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x32 and 31x64)