# FuXi architecture, model parameters, and I/O size debug

In [1]:
import os
import yaml

import torch

from credit.models import load_model
from credit.models.unet import SegmentationModel
from credit.models.crossformer import CrossFormer
from credit.models.fuxi import Fuxi
from credit.parser import credit_main_parser

In [2]:
CONFIG_FILE_DIR = '/glade/u/home/ksha/miles-physics/config/'

## FuXi unit test

In [3]:
def test_fuxi(): 
    # load config
    config = os.path.join(CONFIG_FILE_DIR, 'fuxi_1deg_test.yml')
    with open(config) as cf:
        conf = yaml.load(cf, Loader=yaml.FullLoader)
    # handle config args
    conf = credit_main_parser(conf)
    
    image_height = conf["model"]["image_height"]
    image_width = conf["model"]["image_width"]
    channels = conf["model"]["channels"]
    levels = conf["model"]["levels"]
    surface_channels = conf["model"]["surface_channels"]
    input_only_channels = conf["model"]["input_only_channels"]
    output_only_channels = conf["model"]["output_only_channels"]
    frames = conf["model"]["frames"]
    
    in_channels = channels * levels + surface_channels + input_only_channels
    out_channels = channels * levels + surface_channels + output_only_channels
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_tensor = torch.randn(1, in_channels, frames, image_height, image_width).to(device)
    
    model = load_model(conf).to(device)
    assert isinstance(model, Fuxi)
    
    y_pred = model(input_tensor)
    assert y_pred.shape == torch.Size([1, out_channels, 1, image_height, image_width])
    assert not torch.isnan(y_pred).any()

In [4]:
test_fuxi()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


## FuXi dev section

In [5]:
import torch
from torch import nn
from torch.nn import functional as F
from timm.layers.helpers import to_2tuple
from timm.models.swin_transformer_v2 import SwinTransformerV2Stage
import logging

from credit.postblock import PostBlock
from credit.models.base_model import BaseModel
from credit.boundary_padding import TensorPadding

logger = logging.getLogger(__name__)


def apply_spectral_norm(model):
    """
    add spectral norm to all the conv and linear layers
    """
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
            nn.utils.spectral_norm(module)


def get_pad3d(input_resolution, window_size):
    """
    Estimate the size of padding based on the given window size and the original input size.

    Args:
        input_resolution (tuple[int]): (Pl, Lat, Lon)
        window_size (tuple[int]): (Pl, Lat, Lon)

    Returns:
        padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
    """
    Pl, Lat, Lon = input_resolution
    win_pl, win_lat, win_lon = window_size

    padding_left = padding_right = padding_top = padding_bottom = padding_front = (
        padding_back
    ) = 0
    pl_remainder = Pl % win_pl
    lat_remainder = Lat % win_lat
    lon_remainder = Lon % win_lon

    if pl_remainder:
        pl_pad = win_pl - pl_remainder
        padding_front = pl_pad // 2
        padding_back = pl_pad - padding_front
    if lat_remainder:
        lat_pad = win_lat - lat_remainder
        padding_top = lat_pad // 2
        padding_bottom = lat_pad - padding_top
    if lon_remainder:
        lon_pad = win_lon - lon_remainder
        padding_left = lon_pad // 2
        padding_right = lon_pad - padding_left

    return (
        padding_left,
        padding_right,
        padding_top,
        padding_bottom,
        padding_front,
        padding_back,
    )


def get_pad2d(input_resolution, window_size):
    """
    Args:
        input_resolution (tuple[int]): Lat, Lon
        window_size (tuple[int]): Lat, Lon

    Returns:
        padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom)
    """
    input_resolution = [2] + list(input_resolution)
    window_size = [2] + list(window_size)
    padding = get_pad3d(input_resolution, window_size)
    return padding[:4]


