In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:

class PatchEmbedding2D(nn.Module):
    def __init__(self, input_dim, embed_dim, patch_size):
        super(PatchEmbedding2D, self).__init__()
        self.projection = nn.Conv2d(
            in_channels=input_dim,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
        self.flatten = nn.Flatten(start_dim=2)  # Flatten spatial dimensions (H, W)

    def forward(self, x):
        x = self.projection(x)  # (B, embed_dim, H', W')
        x = self.flatten(x)  # (B, embed_dim, H'*W')
        x = x.transpose(1, 2)  # (B, H'*W', embed_dim)
        return x

In [None]:
class PatchRecovery2D(nn.Module):
    def __init__(self, output_dim, embed_dim, patch_size):
        super(PatchRecovery2D, self).__init__()
        self.reconstruction = nn.ConvTranspose2d(
            in_channels=embed_dim,
            out_channels=output_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x, spatial_shape):
        # Reshape to (B, embed_dim, H', W') for ConvTranspose2d
        B, N, C = x.shape
        H, W = spatial_shape
        x = x.transpose(1, 2).view(B, C, H, W)
        return self.reconstruction(x)

In [None]:
class EarthSpecificBlock:
  def __init__(self, dim, drop_path_ratio, heads):
    # Define the window size of the neural network 
    self.window_size = (2, 6, 12)

    # Initialize serveral operations
    self.drop_path = DropPath(drop_rate=drop_path_ratio)
    self.norm1 = LayerNorm(dim)
    self.norm2 = LayerNorm(dim)
    self.linear = MLP(dim, 0)
    self.attention = EarthAttention3D(dim, heads, 0, self.window_size)

  def forward(self, x, Z, H, W, roll):
    # Save the shortcut for skip-connection
    shortcut = x

    # Reshape input to three dimensions to calculate window attention
    reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))

    # Zero-pad input if needed
    x = pad3D(x)

    # Store the shape of the input for restoration
    ori_shape = x.shape

    if roll:
      # Roll x for half of the window for 3 dimensions
      x = roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
      # Generate mask of attention masks
      # If two pixels are not adjacent, then mask the attention between them
      # Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
      mask = gen_mask(x)
    else:
      # e.g., zero matrix when you add mask to attention
      mask = no_mask

    # Reorganize data to calculate window attention
    x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
    x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))

    # Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
    x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))

    # Apply 3D window attention with Earth-Specific bias
    x_window = self.attention(x, mask)

    # Reorganize data to original shapes
    x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
    x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))

    # Reshape the tensor back to its original shape
    x = reshape(x_window, target_shape=ori_shape)

    if roll:
      # Roll x back for half of the window
      x = roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])

    # Crop the zero-padding
    x = Crop3D(x)

    # Reshape the tensor back to the input shape
    x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))

    # Main calculation stages
    x = shortcut + self.drop_path(self.norm1(x))
    x = x + self.drop_path(self.norm2(self.linear(x)))
    return x