In [None]:
import h5py
import numpy as np

In [None]:
trainset = h5py.File("../data/train.h5", "r")
validateset = h5py.File("../data/val.h5", "r")
testset = h5py.File("../data/test.h5", "r")

train_images = trainset["images"][:] / 10000
validate_images = validateset["images"][:] / 10000
test_images = testset["images"][:] / 10000

In [None]:
for feature in trainset:
    print(f"{feature}:\t{trainset[feature][:].shape}")

agbd:	(25036,)
cloud:	(25036, 15, 15, 1)
images:	(25036, 15, 15, 12)
lat:	(25036, 15, 15, 1)
lon:	(25036, 15, 15, 1)
scl:	(25036, 15, 15, 1)


## Image Regression - Transformers

In [None]:
import torch
import torch.nn as nn

<torch._C.Generator at 0x7f6350e4dfd0>

In [None]:
class ViT(nn.Module):
    def __init__(self, input):
        super().__init__()

        n_images,_,_, n_channels = input.shape
        images = torch.from_numpy(input / 10000)

        images = self.patched(images, (5,5))
        patch_len = images.shape[-1]

        images += self.pos_emb(images[0].shape)
        learnable_emb = nn.Parameter(torch.zeros(n_images, 1, patch_len))
        self.input = torch.cat([learnable_emb, images], 1)

        self.layer_norm = nn.LayerNorm(self.input.shape, dtype=self.input.dtype)
        self.msa = nn.MultiheadAttention(patch_len, n_channels, dtype=self.input.dtype)
        self.mlp = nn.Sequential(
            nn.Linear(input.shape[-1], input.shape[-1]),
            nn.GELU(),
            nn.Dropout(),
            nn.Linear(input.shape[-1], input.shape[-1]),
            nn.Dropout()
        )
 

    def forward(self, msa_input):
        # Transfomrer Encoder
        for _ in range(10):
            x = self.layer_norm(msa_input)
            msa_output  = self.msa(x, x, x)
            mlp_input   = msa_output + msa_input
            mlp_output  = self.mlp(self.layer_norm(mlp_input))
            msa_input   = mlp_output + mlp_input

        return msa_input[0]


    def patched(self, images, patch_shape):
        # Input:  (n_images, h_image, w_image, n_channels), (h_patch, w_patch)
        # Output: (n_patches, n_channels*h_patch*w_patch)

        (n_images, h_image, _, n_channels) = images.shape
        (h_patch, w_patch)                 = patch_shape

        flattened_patches = (
            images.reshape(n_images, h_image // h_patch, h_patch, -1, w_patch, n_channels) # Create patch segments
            .permute(0, 1, 3, 5, 2, 4)                                          # Join the segments, forming patch
            .reshape(n_images, -1, n_channels*h_patch*w_patch)    # Remove excess dimensions and flatten the patch
        )
        return flattened_patches
    
        
    def pos_emb(self, patch_shape):
        pe = torch.zeros(patch_shape)

        d_patch = patch_shape[1]
        (pos, i) = patch_shape

        for pos in range(pos):
            for i in range(0,i,2):
                pe[pos,i] = np.sin(pos / 10000**(i/d_patch))
            for i in range(1,i,2):
                pe[pos,i] = np.cos(pos / 10000**(i/d_patch))

        return pe