class CubeEmbedding(nn.Module):
    """
    Args:
        img_size: T, Lat, Lon
        patch_size: T, Lat, Lon
    """

    def __init__(
        self, img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm
    ):
        super().__init__()

        # input size
        self.img_size = img_size

        # number of patches after embedding (T_num, Lat_num, Lon_num)
        patches_resolution = [
            img_size[0] // patch_size[0],
            img_size[1] // patch_size[1],
            img_size[2] // patch_size[2],
        ]
        self.patches_resolution = patches_resolution

        # number of embedded dimension after patching
        self.embed_dim = embed_dim

        # Conv3d-based patching
        self.proj = nn.Conv3d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

        # layer norm
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x: torch.Tensor):
        # example size: [Batch, 67, 2, 640, 1280]
        B, T, C, Lat, Lon = x.shape

        # Conv3d-based patching and embedding
        # output size: [B, 1024, 1, 40, 80]
        x = self.proj(x)

        # combine T, Lat, Lon dimensions
        # output size: [B, 1024, 3200]
        x = x.reshape(B, self.embed_dim, -1)

        # switch to channel-last for normalization
        # output size: [B, 3200, 1024]
        x = x.transpose(1, 2)  # B T*Lat*Lon C

        # Layer norm (channel last)
        if self.norm is not None:
            x = self.norm(x)

        # switch back to channel first
        # output size: [B, 1024, 3200]
        x = x.transpose(1, 2)

        # recover T, Lat, Lon dimensions
        # output size: [B, 1024, 1, 40, 80]
        x = x.reshape(B, self.embed_dim, *self.patches_resolution)

        return x


class DownBlock(nn.Module):
    def __init__(
        self, in_chans: int, out_chans: int, num_groups: int, num_residuals: int = 2
    ):
        super().__init__()

        # down-sampling with Conv2d
        self.conv = nn.Conv2d(
            in_chans, out_chans, kernel_size=(3, 3), stride=2, padding=1
        )

        # blocks of residual path
        blk = []
        for i in range(num_residuals):
            blk.append(
                nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1)
            )
            blk.append(nn.GroupNorm(num_groups, out_chans))
            blk.append(nn.SiLU())
        self.b = nn.Sequential(*blk)

    def forward(self, x):
        # down-sampling
        x = self.conv(x)

        # skip-connection
        shortcut = x

        # residual path
        x = self.b(x)

        # additive residual connection
        return x + shortcut


class UpBlock(nn.Module):
    def __init__(self, in_chans, out_chans, num_groups, num_residuals=2):
        super().__init__()

        # down-sampling with Transpose Conv
        self.conv = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)

        # blocks of residual path
        blk = []
        for i in range(num_residuals):
            blk.append(
                nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1)
            )
            blk.append(nn.GroupNorm(num_groups, out_chans))
            blk.append(nn.SiLU())
        self.b = nn.Sequential(*blk)

    def forward(self, x):
        # up-sampling
        x = self.conv(x)

        # skip-connection
        shortcut = x

        # residual path
        x = self.b(x)

        # additive residual connection
        return x + shortcut


