In [4]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
import numpy as np


In [5]:
surface = '/home/rwkv/RWKV-TS/WeatherBench/input_surface.npy'
upper = '/home/rwkv/RWKV-TS/WeatherBench/input_upper.npy'

import numpy as np
surface = np.load(surface)
upper = np.load(upper)
# add time dimension
surface = np.expand_dims(surface, axis=0)
upper = np.expand_dims(upper, axis=0)
print(surface.shape, upper.shape)

(1, 4, 721, 1440) (1, 5, 13, 721, 1440)


In [6]:
land_masks = np.load('/home/rwkv/RWKV-TS/WeatherBench/constant_masks/land_mask.npy')[180:308,440:568]
soil_types = np.load('/home/rwkv/RWKV-TS/WeatherBench/constant_masks/soil_type.npy')[180:308,440:568]
topography = np.load('/home/rwkv/RWKV-TS/WeatherBench/constant_masks/topography.npy')[180:308,440:568]

print(land_masks.shape, soil_types.shape, topography.shape)

(128, 128) (128, 128) (128, 128)


# Our finall data shape will be (B,T,5,13,128,128) and (B,T,4,128,128)

In [None]:
soil_types

In [5]:

##self defined 
# def drop_path(x, drop_prob: float = 0., training: bool = False):
#     if drop_prob == 0. or not training:
#         return x
#     keep_prob = 1 - drop_prob
#     shape = (x.shape[0],) + (1,) * (x.ndim - 1)  ##(B,1,1,1,1...)
#     random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
#     # print(random_tensor)
#     random_tensor.floor_()  
#     # print(random_tensor)
#     output = x.div(keep_prob) * random_tensor  # maintain E
#     return output


# class DropPath(nn.Module):
#     def __init__(self, drop_prob=None):
#         super(DropPath, self).__init__()
#         self.drop_prob = drop_prob

#     def forward(self, x):
#         return drop_path(x, self.drop_prob, self.training)



def Inference(input, input_surface, forecast_range):
  '''Inference code, describing the algorithm of inference using models with different lead times. 
  PanguModel24, PanguModel6, PanguModel3 and PanguModel1 share the same training algorithm but differ in lead times.
  Args:
    input: input tensor, need to be normalized to N(0, 1) in practice
    input_surface: target tensor, need to be normalized to N(0, 1) in practice
    forecast_range: iteration numbers when roll out the forecast model
  '''

  # Load 4 pre-trained models with different lead times
  PanguModel24 = LoadModel(ModelPath24)
  PanguModel6 = LoadModel(ModelPath6)
  PanguModel3 = LoadModel(ModelPath3)
  PanguModel1 = LoadModel(ModelPath1)

  # Load mean and std of the weather data
  weather_mean, weather_std, weather_surface_mean, weather_surface_std = LoadStatic()

  # Store initial input for different models
  input_24, input_surface_24 = input, input_surface
  input_6, input_surface_6 = input, input_surface
  input_3, input_surface_3 = input, input_surface

  # Using a list to store output
  output_list = []

  # Note: the following code is implemented for fast inference of [1,forecast_range]-hour forecasts -- if only one lead time is requested, the inference can be much faster.
  for i in range(forecast_range):
    # switch to the 24-hour model if the forecast time is 24 hours, 48 hours, ..., 24*N hours
    if (i+1) % 24 == 0:
      # Switch the input back to the stored input
      input, input_surface = input_24, input_surface_24

      # Call the model pretrained for 24 hours forecast
      output, output_surface = PanguModel24(input, input_surface)

      # Restore from uniformed output
      output = output * weather_std + weather_mean
      output_surface = output_surface * weather_surface_std + weather_surface_mean

      # Stored the output for next round forecast
      input_24, input_surface_24 = output, output_surface
      input_6, input_surface_6 = output, output_surface
      input_3, input_surface_3 = output, output_surface

    # switch to the 6-hour model if the forecast time is 30 hours, 36 hours, ..., 24*N + 6/12/18 hours
    elif (i+1) % 6 == 0:
      # Switch the input back to the stored input
      input, input_surface = input_6, input_surface_6

      # Call the model pretrained for 6 hours forecast
      output, output_surface = PanguModel6(input, input_surface)

      # Restore from uniformed output
      output = output * weather_std + weather_mean
      output_surface = output_surface * weather_surface_std + weather_surface_mean
      
      # Stored the output for next round forecast
      input_6, input_surface_6 = output, output_surface
      input_3, input_surface_3 = output, output_surface

    # switch to the 3-hour model if the forecast time is 3 hours, 9 hours, ..., 6*N + 3 hours
    elif (i+1) % 3 ==0:
      # Switch the input back to the stored input
      input, input_surface = input_3, input_surface_3

      # Call the model pretrained for 3 hours forecast
      output, output_surface = PanguModel3(input, input_surface)

      # Restore from uniformed output
      output = output * weather_std + weather_mean
      output_surface = output_surface * weather_surface_std + weather_surface_mean
      
      # Stored the output for next round forecast
      input_3, input_surface_3 = output, output_surface

    # switch to the 1-hour model
    else:
      # Call the model pretrained for 1 hours forecast
      output, output_surface = PanguModel1(input, input_surface)

      # Restore from uniformed output
      output = output * weather_std + weather_mean
      output_surface = output_surface * weather_surface_std + weather_surface_mean

    # Stored the output for next round forecast
    input, input_surface = output, output_surface

    # Save the output
    output_list.append((output, output_surface))
  return output_list


