In [1]:
import numpy as np
import torch
from torch import nn
from torch.distributions import MultivariateNormal

from cnp.cnp import GaussianNeuralProcess
from cnp.data import LambdaIterator
from cnp.cov import MeanFieldCov, InnerProdCov, KvvCov, AddNoNoise, AddHomoNoise

import matplotlib.pyplot as plt

device = torch.device('cuda:1')

In [2]:
data_root = '/scratches/cblgpu07/em626/kernelcnp/kernelcnp/experiments/environmental/data'

np_lonlat_fine = np.load(f'{data_root}/x_context_fine.npy')
np_lonlat_coarse = np.load(f'{data_root}/x_context_coarse.npy')
np_elevation_fine = np.load(f'{data_root}/y_context_fine.npy')

np_train_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_train.npy')
np_train_lonlat_station = np.load(f'{data_root}/x_target_train.npy')
np_train_temperature_station = np.load(f'{data_root}/y_target_train.npy')

np_test_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_val.npy')
np_test_lonlat_station = np.load(f'{data_root}/x_target_val.npy')
np_test_temperature_station = np.load(f'{data_root}/y_target_val.npy')

lonlat_fine = torch.tensor(np_lonlat_fine).float()
lonlat_coarse = torch.tensor(np_lonlat_coarse).float()
np_elevation_fine = (np_elevation_fine - np.mean(np_elevation_fine)) / np.std(np_elevation_fine)
elevation_fine = torch.tensor(np_elevation_fine).float()

idx = torch.randperm(np_train_lonlat_station.shape[0])
train_lonlat_station = torch.tensor(np_train_lonlat_station).float()[idx, :][:-1]
train_temperature_station = torch.tensor(np_train_temperature_station).float()[:, idx][:, :-1]
train_reanalysis_coarse = torch.tensor(np_train_reanalysis_coarse).float()

# Compute mean and standard deviation of inputs for normalising
train_reanalysis_mean = torch.mean(train_reanalysis_coarse, dim=[0, 2, 3])[None, :, None, None]
train_reanalysis_stds = torch.var(train_reanalysis_coarse, dim=[0, 2, 3])[None, :, None, None]**0.5