class UTransformer(nn.Module):
    """U-Transformer
    Args:
        embed_dim (int): Patch embedding dimension.
        num_groups (int | tuple[int]): number of groups to separate the channels into.
        input_resolution (tuple[int]): Lat, Lon.
        num_heads (int): Number of attention heads in different layers.
        window_size (int | tuple[int]): Window size.
        depth (int): Number of blocks.
    """

    def __init__(
        self, embed_dim, 
        num_groups, 
        input_resolution, 
        num_heads, 
        window_size, 
        depth,
        drop_path
    ):
        super().__init__()
        num_groups = to_2tuple(num_groups)
        window_size = to_2tuple(window_size)  # convert window_size[int] to tuple

        # padding input tensors so they are divided by the window size
        padding = get_pad2d(input_resolution, window_size)  # <--- Accepts tuple only
        padding_left, padding_right, padding_top, padding_bottom = padding
        self.padding = padding
        self.pad = nn.ZeroPad2d(padding)

        # input resolution after padding
        input_resolution = list(input_resolution)
        input_resolution[0] = input_resolution[0] + padding_top + padding_bottom
        input_resolution[1] = input_resolution[1] + padding_left + padding_right

        # down-sampling block
        self.down = DownBlock(embed_dim, embed_dim, num_groups[0])

        # SwinT block
        self.layer = SwinTransformerV2Stage(
            embed_dim, 
            embed_dim, 
            input_resolution, 
            depth, 
            num_heads, 
            window_size[0],
            drop_path=drop_path
        )  # <--- window_size[0] get window_size[int] from tuple

        # up-sampling block
        self.up = UpBlock(embed_dim * 2, embed_dim, num_groups[1])

    def forward(self, x):
        B, C, Lat, Lon = x.shape
        padding_left, padding_right, padding_top, padding_bottom = self.padding
        x = self.down(x)
        shortcut = x

        # pad
        x = self.pad(x)
        _, _, pad_lat, pad_lon = x.shape

        x = x.permute(0, 2, 3, 1)  # B Lat Lon C
        x = self.layer(x)
        x = x.permute(0, 3, 1, 2)

        # crop
        x = x[
            :,
            :,
            padding_top : pad_lat - padding_bottom,
            padding_left : pad_lon - padding_right,
        ]

        # concat
        x = torch.cat([shortcut, x], dim=1)  # B 2*C Lat Lon
        x = self.up(x)
        return x