In [7]:



def Train(path='PanguModel.pth'):
  '''Training code'''
  # Initialize the model, for some APIs some adaptation is needed to fit hardwares
  model = PanguModel()
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=3e-6)
  # Train single Pangu-Weather model
  epochs = 100
  for i in range(epochs):

    # dataset_length is the length of your training data, e.g., the sample between 1979 and 2017
    for step in range(upper.shape[0]):
      # Load weather data at time t as the input; load weather data at time t+1/3/6/24 as the output
      # Note the data need to be randomly shuffled
      # Note the input and target need to be normalized, see Inference() for details

      input, input_surface, target, target_surface = LoadData(step)

      # Call the model and get the output
      output, output_surface = model(input, input_surface)

      # We use the MAE loss to train the model
      # The weight of surface loss is 0.25
      # Different weight can be applied for differen fields if needed
      loss = torch.abs(output - target) + torch.abs(output_surface - target_surface) * 0.25
      optimizer.zero_grad()
      # loss = TensorAbs(output-target) + TensorAbs(output_surface-target_surface) * 0.25

      # Call the backward algorithm and calculate the gratitude of parameters
      loss.backward()
      optimizer.step()

  # Save the model at the end of the training stage
  # SaveModel()
  ModelPath = path
  torch.save(model.state_dict(), ModelPath)
  
class PanguModel(nn.Module):
  def __init__(self):
    super(PanguModel, self).__init__()
    # Drop path rate is linearly increased as the depth increases
    # drop_path_list = LinearSpace(0, 0.2, 8)
    drop_path_list = torch.linspace(0, 0.2, 8)
    # Patch embedding
    self._input_layer = PatchEmbedding((2, 4, 4), 192)

    # Four basic layers
    self.layer1 = EarthSpecificLayer(2, 192, drop_path_list[:2], 6,down=False)
    self.layer2 = EarthSpecificLayer(6, 384, drop_path_list[:6], 12, down=True)
    self.layer3 = EarthSpecificLayer(6, 384, drop_path_list[:6], 12,down=True)
    self.layer4 = EarthSpecificLayer(2, 192, drop_path_list[:2], 6, down=False)

    # Upsample and downsample
    self.upsample = UpSample(384, 192)
    self.downsample = DownSample(192)

    # Patch Recovery
    self._output_layer = PatchRecovery(384,patch_size=(2, 4, 4))
    
  def forward(self, input, input_surface):
    '''Backbone architecture'''
    # Embed the input fields into patches
    x = self._input_layer(input, input_surface)

    # Encoder, composed of two layers
    # Layer 1, shape (8, 360, 181, C), C = 192 as in the original paper
    x = self.layer1(x, 7, 32, 32) 
    print('layer1start',x.shape)
    # Store the tensor for skip-connection
    skip = x

    # Downsample from (8, 360, 181) to (8, 180, 91)
    x = self.downsample(x, 7, 32, 32)
    print('layer2start',x.shape)
    # Layer 2, shape (8, 180, 91, 2C), C = 192 as in the original paper
    x = self.layer2(x, 7, 16, 16) 

    # Decoder, composed of two layers
    # Layer 3, shape (8, 180, 91, 2C), C = 192 as in the original paper
    x = self.layer3(x, 7, 16, 16) 

    # Upsample from (8, 180, 91) to (8, 360, 181)
    x = self.upsample(x)

    # Layer 4, shape (8, 360, 181, 2C), C = 192 as in the original paper
    x = self.layer4(x, 7, 32, 32) 

    # Skip connect, in last dimension(C from 192 to 384)
    # x = Concatenate(skip, x)
    x = torch.cat((x, skip), dim=-1)

    # Recover the output fields from patches
    output, output_surface = self._output_layer(x,7,32,32)
    return output, output_surface