train_reanalysis_coarse = (train_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

# Compute mean and standard deviation of outputs for normalising
flat = torch.flatten(train_temperature_station)
train_temperature_station_mean = torch.mean(flat[~torch.isnan(flat)])
train_temperature_station_std = torch.var(flat[~torch.isnan(flat)])**0.5

train_temperature_station = (train_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

valid_lonlat_station = torch.tensor(np_train_lonlat_station).float()[idx, :][-1:]
valid_temperature_station = torch.tensor(np_train_temperature_station).float()[:, idx][:, -1:]
valid_reanalysis_coarse = torch.tensor(np_train_reanalysis_coarse).float()
valid_reanalysis_coarse = (valid_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

valid_temperature_station = (valid_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

test_lonlat_station = torch.tensor(np_test_lonlat_station).float()
test_temperature_station = torch.tensor(np_test_temperature_station).float()
test_reanalysis_coarse = torch.tensor(np_test_reanalysis_coarse).float()
test_reanalysis_coarse = (test_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

test_temperature_station = (test_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

In [3]:
print('lonlat_fine'.ljust(30), lonlat_fine.shape)
print('lonlat_coarse'.ljust(30), lonlat_coarse.shape)
print('elevation_fine'.ljust(30), elevation_fine.shape, '\n')

print('train_reanalysis_coarse'.ljust(30), train_reanalysis_coarse.shape)
print('train_lonlat_station'.ljust(30), train_lonlat_station.shape)
print('train_temperature_station'.ljust(30), train_temperature_station.shape, '\n')

print('valid_reanalysis_coarse'.ljust(30), valid_reanalysis_coarse.shape)
print('valid_lonlat_station'.ljust(30), valid_lonlat_station.shape)
print('valid_temperature_station'.ljust(30), valid_temperature_station.shape, '\n')

print('test_reanalysis_coarse'.ljust(30), test_reanalysis_coarse.shape)
print('test_lonlat_station'.ljust(30), test_lonlat_station.shape)
print('test_temperature_station'.ljust(30), test_temperature_station.shape)

lonlat_fine                    torch.Size([1200, 1200, 2])
lonlat_coarse                  torch.Size([6, 6, 2])
elevation_fine                 torch.Size([1200, 1200]) 

train_reanalysis_coarse        torch.Size([8766, 25, 6, 6])
train_lonlat_station           torch.Size([961, 2])
train_temperature_station      torch.Size([8766, 961]) 

valid_reanalysis_coarse        torch.Size([8766, 25, 6, 6])
valid_lonlat_station           torch.Size([1, 2])
valid_temperature_station      torch.Size([8766, 1]) 

test_reanalysis_coarse         torch.Size([2192, 25, 6, 6])
test_lonlat_station            torch.Size([24, 2])
test_temperature_station       torch.Size([11688, 24])


In [4]:
def get_elevation(lonlat_fine, lonlat_station, elevation_fine):
    
    diff = lonlat_fine[:, :, None, :] - lonlat_station[None, None, :, :]
    dist = torch.sum(diff**2, dim=-1)**0.5
    dist = -torch.permute(dist, (2, 0, 1))
    
    n = torch.tensor(dist.shape[0])
    d = torch.tensor(dist.shape[1])
    m = dist.view(n, -1).argmax(1)
    idx = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
    
    return torch.tensor([elevation_fine[i, j] for i, j in idx])

In [5]:
train_station_elevation = get_elevation(lonlat_fine=lonlat_fine,
                                        lonlat_station=train_lonlat_station,
                                        elevation_fine=elevation_fine)
train_station_elevation = train_station_elevation.float()

valid_station_elevation = get_elevation(lonlat_fine=lonlat_fine,
                                        lonlat_station=valid_lonlat_station,
                                        elevation_fine=elevation_fine)
valid_station_elevation = valid_station_elevation.float()

test_station_elevation = get_elevation(lonlat_fine=lonlat_fine,
                                       lonlat_station=test_lonlat_station,
                                       elevation_fine=elevation_fine)
test_station_elevation = test_station_elevation.float()

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)


In [6]:
class Dataloader:
    
    def __init__(self,
                 lonlat_fine,
                 lonlat_coarse,
                 elevation_fine,
                 reanalysis_coarse,
                 lonlat_station,
                 elevation_station,
                 temperature_station,
                 iterations_per_epoch,
                 max_num_target,
                 batch_size,
                 device):
        
        # Set data tensors
        self.lonlat_fine = lonlat_fine
        self.lonlat_coarse = lonlat_coarse
        self.elevation_fine = elevation_fine
        self.reanalysis_coarse = reanalysis_coarse
        self.lonlat_station = lonlat_station
        self.elevation_station = elevation_station
        self.temperature_station = temperature_station
        
        # Set dataloader parameters
        self.iterations_per_epoch = iterations_per_epoch
        self.max_num_target = max_num_target
        self.num_datasets = self.reanalysis_coarse.shape[0]
        self.batch_size = batch_size
        self.device = device
        
        
    def generate_batch(self):
        
        # Draw batch indices at random - these are time indices
        idx1 = torch.randperm(self.num_datasets)[:batch_size]
        
        batch_reanalysis_coarse = self.reanalysis_coarse[idx1]
        batch_temperature_station = self.temperature_station[idx1]
        
        # Keep stations which have no nan values
        nan_mask = torch.isnan(torch.sum(batch_temperature_station, dim=0))
        
        batch_lonlat_station = self.lonlat_station[~nan_mask, :]
        batch_elevation_station = self.elevation_station[~nan_mask]
        batch_temperature_station = batch_temperature_station[:, ~nan_mask]
        
        # From the non-nan stations, pick **num_target** at random
        num_target = min(max_num_target, nan_mask.shape[0])
        idx2 = torch.randperm(batch_lonlat_station.shape[0])[:num_target]
        
        batch_lonlat_station = batch_lonlat_station[idx2, :]
        batch_elevation_station = batch_elevation_station[idx2]
        batch_temperature_station = batch_temperature_station[:, idx2]
        
        a = torch.cuda.memory_allocated('cuda:1')
        if False: print(f'Memory usage (before loading):'.ljust(50) + f'{a}')
        
        batch = {'lonlat_fine'         : self.lonlat_fine.to(self.device),
                 'lonlat_coarse'       : self.lonlat_coarse.to(self.device),
                 'elevation_fine'      : self.elevation_fine.to(self.device),
                 'reanalysis_coarse'   : batch_reanalysis_coarse.to(self.device),
                 'lonlat_station'      : batch_lonlat_station.to(self.device),
                 'elevation_station'   : batch_elevation_station.to(self.device),
                 'temperature_station' : batch_temperature_station.to(self.device)}
        
        
        a = torch.cuda.memory_allocated('cuda:1')
        if False: print(f'Memory usage (after loading):'.ljust(50) + f'{a}')
        
        return batch
        
    def __iter__(self):
        return LambdaIterator(lambda: self.generate_batch(), self.iterations_per_epoch)

In [7]:
class StandardResNet(nn.Module):
    
    def __init__(self,
                 in_channels,
                 latent_channels,
                 out_channels):
        
        super().__init__()
        
        self.in_channels = in_channels
        self.latent_channels = latent_channels
        self.out_channels = out_channels
        self.kernel_size = 3
        self.padding = (self.kernel_size - 1) // 2
        
        self.l1 = nn.Conv2d(in_channels=self.in_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=1)
        
        self.l2 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=1)
        
        self.l3 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=1)
        
        self.l4 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=1)

        self.l5 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=self.out_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0)
        
        self.activation = nn.ReLU()
        
        
    def forward(self, tensor):
        
        tensor = self.activation(self.l1(tensor))
        
        tensor = self.activation(self.l2(tensor)) + tensor
        tensor = self.activation(self.l3(tensor)) + tensor
        tensor = self.activation(self.l4(tensor)) + tensor
        
        tensor = self.l5(tensor)
        
        return tensor

In [8]:
class StandardResNetEncoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.conv_in_channels = 25
        self.conv_latent_channels = 128
        self.conv_out_channels = 5
        
        self.cnn = StandardResNet(in_channels=self.conv_in_channels,
                                  latent_channels=self.conv_latent_channels,
                                  out_channels=self.conv_out_channels)
    
        
    def forward(self, reanalysis_coarse):
        return self.cnn(reanalysis_coarse)

In [9]:
class StandardEnvDecoder(nn.Module):
    
    def __init__(self, lengthscale, out_channels):
        
        super().__init__()
        
        self.out_channels = out_channels
        
        self.log_lengthscale = nn.Parameter(torch.tensor([np.log(lengthscale),
                                                          np.log(lengthscale)]).float())
        
        self.activation = nn.ReLU()

        self.l1 = nn.Linear(in_features=6,
                            out_features=64,
                            bias=True)

        self.l2 = nn.Linear(in_features=64,
                            out_features=64,
                            bias=True)

        self.l3 = nn.Linear(in_features=64,
                            out_features=64,
                            bias=True)

        self.l4 = nn.Linear(in_features=64,
                            out_features=self.out_channels,
                            bias=True)
        
    @property
    def lengthscale(self):
        return torch.exp(self.log_lengthscale)
        

    def forward(self, tensor, elevation_station, lonlat_coarse, lonlat_target):
        """
        
        Arguments:
            tensor        : torch.tensor, shape (B, C1, K, L)
            lonlat_fine   : torch.tensor, shape (K, L, 2)
            lonlat_target : torch.tensor, shape (N, 2)
            
        Returns:
            tensor        : torch.tensor, shape (B, C2, N)
        """
        
        # Compute differences between grid locations
        diff = lonlat_coarse[:, :, None, :] - \
               lonlat_target[None, None, :, :]
        
        # Compute weight matrix
        quad = -0.5 * (diff / self.lengthscale[None, None, None, :]) ** 2
        quad = torch.sum(quad, axis=-1)
        
        exp = torch.exp(quad)
        exp = exp / torch.sum(exp, dim=[0, 1])[None, None, :]
        
        # Compute refined tensor
        tensor = torch.einsum('bckl, kln -> bcn', tensor, exp)
        
        elevation_station = elevation_station[None, None, :].repeat(tensor.shape[0], 1, 1)
        tensor = torch.cat([tensor, elevation_station], dim=1)
        
        tensor = torch.permute(tensor, (0, 2, 1))
        
        tensor = self.l1(tensor)
        tensor = self.activation(tensor)
        
        tensor = self.l2(tensor)
        tensor = self.activation(tensor)
        
        tensor = self.l3(tensor)
        tensor = self.activation(tensor)
        
        tensor = self.l4(tensor)
        
        return tensor

In [10]:
class EnvResNetConvGNP(GaussianNeuralProcess):
    
    def __init__(self, covariance, add_noise):
        
        self.output_dim = 1
        self.decoder_lengthscale = 1.
        
        # Construct the convolutional decoder
        decoder_out_channels = self.output_dim          + \
                               covariance.num_basis_dim + \
                               covariance.extra_cov_dim + \
                               add_noise.extra_noise_dim
        
        encoder = StandardResNetEncoder()
        
        decoder = StandardEnvDecoder(lengthscale=self.decoder_lengthscale,
                                     out_channels=decoder_out_channels)

        super().__init__(encoder=encoder,
                         decoder=decoder,
                         covariance=covariance,
                         add_noise=add_noise)
        
        
    def forward(self, batch):
        
        # Pass through encoder
        tensor = self.encoder(reanalysis_coarse=batch['reanalysis_coarse'])
        
        # Pass through decoder
        tensor = self.decoder(tensor,
                              elevation_station=batch['elevation_station'],
                              lonlat_coarse=batch['lonlat_coarse'],
                              lonlat_target=batch['lonlat_station'])
        
        # Produce mean
        mean = tensor[..., 0:1]
        
        # Produce cov
        embedding = tensor[..., 1:]
        cov = self.covariance(embedding)
        cov_plus_noise = self.add_noise(cov, embedding)
        
        return mean, cov, cov_plus_noise

    
    def loss(self, batch):

        y_mean, _, y_cov = self.forward(batch)

        y_mean = y_mean.double()
        y_cov = y_cov.double()
        y_target = batch['temperature_station'].double()

        jitter = 1e-3 * torch.eye(y_cov.shape[-1], device=y_cov.device).double()
        y_cov = y_cov + jitter[None, :, :]
        
        mae = torch.mean(torch.abs(y_mean[:, :, 0] - y_target))
        
        dist = MultivariateNormal(loc=y_mean[:, :, 0], covariance_matrix=y_cov)
        nll = - torch.mean(dist.log_prob(y_target.double()))

        return nll.float(), mae.float()

In [11]:
torch.manual_seed(0)
np.random.seed(0)

In [12]:
iterations_per_epoch = 5000
batch_size = 64
max_num_target = 128

test_batch_size = 64
test_max_num_target = 64

train_data = Dataloader(lonlat_fine=lonlat_fine,
                        lonlat_coarse=lonlat_coarse,
                        elevation_fine=elevation_fine,
                        reanalysis_coarse=train_reanalysis_coarse,
                        lonlat_station=train_lonlat_station,
                        elevation_station=train_station_elevation,
                        temperature_station=train_temperature_station,
                        iterations_per_epoch=iterations_per_epoch,
                        max_num_target=max_num_target,
                        batch_size=batch_size,
                        device=device)

valid_data = Dataloader(lonlat_fine=lonlat_fine,
                        lonlat_coarse=lonlat_coarse,
                        elevation_fine=elevation_fine,
                        reanalysis_coarse=valid_reanalysis_coarse,
                        lonlat_station=valid_lonlat_station,
                        elevation_station=valid_station_elevation,
                        temperature_station=valid_temperature_station,
                        iterations_per_epoch=iterations_per_epoch,
                        max_num_target=max_num_target,
                        batch_size=batch_size,
                        device=device)

test_data = Dataloader(lonlat_fine=lonlat_fine,
                       lonlat_coarse=lonlat_coarse,
                       elevation_fine=elevation_fine,
                       reanalysis_coarse=test_reanalysis_coarse,
                       lonlat_station=test_lonlat_station,
                       elevation_station=test_station_elevation,
                       temperature_station=test_temperature_station,
                       iterations_per_epoch=iterations_per_epoch,
                       max_num_target=test_max_num_target,
                       batch_size=test_batch_size,
                       device=device)

In [None]:
num_basis_dim = 128

covariance = KvvCov(num_basis_dim)
add_noise = AddHomoNoise()

# covariance = MeanFieldCov(num_basis_dim=1)
# add_noise = AddNoNoise()

model = EnvResNetConvGNP(covariance=covariance,
                         add_noise=add_noise)
model.to(device)

optimizer = torch.optim.Adam(lr=1e-4, params=model.parameters(), weight_decay=1e-3)

mae_scale = train_temperature_station_std

for i, (train_batch, valid_batch, test_batch) in enumerate(zip(train_data, valid_data, test_data)):
    
    optimizer.zero_grad()
    
    loss, mae = model.loss(train_batch)
    loss.backward()
    optimizer.step()
    
    if i % 20 == 0:
    
        with torch.no_grad():
            valid_loss, valid_mae = 0, 0 # model.loss(valid_batch)
            test_loss, test_mae = model.loss(test_batch)
            
        decoder_scale = model.decoder.lengthscale.detach().cpu().numpy()
            
        print(f'Loss: {loss:4.2f}, {valid_loss:4.2f}, {test_loss:4.2f} '
              f'MAE: {mae*mae_scale:4.2f}, {valid_mae*mae_scale:4.2f}, {test_mae*mae_scale:4.2f} '
              f'Decoder scale: {decoder_scale[0]:.2f}, {decoder_scale[1]:.2f} ')

Loss: 179.98, 0.00, 39.24 MAE: 7.50, 0.00, 8.74 Decoder scale: 1.00, 1.00 
Loss: 129.89, 0.00, 29.46 MAE: 7.50, 0.00, 7.66 Decoder scale: 1.00, 1.00 
Loss: 129.48, 0.00, 30.04 MAE: 7.49, 0.00, 8.39 Decoder scale: 1.00, 1.00 
Loss: 126.39, 0.00, 27.99 MAE: 6.66, 0.00, 7.45 Decoder scale: 1.00, 1.00 
Loss: 125.50, 0.00, 29.27 MAE: 7.62, 0.00, 8.50 Decoder scale: 1.00, 1.00 
Loss: 125.09, 0.00, 28.62 MAE: 6.75, 0.00, 7.93 Decoder scale: 1.01, 1.00 
Loss: 124.50, 0.00, 28.58 MAE: 6.58, 0.00, 8.39 Decoder scale: 1.01, 1.01 
Loss: 124.40, 0.00, 28.12 MAE: 6.61, 0.00, 7.90 Decoder scale: 1.01, 1.01 
Loss: 124.00, 0.00, 27.53 MAE: 7.13, 0.00, 7.33 Decoder scale: 1.01, 1.01 
Loss: 123.20, 0.00, 27.63 MAE: 7.32, 0.00, 7.46 Decoder scale: 1.01, 1.01 
Loss: 123.70, 0.00, 27.86 MAE: 7.00, 0.00, 8.36 Decoder scale: 1.01, 1.01 
Loss: 123.49, 0.00, 27.48 MAE: 6.89, 0.00, 7.85 Decoder scale: 1.01, 1.01 
Loss: 122.44, 0.00, 27.43 MAE: 6.68, 0.00, 8.11 Decoder scale: 1.01, 1.01 
Loss: 122.53, 0.00, 27.28

In [None]:
# plt.figure(figsize=(16, 20))
# for i in range(80):
    
#     plt.subplot(10, 8, i+1)
#     temp = train_temperature_station[:, i].numpy()
#     temp = temp[~np.isnan(temp)]
    
#     plt.hist(temp)
    
# plt.tight_layout()
# plt.show()

In [None]:
plt.scatter(train_lonlat_station[:, 0],
            train_lonlat_station[:, 1],
            zorder=2)

plt.scatter(train_lonlat_station[:, 0],
            train_lonlat_station[:, 1],
            zorder=2)

plt.scatter(valid_lonlat_station[:, 0],
            valid_lonlat_station[:, 1],
            zorder=3)

plt.contourf(lonlat_fine[:, :, 0],
             lonlat_fine[:, :, 1],
             elevation_fine, origin='lower',
             alpha=0.5,
             cmap='coolwarm',
             zorder=1)

plt.show()