class Fuxi(BaseModel):
    """
    Args:
        img_size (Sequence[int], optional): T, Lat, Lon.
        patch_size (Sequence[int], optional): T, Lat, Lon.
        in_chans (int, optional): number of input channels.
        out_chans (int, optional): number of output channels.
        dim (int, optional): number of embed channels.
        num_groups (Sequence[int] | int, optional): number of groups to separate the channels into.
        num_heads (int, optional): Number of attention heads.
        window_size (int | tuple[int], optional): Local window size.
    """

    def __init__(
        self,
        image_height=640,  # 640
        patch_height=16,
        image_width=1280,  # 1280
        patch_width=16,
        levels=15,
        frames=2,
        frame_patch_size=2,
        dim=1536,
        num_groups=32,
        channels=4,
        surface_channels=7,
        input_only_channels=0,
        output_only_channels=0,
        num_heads=8,
        depth=48,
        window_size=7,
        use_spectral_norm=True,
        interp=True,
        drop_path=0,
        padding_conf=None,
        post_conf=None,
        **kwargs,
    ):
        super().__init__()

        self.use_interp = interp
        self.use_spectral_norm = use_spectral_norm
        if padding_conf is None:
            padding_conf = {"activate": False}
        self.use_padding = padding_conf["activate"]
        if post_conf is None:
            post_conf = {"activate": False}
        self.use_post_block = post_conf["activate"]

        # input tensor size (time, lat, lon)
        if self.use_padding:
            pad_lat = padding_conf["pad_lat"]
            pad_lon = padding_conf["pad_lon"]
            image_height_pad = image_height + pad_lat[0] + pad_lat[1]
            image_width_pad = image_width + pad_lon[0] + pad_lon[1]
            img_size = (frames, image_height_pad, image_width_pad)
            self.img_size_original = (frames, image_height, image_width)
        else:
            img_size = (frames, image_height, image_width)
            self.img_size_original = img_size

        # the size of embedded patches
        patch_size = (frame_patch_size, patch_height, patch_width)

        # number of channels = levels * varibales per level + surface variables
        # in_chans = out_chans = levels * channels + surface_channels

        in_chans = channels * levels + surface_channels + input_only_channels
        out_chans = channels * levels + surface_channels + output_only_channels

        # input resolution = number of embedded patches / 2
        # divide by two because "u_trasnformer" has a down-sampling block

        input_resolution = (
            round(img_size[1] / patch_size[1] / 2),
            round(img_size[2] / patch_size[2] / 2),
        )
        # FuXi cube embedding layer
        self.cube_embedding = CubeEmbedding(img_size, patch_size, in_chans, dim)

        # Downsampling --> SwinTransformerV2 stacks --> Upsampling
        self.u_transformer = UTransformer(
            dim, num_groups, 
            input_resolution, 
            num_heads, 
            window_size, 
            depth=depth,
            drop_path=drop_path
        )

        # dense layer applied on channel dmension
        # channel * patch_size beucase dense layer recovers embedded dimensions to the input dimensions
        self.fc = nn.Linear(dim, out_chans * patch_size[1] * patch_size[2])

        # Hyperparameters
        self.patch_size = patch_size
        self.input_resolution = input_resolution
        self.out_chans = out_chans
        self.img_size = img_size

        self.channels = channels
        self.surface_channels = surface_channels
        self.levels = levels

        if self.use_padding:
            self.padding_opt = TensorPadding(**padding_conf)

        if self.use_spectral_norm:
            logger.info("Adding spectral norm to all conv and linear layers")
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            # Move the model to the device
            self.to(device)
            apply_spectral_norm(self)

        if self.use_post_block:
            self.postblock = PostBlock(post_conf)

    def forward(self, x: torch.Tensor):
        # copy tensor to feed into postblock later
        x_copy = None
        if self.use_post_block:
            x_copy = x.clone().detach()

        if self.use_padding:
            x = self.padding_opt.pad(x)

        # Tensor dims: Batch, Variables, Time, Lat grids, Lon grids
        B, _, _, _, _ = x.shape

        _, patch_lat, patch_lon = self.patch_size

        # Get the number of patches after embedding
        Lat, Lon = self.input_resolution
        Lat, Lon = Lat * 2, Lon * 2

        # Cube Embedding and squeese the time dimension
        # (the model produce single forecast lead time only)

        # x: input size = (Batch, Variables, Time, Lat grids, Lon grids)
        x = self.cube_embedding(x).squeeze(2)  # B C Lat Lon
        # x: output size = (Batch, Embedded dimension, time, number of patches, number of patches)

        # u_transformer stage
        # the size of x does notchange
        x = self.u_transformer(x)

        # recover embeddings to lat/lon grids with dense layer and reshape operation.
        x = self.fc(x.permute(0, 2, 3, 1))  # B Lat Lon C
        x = x.reshape(B, Lat, Lon, patch_lat, patch_lon, self.out_chans).permute(
            0, 1, 3, 2, 4, 5
        )
        # B, lat, patch_lat, lon, patch_lon, C
        x = x.reshape(B, Lat * patch_lat, Lon * patch_lon, self.out_chans)
        x = x.permute(0, 3, 1, 2)  # B C Lat Lon

        if self.use_padding:
            x = self.padding_opt.unpad(x)

        if self.use_interp:
            img_size = list(self.img_size_original)
            x = F.interpolate(x, size=img_size[1:], mode="bilinear")

        x = x.unsqueeze(2)

        if self.use_post_block:
            x = {
                "y_pred": x,
                "x": x_copy,
            }
            x = self.postblock(x)

        return x

## FuXi param check

In [6]:
# old rollout config
#config_name = '/glade/u/home/ksha/miles-credit/config/example_physics_single.yml'

config_name = '/glade/work/ksha/CREDIT_runs/fuxi_dry_tune/model_single.yml'

# Read YAML file
with open(config_name, 'r') as stream:
    conf = yaml.safe_load(stream)

In [7]:
conf = credit_main_parser(conf)

In [8]:
# conf['model']['post_conf']['activate'] = False
# conf['model']['interp'] = False
# conf['model']['padding_conf']['pad_lat'] = [21, 22]
# conf['model']['padding_conf']['pad_lon'] = [44, 44]

In [1]:
image_height = conf['model']['image_height']
image_width = conf['model']['image_width']
levels = conf['model']['levels']
frames = conf['model']['frames']
channels = conf['model']['channels']
surface_channels = conf['model']['surface_channels']
input_only_channels = conf['model']['input_only_channels']
output_only_channels = conf['model']['output_only_channels']