class PatchEmbedding(nn.Module):
  def __init__(self, patch_size, dim):
    '''Patch embedding operation'''
    super(PatchEmbedding, self).__init__()
    # Here we use convolution to partition data into cubes
    self.conv = nn.Conv3d(5, dim, kernel_size=patch_size, stride=patch_size)
    self.conv_surface = nn.Conv2d(4, dim, kernel_size=patch_size[1:], stride=patch_size[1:])

    # Load constant masks from the disc
    # self.land_mask, self.soil_type, self.topography = LoadConstantMask()
    self.land_mask, self.soil_type, self.topography = torch.tensor(land_masks), torch.tensor(soil_types), torch.tensor(topography)
    
  def forward(self, input, input_surface):
    # Zero-pad the input
    # input = Pad3D(input)
    input = nn.ZeroPad3d(padding=(0, 0, 0, 0, 0, 0))(input)
    # print(input.shape)
    # input_surface = Pad2D(input_surface)
    input_surface = nn.ZeroPad2d(padding=(0, 0))(input_surface)
    # print(input_surface.shape)
    # Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
    input = self.conv(input)
    # print('input:',input.shape)
    # Add three constant fields to the surface fields
    # input_surface =  Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
    land_mask_expanded = self.land_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, 721, 1440)
    soil_type_expanded = self.soil_type.unsqueeze(0).unsqueeze(0)  
    topography_expanded = self.topography.unsqueeze(0).unsqueeze(0)  
    # Add the expanded fields to input_surface
    input_surface = input_surface + land_mask_expanded + soil_type_expanded + topography_expanded
    # Apply a linear projection for patch_size[1]*patch_size[2] patches
    # print(input_surface.shape)
    input_surface = self.conv_surface(input_surface)
    # print('surface:',input_surface.shape)
    # Concatenate the input in the pressure level, i.e., in Z dimension
    # x = Concatenate(input, input_surface)
    
    x = torch.cat((input, input_surface.unsqueeze(2)), dim=2)
    # print(x.shape)
    #x (B, C, Z, H, W)
    # Reshape x for calculation of linear projections
    # x = TransposeDimensions(x, (0, 2, 3, 4, 1))
    x = x.permute(0, 2, 3, 4, 1) #channel first to channel last
    # print(x.shape)
    x = x.reshape(x.shape[0], 7*32*32, x.shape[-1])
    return x
    
 
class PatchRecovery(nn.Module):
  def __init__(self, dim,patch_size):
    '''Patch recovery operation'''
    super(PatchRecovery, self).__init__()
    # Hear we use two transposed convolutions to recover data
    self.conv = nn.ConvTranspose3d(dim, 5, kernel_size=patch_size, stride=patch_size)
    self.conv_surface = nn.ConvTranspose2d(dim, 4, kernel_size=patch_size[1:], stride=patch_size[1:])
    
  def forward(self, x, Z, H, W):
    # The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
    # Reshape x back to three dimensions
    x = x.permute(0, 2, 1)
    # x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
    x = x.reshape(x.shape[0], x.shape[1], Z, H, W)

    # Call the transposed convolution
    output = self.conv(x[:, :, 1:, :, :])
    output_surface = self.conv_surface(x[:, :, 0, :, :])

    # Crop the output to remove zero-paddings
    # output = Crop3D(output)
    # output_surface = Crop2D(output_surface)
    return output, output_surface


class DownSample(nn.Module):
  def __init__(self, dim):
    '''Down-sampling operation'''
    super(DownSample, self).__init__()
    # A linear function and a layer normalization
    self.linear = nn.Linear(4*dim, 2*dim, bias=False)
    self.norm = nn.LayerNorm(4*dim)
  
  def forward(self, x, Z, H, W):
    # Reshape x to three dimensions for downsampling
    # x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))

    x = x.reshape(x.shape[0], Z, H, W, x.shape[-1])
    # Padding the input to facilitate downsampling
    # x = Pad3D(x)
    print("down x", x.shape)
    # Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
    # Z, H, W = x.shape[-3,-2,-1]
    # Reshape x to facilitate downsampling
    # x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
    x = x.reshape(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1])
    # Change the order of x
    # x = TransposeDimensions(x, (0,1,2,4,3,5,6))
    x = x.permute(0, 1, 2, 4, 3, 5, 6)
    # Reshape to get a tensor of resolution (8, 180, 91)
    # x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
    x = x.reshape(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1])
    # Call the layer normalization
    x = self.norm(x)

    # Decrease the channels of the data to reduce computation cost
    x = self.linear(x)
    return x
