In [None]:
import numpy as np
import torch
from torch import nn

from cnp.data import LambdaIterator

In [None]:
data_root = '/scratches/cblgpu07/em626/kernelcnp/kernelcnp/experiments/environmental/data'

lonlat_fine = np.load(f'{data_root}/x_context_fine.npy')
lonlat_coarse = np.load(f'{data_root}/x_context_coarse.npy')
elevation_fine = np.load(f'{data_root}/y_context_fine.npy')

train_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_train.npy')
train_lonlat_station = np.load(f'{data_root}/x_target_train.npy')
train_temperature_station = np.load(f'{data_root}/y_target_train.npy')

valid_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_val.npy')
valid_lonlat_station = np.load(f'{data_root}/x_target_val.npy')
valid_temperature_station = np.load(f'{data_root}/y_target_val.npy')

lonlat_fine = torch.tensor(lonlat_fine).float()
lonlat_coarse = torch.tensor(lonlat_coarse).float()
elevation_fine = torch.tensor(elevation_fine).float()

train_reanalysis_coarse = torch.tensor(train_reanalysis_coarse).float()
train_lonlat_station = torch.tensor(train_lonlat_station).float()
train_temperature_station = torch.tensor(train_temperature_station).float()

valid_reanalysis_coarse = torch.tensor(valid_reanalysis_coarse).float()
valid_lonlat_station = torch.tensor(valid_lonlat_station).float()
valid_temperature_station = torch.tensor(valid_temperature_station).float()

In [None]:
print('lonlat_fine'.ljust(30), lonlat_fine.shape)
print('lonlat_coarse'.ljust(30), lonlat_coarse.shape)
print('elevation_fine'.ljust(30), elevation_fine.shape, '\n')

print('train_reanalysis_coarse'.ljust(30), train_reanalysis_coarse.shape)
print('train_lonlat_station'.ljust(30), train_lonlat_station.shape)
print('train_temperature_station'.ljust(30), train_temperature_station.shape, '\n')

print('valid_reanalysis_coarse'.ljust(30), valid_reanalysis_coarse.shape)
print('valid_lonlat_station'.ljust(30), valid_lonlat_station.shape)
print('valid_temperature_station'.ljust(30), valid_temperature_station.shape)

In [None]:
class Dataloader:
    
    def __init__(self,
                 lonlat_fine,
                 lonlat_coarse,
                 elevation_fine,
                 reanalysis_coarse,
                 lonlat_station,
                 temperature_station,
                 iterations_per_epoch,
                 batch_size):
        
        # Set data tensors
        self.lonlat_fine = lonlat_fine
        self.lonlat_coarse = lonlat_coarse
        self.elevation_fine = elevation_fine
        self.reanalysis_coarse = reanalysis_coarse
        self.lonlat_station = lonlat_station
        self.temperature_station = temperature_station
        
        # Set dataloader parameters
        self.iterations_per_epoch = iterations_per_epoch
        self.num_datasets = self.reanalysis_coarse.shape[0]
        self.batch_size = batch_size
        
        
    def generate_batch(self):
        
        # Draw batch indices at random
        idx = torch.randperm(self.num_datasets)[:batch_size]
        
        # Select batch indices
        batch_reanalysis_coarse = self.reanalysis_coarse[idx]
        batch_temperature_station = self.temperature_station[idx]
        
        # Keep non-nan stations
        nan_mask = torch.isnan(torch.sum(batch_temperature_station, dim=0))
        
        batch_lonlat_station = self.lonlat_station[~nan_mask, :]
        batch_temperature_station = batch_temperature_station[:, ~nan_mask]
        
        batch = {'lonlat_fine'         : self.lonlat_fine,
                 'lonlat_coarse'       : self.lonlat_coarse,
                 'elevation_fine'      : self.elevation_fine,
                 'reanalysis_coarse'   : batch_reanalysis_coarse,
                 'lonlat_station'      : batch_lonlat_station,
                 'temperature_station' : batch_temperature_station}
        
        return batch
        
    def __iter__(self):
        return LambdaIterator(lambda: self.generate_batch(), self.iterations_per_epoch)