# ============================================================= #
# build the model
model = Fuxi(**conf['model']).to("cuda")

# ============================================================= #
# test the model

# pass an input tensor to test the graph
input_tensor = torch.randn(1, channels * levels + surface_channels + input_only_channels, 
                           frames, image_height, image_width).to("cuda")    

y_pred = model(input_tensor.to("cuda"))

print('Input shape: {}'.format(input_tensor.shape))
print("Predicted shape: {}".format(y_pred.shape))

## FuXi I/O size and padding

**Auto-detect pad size**

In [82]:
import math

In [118]:
def compute_padding_sizes(image_height, image_width, window_size, patch_height, patch_width, base_val=40, N_lat_add=0, N_lon_add=0):
    """
    Computes the required padding sizes pad_lat and pad_lon for given image dimensions,
    window size, and patch sizes, such that after padding:
    - The input resolutions used in the model are divisible by window_size.
    - After unpadding and any interpolation, the output tensor size is the same as the original.
    - The padding sizes are as close as possible to [base_val, base_val].
    - Allows adjustment of required input resolutions by adding multiples of window_size.
    
    Args:
        image_height (int): Original image height.
        image_width (int): Original image width.
        window_size (int): Window size used in the model.
        patch_height (int): Patch height.
        patch_width (int): Patch width.
        base_val (int): Desired base padding value for each side.
        N_lat_add (int): Additional window sizes to add to required input resolution for latitude.
        N_lon_add (int): Additional window sizes to add to required input resolution for longitude.
    
    Returns:
        tuple: pad_lat (list[int]), pad_lon (list[int])
               Padding sizes for latitude and longitude ([top, bottom], [left, right]).
    """
    frames = 2
    frame_patch_size = 2

    # Calculate initial input resolutions without padding
    input_resolution_lat = image_height / patch_height / 2
    input_resolution_lon = image_width / patch_width / 2

    # Calculate minimal required input resolutions that are divisible by window_size
    N_lat_min = math.ceil(input_resolution_lat / window_size)
    N_lon_min = math.ceil(input_resolution_lon / window_size)

    # Adjust required input resolutions by adding additional window sizes
    N_lat = N_lat_min + N_lat_add
    N_lon = N_lon_min + N_lon_add

    required_input_resolution_lat = N_lat * window_size
    required_input_resolution_lon = N_lon * window_size

    # Adjusted image dimensions after padding
    image_height_padded = required_input_resolution_lat * patch_height * 2
    image_width_padded = required_input_resolution_lon * patch_width * 2

    # Calculate total padding required
    pad_lat_total = int(image_height_padded - image_height)
    pad_lon_total = int(image_width_padded - image_width)

    # Check if total padding is non-negative
    if pad_lat_total < 0 or pad_lon_total < 0:
        return None, None

    # Distribute padding for latitude (height)
    pad_lat = distribute_padding(pad_lat_total, base_val)

    # Distribute padding for longitude (width)
    pad_lon = distribute_padding(pad_lon_total, base_val)

    return pad_lat, pad_lon

def distribute_padding(total_padding, base_val):
    if total_padding == 0:
        return [0, 0]
    
    # Distribute padding evenly
    pad_first = total_padding // 2
    pad_second = total_padding - pad_first

    # Adjust padding to be as close as possible to base_val
    if pad_first > base_val:
        pad_first = base_val
        pad_second = total_padding - pad_first
    if pad_second > base_val:
        pad_second = base_val
        pad_first = total_padding - pad_second

    # Ensure total padding matches
    if pad_first + pad_second != total_padding:
        pad_second += total_padding - (pad_first + pad_second)

    return [pad_first, pad_second]

In [122]:
compute_padding_sizes(721, 1440, 7, 4, 4, base_val=200, N_lat_add=2, N_lon_add=2)

([59, 60], [64, 64])