class UpSample(nn.Module):
  def __init__(self, input_dim, output_dim):
    '''Up-sampling operation'''
    super(UpSample, self).__init__()
    # Linear layers without bias to increase channels of the data
    self.linear1 = nn.Linear(input_dim, output_dim*4, bias=False)

    # Linear layers without bias to mix the data up
    self.linear2 = nn.Linear(output_dim, output_dim, bias=False)

    # Normalization
    self.norm = nn.LayerNorm(output_dim)
  
  def forward(self, x):
    # Call the linear functions to increase channels of the data
    x = self.linear1(x)

    # Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
    # Reshape x to facilitate upsampling.
    # x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
    x = x.reshape(x.shape[0], 7, 16,16, 2, 2, x.shape[-1]//4)
    # Change the order of x
    # x = TransposeDimensions(x, (0,1,2,4,3,5,6))
    x = x.permute(0, 1, 2, 4, 3, 5, 6)
    # Reshape to get Tensor with a resolution of (8, 360, 182)
    # x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
    x = x.reshape(x.shape[0], 7, 32, 32, x.shape[-1])
    # Crop the output to the input shape of the network
    # x = Crop3D(x)

    # Reshape x back
    # x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
    x = torch.reshape(x, (x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
    # Call the layer normalization
    x = self.norm(x)

    # Mixup normalized tensors
    x = self.linear2(x)
    return x
  
class EarthSpecificLayer(nn.Module):
  def __init__(self, depth, dim, drop_path_ratio_list, heads,down):
    '''Basic layer of our network, contains 2 or 6 blocks'''
    super(EarthSpecificLayer, self).__init__()
    self.depth = depth
    self.blocks = []

    # Construct basic blocks
    for i in range(depth):
      self.blocks.append(EarthSpecificBlock(dim, drop_path_ratio_list[i], heads,down))
      
  def forward(self, x, Z, H, W):
    for i in range(self.depth):
      # Roll the input every two blocks
      if i % 2 == 0:
        self.blocks[i](x, Z, H, W, roll=False)
      else:
        self.blocks[i](x, Z, H, W, roll=True)
    return x

class EarthSpecificBlock(nn.Module):
  def __init__(self, dim, drop_path_ratio, heads,down):
    '''
    3D transformer block with Earth-Specific bias and window attention, 
    see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
    The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
    '''
    super(EarthSpecificBlock, self).__init__()
    # Define the window size of the neural network 
    self.window_size = (1, 4, 8)

    # Initialize serveral operations
    self.drop_path = DropPath(drop_path_ratio)
    # self.norm1 = LayerNorm(dim)
    self.norm1 = nn.LayerNorm(dim)
    self.norm2 = nn.LayerNorm(dim)
    # self.norm2 = LayerNorm(dim)
    self.linear = Mlp(dim, 0)
    self.attention = EarthAttention3D(dim, heads, 0, self.window_size,down)

  def roll3D(self,x, shift):
    assert len(shift) == 3, "Shift must specify three dimensions: (Z, H, W)"
    Z_shift, H_shift, W_shift = shift

    # Roll along each dimension
    if Z_shift != 0:
        x = torch.roll(x, shifts=Z_shift, dims=1)  
    if H_shift != 0:
        x = torch.roll(x, shifts=H_shift, dims=2)  
    if W_shift != 0:
        x = torch.roll(x, shifts=W_shift, dims=3)  
    return x

  def gen_mask(self, x):

    B, Z, H, W, C = x.shape
    Z_win, H_win, W_win = self.window_size

    # Calculate the number of windows along each dimension
    num_Z = Z // Z_win
    num_H = H // H_win
    num_W = W // W_win
    num_windows = num_Z * num_H * num_W
    num_tokens = Z_win * H_win * W_win

    # Initialize 
    mask = torch.zeros((B, num_windows, num_tokens, num_tokens), device=x.device)

    # block non-adjacent positions
    for z in range(Z_win):
        for h in range(H_win):
            for w in range(W_win):
                token_id_1 = z * H_win * W_win + h * W_win + w
                for z2 in range(Z_win):
                    for h2 in range(H_win):
                        for w2 in range(W_win):
                            token_id_2 = z2 * H_win * W_win + h2 * W_win + w2
                            if abs(z - z2) > 1 or abs(h - h2) > 1 or abs(w - w2) > 1:
                                mask[:, :, token_id_1, token_id_2] = -1000  # Block non-adjacent tokens
    return mask


  def forward(self, x, Z, H, W, roll):
    # Save the shortcut for skip-connection
    shortcut = x

    # Reshape input to three dimensions to calculate window attention
    # reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[2]))
    x = x.reshape(x.shape[0], Z, H, W, x.shape[2])
    # x = pad3D(x)

    # Store the shape of the input for restoration
    ori_shape = x.shape

    if roll:
      # Roll x for half of the window for 3 dimensions
      x = self.roll3D(x, shift=[self.window_size[0]//2, self.window_size[1]//2, self.window_size[2]//2])
      # Generate mask of attention masks
      # If two pixels are not adjacent, then mask the attention between them
      # Your can set the matrix element to -1000 when it is not adjacent, then add it to the attention
      mask = self.gen_mask(x)
    else:
      # e.g., zero matrix when you add mask to attention
      mask = torch.zeros_like(x)
      
    print('x_layer',x.shape)
    # Reorganize data to calculate window attention
    # x_window = reshape(x, target_shape=(x.shape[0], Z//window_size[0], window_size[0], H // window_size[1], window_size[1], W // window_size[2], window_size[2], x.shape[-1]))
    x_window = x.reshape(x.shape[0], Z//self.window_size[0], self.window_size[0], H // self.window_size[1], self.window_size[1], W // self.window_size[2], self.window_size[2], x.shape[-1])

    # x_window = TransposeDimensions(x_window, (0, 1, 3, 5, 2, 4, 6, 7))
    x_window = x_window.permute(0, 1, 3, 5, 2, 4, 6, 7)
    # Get data stacked in 3D cubes, which will further be used to calculated attention among each cube
    # x_window = reshape(x_window, target_shape=(-1, window_size[0]* window_size[1]*window_size[2], x.shape[-1]))
    x_window = x_window.reshape(-1,self.window_size[0]*self.window_size[1]*self.window_size[2], x.shape[-1])
    # Apply 3D window attention with Earth-Specific bias
    print('window',x_window.shape)
    x_window = self.attention(x_window, mask)

    # Reorganize data to original shapes
    # x = reshape(x_window, target_shape=((-1, Z // window_size[0], H // window_size[1], W // window_size[2], window_size[0], window_size[1], window_size[2], x_window.shape[-1])))
    x = x_window.reshape(-1,Z//self.window_size[0],H//self.window_size[1],W//self.window_size[2],self.window_size[0],self.window_size[1],self.window_size[2],x_window.shape[-1])
    # x = TransposeDimensions(x, (0, 1, 4, 2, 5, 3, 6, 7))
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7)
    # Reshape the tensor back to its original shape
    # x = reshape(x_window, target_shape=ori_shape)
    x = x.reshape(ori_shape)
    if roll:
      # Roll x back for half of the window
      x = self.roll3D(x, shift=[-self.window_size[0]//2, -self.window_size[1]//2, -self.window_size[2]//2])

    # # Crop the zero-padding
    # x = Crop3D(x)

    # Reshape the tensor back to the input shape
    # x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[4]))
    x = x.reshape(x.shape[0], x.shape[1]* x.shape[2]* x.shape[3], x.shape[4])
    # Main calculation stages
    x = shortcut + self.drop_path(self.norm1(x))
    x = x + self.drop_path(self.norm2(self.linear(x)))
    return x
    
class EarthAttention3D(nn.Module):
  def __init__(self, dim, heads, dropout_rate, window_size, down):
    super(EarthAttention3D, self).__init__()
    '''
    3D window attention with the Earth-Specific bias, 
    see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention.
    '''
    # Initialize several operations
    self.linear1 = nn.Linear(dim,3*dim, bias=True) #creat qkv
    self.linear2 = nn.Linear(dim, dim)
    self.softmax = nn.Softmax(dim=-1)
    self.dropout = nn.Dropout(dropout_rate)

    # Store several attributes
    self.head_number = heads
    self.dim = dim
    self.scale = (dim//heads)**-0.5
    self.window_size = window_size

    # input_shape is current shape of the self.forward function
    # You can run your code to record it, modify the code and rerun it
    # Record the number of different window types

    # self.type_of_windows = (input_shape[0]//window_size[0])*(input_shape[1]//window_size[1]), (126//window_size[2])*(50//window_size[1])
    self.type_of_windows = 56 if down else 224

    # For each type of window, we will construct a set of parameters according to the paper
    # self.earth_specific_bias = ConstructTensor(shape=((2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0], self.type_of_windows, heads))
    # self.earth_specific_bias = Parameters(self.earth_specific_bias)
    self.earth_specific_bias = nn.Parameter(
            torch.randn(
                (2 * window_size[2] - 1) * window_size[1] * window_size[1] * window_size[0] * window_size[0],
                self.type_of_windows,
                heads
            )
        )
    # Initialize the tensors using Truncated normal distribution
    # TruncatedNormalInit(self.earth_specific_bias, std=0.02) 
    nn.init.trunc_normal_(self.earth_specific_bias, std=0.02)
    # Construct position index to reuse self.earth_specific_bias
    self.position_index = self._construct_index()
    
  def _construct_index(self):
    ''' This function construct the position index to reuse symmetrical parameters of the position bias'''
    coords_zi = torch.arange(self.window_size[0])
    coords_zj = -coords_zi * self.window_size[0]

    coords_hi = torch.arange(self.window_size[1])
    coords_hj = -coords_hi * self.window_size[1]

    coords_w = torch.arange(self.window_size[2])

    coords_1 = torch.stack(torch.meshgrid(coords_zi, coords_hi, coords_w, indexing='ij'), dim=-1).view(-1, 3)
    coords_2 = torch.stack(torch.meshgrid(coords_zj, coords_hj, coords_w, indexing='ij'), dim=-1).view(-1, 3)

    coords = coords_1[:, None, :] - coords_2[None, :, :]
    coords = coords.transpose(0, 1)

    coords[:, :, 2] += self.window_size[2] - 1
    coords[:, :, 1] *= 2 * self.window_size[2] - 1
    coords[:, :, 0] *= (2 * self.window_size[2] - 1) * self.window_size[1] * self.window_size[1]

    position_index = coords.sum(dim=-1).view(-1)
    return position_index

  def forward(self, x, mask):
    # Linear layer to create query, key and value
    original_shape = x.shape
    x = self.linear1(x)

    # Record the original shape of the input
    

    # reshape the data to calculate multi-head attention
    # qkv = reshape(x, target_shape=(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number)) 
    qkv = x.reshape(x.shape[0], x.shape[1], 3, self.head_number, self.dim // self.head_number)
    # query, key, value = TransposeDimensions(qkv, (2, 0, 3, 1, 4))
    query, key, value = qkv.permute(2, 0, 3, 1, 4)
    # Scale the attention
    query = query * self.scale
    print('query',query.shape)
    # Calculated the attention, a learnable bias is added to fix the nonuniformity of the grid.
    # attention = query @ key.T # @ denotes matrix multiplication
    attention = torch.matmul(query, key.transpose(-2, -1))

    # self.earth_specific_bias is a set of neural network parameters to optimize. 
    EarthSpecificBias = self.earth_specific_bias[self.position_index]

    # Reshape the learnable bias to the same shape as the attention matrix
    # EarthSpecificBias = reshape(EarthSpecificBias, target_shape=(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number))
    EarthSpecificBias = EarthSpecificBias.reshape(self.window_size[0]*self.window_size[1]*self.window_size[2], self.window_size[0]*self.window_size[1]*self.window_size[2], self.type_of_windows, self.head_number)
    EarthSpecificBias = EarthSpecificBias.permute(2,3,0,1)
    # EarthSpecificBias = TransposeDimensions(EarthSpecificBias, (2, 3, 0, 1))

    # EarthSpecificBias = reshape(EarthSpecificBias, target_shape = [1]+EarthSpecificBias.shape)
    EarthSpecificBias = EarthSpecificBias.unsqueeze(0)
    # Add the Earth-Specific bias to the attention matrix
    print('attention',attention.shape,'EarthSpecificBias',EarthSpecificBias.shape)
    attention = attention + EarthSpecificBias
    print('attention',attention.shape)
    print(mask.shape)

    # Mask the attention between non-adjacent pixels, e.g., simply add -100 to the masked element.
    # attention = self.mask_attention(attention, mask)
    # attention = self.softmax(attention)
    # attention = self.dropout(attention)

    print('attention',attention.shape)
    # Calculated the tensor after spatial mixing.
    # x = attention @ value.T # @ denote matrix multiplication
    x = torch.matmul(attention, value.unsqueeze(0))
    print('x',x.shape)
    # Reshape tensor to the original shape
    # x = TransposeDimensions(x, (0, 2, 1))
    # [1, 126, 6, 50, 32]
    x = x.reshape(x.shape[1],x.shape[-2],-1)
    # x = x.permute(0,2,1)
    # x = reshape(x, target_shape = original_shape)
    print(original_shape)
    x = x.reshape(original_shape)
    # Linear layer to post-process operated tensor
    x = self.linear2(x)
    x = self.dropout(x)
    return x


  def mask_attention(self, attention, mask):
        if mask is not None:
            attention += mask.unsqueeze(1).unsqueeze(2) * -100
        return attention
  
class Mlp(nn.Module):
  def __init__(self, dim, dropout_rate):
    super(Mlp, self).__init__()
    self.linear1 = nn.Linear(dim, dim * 4)
    self.linear2 = nn.Linear(dim * 4, dim)
    self.activation = nn.GELU()
    self.drop = nn.Dropout(dropout_rate)
    
  def forward(self, x):
    x = self.linear1(x)
    x = self.activation(x)
    x = self.drop(x)
    x = self.linear2(x)
    x = self.drop(x)
    return x
  
from perlin_numpy import generate_fractal_noise_2d
def PerlinNoise():
  '''Generate random Perlin noise: we follow https://github.com/pvigier/perlin-numpy/ to calculate the perlin noise.'''
  # Define number of noise
  octaves = 3
  # Define the scaling factor of noise
  noise_scale = 0.2
  # Define the number of periods of noise along the axis
  period_number = 12
  # The size of an input slice
  H, W = 128, 128
  # Scaling factor between two octaves
  persistence = 0.5
  # perlin_noise = noise_scale*GenerateFractalNoise((H, W), (period_number, period_number), octaves, persistence)
  perlin_noise = noise_scale*generate_fractal_noise_2d((H,W), (period_number, period_number), octaves, persistence)
  return perlin_noise


In [12]:

# class PatchEmbedding(nn.Module):
#   def __init__(self, patch_size, dim):
#     '''Patch embedding operation'''
#     super(PatchEmbedding, self).__init__()
#     # Here we use convolution to partition data into cubes
#     self.conv = nn.Conv3d(5, dim, kernel_size=patch_size, stride=patch_size)
#     self.conv_surface = nn.Conv2d(4, dim, kernel_size=patch_size[1:], stride=patch_size[1:])

#     # Load constant masks from the disc
#     # self.land_mask, self.soil_type, self.topography = LoadConstantMask()
#     self.land_mask, self.soil_type, self.topography = torch.tensor(land_masks), torch.tensor(soil_types), torch.tensor(topography)
    
#   def forward(self, input, input_surface):
#     # Zero-pad the input
#     # input = Pad3D(input)
#     input = nn.ZeroPad3d(padding=(0, 0, 0, 0, 0, 0))(input)
#     print(input.shape)
#     # input_surface = Pad2D(input_surface)
#     input_surface = nn.ZeroPad2d(padding=(0, 0))(input_surface)
#     # print(input_surface.shape)
#     # Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches, patch_size = (2, 4, 4) as in the original paper
#     input = self.conv(input)
#     print('input:',input.shape)
#     # Add three constant fields to the surface fields
#     # input_surface =  Concatenate(input_surface, self.land_mask, self.soil_type, self.topography)
#     land_mask_expanded = self.land_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, 721, 1440)
#     soil_type_expanded = self.soil_type.unsqueeze(0).unsqueeze(0)  
#     topography_expanded = self.topography.unsqueeze(0).unsqueeze(0)  
#     # Add the expanded fields to input_surface
#     input_surface = input_surface + land_mask_expanded + soil_type_expanded + topography_expanded
#     # Apply a linear projection for patch_size[1]*patch_size[2] patches
#     print(input_surface.shape)
#     input_surface = self.conv_surface(input_surface)
#     print('surface:',input_surface.shape)
#     # Concatenate the input in the pressure level, i.e., in Z dimension
#     # x = Concatenate(input, input_surface)
    
#     x = torch.cat((input_surface.unsqueeze(2),input), dim=2)
#     print(x.shape)
#     #x (B, C, Z, H, W)
#     # Reshape x for calculation of linear projections
#     # x = TransposeDimensions(x, (0, 2, 3, 4, 1))
#     x = x.permute(0, 2, 3, 4, 1) #channel first to channel last
#     print(x.shape)
#     x = x.reshape(x.shape[0], 7*32*32, x.shape[-1])
#     return x
    
  
# class PatchRecovery(nn.Module):
#   def __init__(self, dim,patch_size):
#     '''Patch recovery operation'''
#     super(PatchRecovery, self).__init__()
#     # Hear we use two transposed convolutions to recover data
#     self.conv = nn.ConvTranspose3d(dim, 5, kernel_size=patch_size, stride=patch_size)
#     self.conv_surface = nn.ConvTranspose2d(dim, 4, kernel_size=patch_size[1:], stride=patch_size[1:])
    
#   def forward(self, x, Z, H, W):
#     # The inverse operation of the patch embedding operation, patch_size = (2, 4, 4) as in the original paper
#     # Reshape x back to three dimensions
#     x = x.permute(0, 2, 1)
#     # x = reshape(x, target_shape=(x.shape[0], x.shape[1], Z, H, W))
#     x = x.reshape(x.shape[0], x.shape[1], Z, H, W)

#     # Call the transposed convolution
#     output = self.conv(x[:, :, 1:, :, :])
#     output_surface = self.conv_surface(x[:, :, 0, :, :])

#     # Crop the output to remove zero-paddings
#     # output = Crop3D(output)
#     # output_surface = Crop2D(output_surface)
#     return output, output_surface



# class DownSample(nn.Module):
#   def __init__(self, dim):
#     '''Down-sampling operation'''
#     super(DownSample, self).__init__()
#     # A linear function and a layer normalization
#     self.linear = nn.Linear(4*dim, 2*dim, bias=False)
#     self.norm = nn.LayerNorm(4*dim)
  
#   def forward(self, x, Z, H, W):
#     # Reshape x to three dimensions for downsampling
#     # x = reshape(x, target_shape=(x.shape[0], Z, H, W, x.shape[-1]))

#     x = x.reshape(x.shape[0], Z, H, W, x.shape[-1])
#     # Padding the input to facilitate downsampling
#     # x = Pad3D(x)
#     print("down x", x.shape)
#     # Reorganize x to reduce the resolution: simply change the order and downsample from (8, 360, 182) to (8, 180, 91)
#     # Z, H, W = x.shape[-3,-2,-1]
#     # Reshape x to facilitate downsampling
#     # x = reshape(x, target_shape=(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1]))
#     x = x.reshape(x.shape[0], Z, H//2, 2, W//2, 2, x.shape[-1])
#     # Change the order of x
#     # x = TransposeDimensions(x, (0,1,2,4,3,5,6))
#     x = x.permute(0, 1, 2, 4, 3, 5, 6)
#     # Reshape to get a tensor of resolution (8, 180, 91)
#     # x = reshape(x, target_shape=(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1]))
#     x = x.reshape(x.shape[0], Z*(H//2)*(W//2), 4 * x.shape[-1])
#     # Call the layer normalization
#     x = self.norm(x)

#     # Decrease the channels of the data to reduce computation cost
#     x = self.linear(x)
#     return x

# class UpSample(nn.Module):
#   def __init__(self, input_dim, output_dim):
#     '''Up-sampling operation'''
#     super(UpSample, self).__init__()
#     # Linear layers without bias to increase channels of the data
#     self.linear1 = nn.Linear(input_dim, output_dim*4, bias=False)

#     # Linear layers without bias to mix the data up
#     self.linear2 = nn.Linear(output_dim, output_dim, bias=False)

#     # Normalization
#     self.norm = nn.LayerNorm(output_dim)
  
#   def forward(self, x):
#     # Call the linear functions to increase channels of the data
#     x = self.linear1(x)

#     # Reorganize x to increase the resolution: simply change the order and upsample from (8, 180, 91) to (8, 360, 182)
#     # Reshape x to facilitate upsampling.
#     # x = reshape(x, target_shape=(x.shape[0], 8, 180, 91, 2, 2, x.shape[-1]//4))
#     x = x.reshape(x.shape[0], 7, 15,15, 2, 2, x.shape[-1]//4)
#     # Change the order of x
#     # x = TransposeDimensions(x, (0,1,2,4,3,5,6))
#     x = x.permute(0, 1, 2, 4, 3, 5, 6)
#     # Reshape to get Tensor with a resolution of (8, 360, 182)
#     # x = reshape(x, target_shape=(x.shape[0], 8, 360, 182, x.shape[-1]))
#     x = x.reshape(x.shape[0], 7, 30, 30, x.shape[-1])
#     # Crop the output to the input shape of the network
#     # x = Crop3D(x)

#     # Reshape x back
#     # x = reshape(x, target_shape=(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
#     x = torch.reshape(x, (x.shape[0], x.shape[1]*x.shape[2]*x.shape[3], x.shape[-1]))
#     # Call the layer normalization
#     x = self.norm(x)

#     # Mixup normalized tensors
#     x = self.linear2(x)
#     return x

In [None]:
# embedding = PatchEmbedding((2, 4, 4), 192)
# recovery = PatchRecovery(384,(2, 4, 4))
# down = DownSample(192)
# up = UpSample(384,192)
# drop_path_list = torch.linspace(0, 0.2, 8)

# layer1 = EarthSpecificLayer(2, 192, drop_path_list[:2], 6)
# upper = torch.rand(1,5,12,128,128,dtype=torch.float32)
# surface = torch.rand(1,4,128,128,dtype=torch.float32)

# x= embedding(upper,surface)
# print("x:",x.shape)
# skip = x
# x = layer1(x, 7, 32, 32) 
# x = down(x,7,32,32)
# print("down x:",x.shape)
# x = up(x)
# print("up x:",x.shape)
# x = torch.cat([x,skip],dim=-1)
# x = recovery(x,7,30,30)
# print(x[0].shape,x[1].shape)




TypeError: EarthSpecificLayer.__init__() missing 1 required positional argument: 'down'

In [8]:
upper = torch.rand(1,5,12,128,128,dtype=torch.float32)
surface = torch.rand(1,4,128,128,dtype=torch.float32)

model = PanguModel()
print(model(upper,surface)[0].shape,model(upper,surface)[1].shape)

x_layer torch.Size([1, 7, 32, 32, 192])
window torch.Size([224, 32, 192])
query torch.Size([224, 6, 32, 32])
attention torch.Size([224, 6, 32, 32]) EarthSpecificBias torch.Size([1, 224, 6, 32, 32])
attention torch.Size([1, 224, 6, 32, 32])
torch.Size([1, 7, 32, 32, 192])
attention torch.Size([1, 224, 6, 32, 32])
x torch.Size([1, 224, 6, 32, 32])
torch.Size([224, 32, 192])
x_layer torch.Size([1, 7, 32, 32, 192])
window torch.Size([224, 32, 192])
query torch.Size([224, 6, 32, 32])
attention torch.Size([224, 6, 32, 32]) EarthSpecificBias torch.Size([1, 224, 6, 32, 32])
attention torch.Size([1, 224, 6, 32, 32])
torch.Size([1, 224, 32, 32])
attention torch.Size([1, 224, 6, 32, 32])
x torch.Size([1, 224, 6, 32, 32])
torch.Size([224, 32, 192])
layer1start torch.Size([1, 7168, 192])
down x torch.Size([1, 7, 32, 32, 192])
layer2start torch.Size([1, 1792, 384])
x_layer torch.Size([1, 7, 16, 16, 384])
window torch.Size([56, 32, 384])
query torch.Size([56, 12, 32, 32])
attention torch.Size([56, 12