In [None]:
class StandardEnvUNet(nn.Module):
    
    def __init__(self,
                 in_channels,
                 latent_channels,
                 out_channels):
        
        super().__init__()
        
        self.in_channels = in_channels
        self.latent_channels = latent_channels
        self.out_channels = out_channels
        self.kernel_size = 5
        self.padding = 2
        
        self.l1 = nn.Conv2d(in_channels=self.in_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l2 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=2*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l3 = nn.Conv2d(in_channels=2*self.latent_channels,
                            out_channels=4*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l4 = nn.Conv2d(in_channels=4*self.latent_channels,
                            out_channels=8*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l5 = nn.ConvTranspose2d(in_channels=8*self.latent_channels,
                                     out_channels=4*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=2,
                                     output_padding=1)
        
        self.l6 = nn.ConvTranspose2d(in_channels=8*self.latent_channels,
                                     out_channels=2*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=2,
                                     output_padding=1)
        
        self.l7 = nn.ConvTranspose2d(in_channels=4*self.latent_channels,
                                     out_channels=2*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=2,
                                     output_padding=1)
        
        self.l8 = nn.ConvTranspose2d(in_channels=3*self.latent_channels,
                                     out_channels=self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=2,
                                     output_padding=1)

        self.last_multiplier = nn.Conv2d(in_channels=self.in_channels+self.latent_channels,
                                         out_channels=self.out_channels,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
        
        self.activation = nn.ReLU()
        
        
    def forward(self, tensor):
        
        h1 = self.activation(self.l1(tensor))
        h2 = self.activation(self.l2(h1))
        h3 = self.activation(self.l3(h2))
        h4 = self.activation(self.l4(h3))
        
        h5 = self.activation(self.l5(h4))
        h5 = torch.cat([h3, h5], dim=1)
        
        h6 = self.activation(self.l6(h5))
        h6 = torch.cat([h2, h6], dim=1)
        
        h7 = self.activation(self.l7(h6))
        h7 = torch.cat([h1, h7], dim=1)
        
        h8 = self.activation(self.l8(h7))
        h8 = torch.cat([tensor, h8], dim=1)
        
        output = self.last_multiplier(h8)
        
        return output

In [None]:
class StandardEnvUpscaleEncoder(nn.Module):
    
    def __init__(self, lengthscale):
        
        super().__init__()
        
        self.lengthscale = nn.Parameter(torch.tensor([lengthscale, lengthscale]))

        
    def convert_to_fine(self, tensor, lonlat_coarse, lonlat_fine):
        """
        Upscales **tensor** with corresponding coordinates **lonlat_coarse**
        to a finer discretisation with coordinates **lonlat_fine**.
        
        Arguments:
        
            tensor        : torch.tensor, shape (B, C, K, L)
            lonlat_coarse : torch.tensor, shape (K, L, 2)
            lonlat_fine   : torch.tensor, shape (N, M, 2)
        """
        
        # Compute differences between grid locations
        diff = lonlat_coarse[:, :, None, None, :] - \
               lonlat_fine[None, None, :, :, :]
        
        # Compute weight matrix
        quad = -0.5 * (diff / self.lengthscale[None, None, None, None, :]) ** 2
        quad = torch.sum(quad, axis=-1)
        exp = torch.exp(quad)
        
        # Compute refined tensor
        tensor = torch.einsum('bckl, klnm -> bcnm', tensor, exp)
        
        return tensor
    
        
    def forward(self,
                lonlat_fine,
                elevation_fine,
                lonlat_coarse,
                reanalysis_coarse):
        
        # Get number of batches
        B = reanalysis_coarse.shape[0]
        
        # Upscale reanalysis data to match elevation grid
        reanalysis_fine = self.convert_to_fine(tensor=reanalysis_coarse,
                                               lonlat_coarse=lonlat_coarse,
                                               lonlat_fine=lonlat_fine)
        
        elevation_fine = elevation_fine[None, None, :, :].repeat(B, 1, 1, 1)
        
        #Concatenate reanalysis and elevation
        tensor = torch.cat([reanalysis_fine, elevation_fine], axis=1)
        
        return tensor

In [None]:
class StandardEnvDecoder(nn.Module):
    
    def __init__(self, lengthscale):
        
        super().__init__()
        
        self.lengthscale = nn.Parameter(torch.tensor([lengthscale, lengthscale]))
        

    def forward(self, tensor, lonlat_fine, lonlat_target):
        """
        
        Arguments:
        
            tensor        : torch.tensor, shape (B, C, K, L)
            lonlat_fine   : torch.tensor, shape (K, L, 2)
            lonlat_target : torch.tensor, shape (N, 2)
            
        Returns:
            tensor        : torch.tensor, shape (B, C, N)
        """
        
        # Compute differences between grid locations
        diff = lonlat_fine[:, :, None, :] - \
               lonlat_target[None, None, :, :]
        
        # Compute weight matrix
        quad = -0.5 * (diff / self.lengthscale[None, None, None, :]) ** 2
        quad = torch.sum(quad, axis=-1)
        exp = torch.exp(quad)
        
        # Compute refined tensor
        tensor = torch.einsum('bckl, kln -> bcn', tensor, exp)
        
        return tensor

In [None]:
class EnvUpscaleConvCNP(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.lengthscale = 0.2
        self.conv_in_channels = 26
        self.conv_latent_channels = 8
        self.out_channels = 2
        
        self.encoder = StandardEnvUpscaleEncoder(lengthscale=self.lengthscale)
        
        self.cnn = StandardEnvUNet(in_channels=self.conv_in_channels,
                                   latent_channels=self.conv_latent_channels,
                                   out_channels=self.out_channels)
        
        self.decoder = StandardEnvDecoder(lengthscale=self.lengthscale)
        
        
    def forward(self, batch):
        
        # Pass through encoder
        tensor = self.encoder(lonlat_fine=batch['lonlat_fine'],
                              elevation_fine=batch['elevation_fine'],
                              lonlat_coarse=batch['lonlat_coarse'],
                              reanalysis_coarse=batch['reanalysis_coarse'])
        
        # Pass through CNN
        tensor = self.cnn(tensor)
        
        # Pass through decoder
        tensor = self.decoder(tensor,
                              lonlat_fine=batch['lonlat_fine'],
                              lonlat_target=batch['lonlat_station'])
        
        return tensor

In [None]:
iterations_per_epoch = 1
batch_size = 16

data = Dataloader(lonlat_fine=lonlat_fine,
                  lonlat_coarse=lonlat_coarse,
                  elevation_fine=elevation_fine,
                  reanalysis_coarse=train_reanalysis_coarse,
                  lonlat_station=train_lonlat_station,
                  temperature_station=train_temperature_station,
                  iterations_per_epoch=iterations_per_epoch,
                  batch_size=batch_size)

In [None]:
convCNP = EnvUpscaleConvCNP()

In [None]:
for batch in data:
    
    tensor = convCNP(batch)

In [None]:
tensor.shape