In [1]:
import numpy as np
import torch
from torch import nn
from torch.distributions import MultivariateNormal

from cnp.cnp import GaussianNeuralProcess
from cnp.data import LambdaIterator
from cnp.cov import MeanFieldCov, InnerProdCov, KvvCov, AddNoNoise, AddHomoNoise

import matplotlib.pyplot as plt

device = torch.device('cuda:1')

In [2]:
data_root = '/scratches/cblgpu07/em626/kernelcnp/kernelcnp/experiments/environmental/data'

np_lonlat_fine = np.load(f'{data_root}/x_context_fine.npy')
np_lonlat_coarse = np.load(f'{data_root}/x_context_coarse.npy')
np_elevation_fine = np.load(f'{data_root}/y_context_fine.npy')

np_train_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_train.npy')
np_train_lonlat_station = np.load(f'{data_root}/x_target_train.npy')
np_train_temperature_station = np.load(f'{data_root}/y_target_train.npy')

np_test_reanalysis_coarse = np.load(f'{data_root}/y_context_coarse_val.npy')
np_test_lonlat_station = np.load(f'{data_root}/x_target_val.npy')
np_test_temperature_station = np.load(f'{data_root}/y_target_val.npy')

lonlat_fine = torch.tensor(np_lonlat_fine).float()
lonlat_coarse = torch.tensor(np_lonlat_coarse).float()
np_elevation_fine = (np_elevation_fine - np.mean(np_elevation_fine)) / np.std(np_elevation_fine)**0.5
elevation_fine = torch.tensor(np_elevation_fine).float()

idx = torch.randperm(np_train_lonlat_station.shape[0])
train_lonlat_station = torch.tensor(np_train_lonlat_station).float()[idx, :][:-100]
train_temperature_station = torch.tensor(np_train_temperature_station).float()[:, idx][:, :-100]
train_reanalysis_coarse = torch.tensor(np_train_reanalysis_coarse).float()

# Compute mean and standard deviation of inputs for normalising
train_reanalysis_mean = torch.mean(train_reanalysis_coarse, dim=[0, 2, 3])[None, :, None, None]
train_reanalysis_stds = torch.var(train_reanalysis_coarse, dim=[0, 2, 3])[None, :, None, None]**0.5

train_reanalysis_coarse = (train_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

# Compute mean and standard deviation of outputs for normalising
flat = torch.flatten(train_temperature_station)
train_temperature_station_mean = torch.mean(flat[~torch.isnan(flat)])
train_temperature_station_std = torch.var(flat[~torch.isnan(flat)])**0.5

train_temperature_station = (train_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

valid_lonlat_station = torch.tensor(np_train_lonlat_station).float()[idx, :][-100:]
valid_temperature_station = torch.tensor(np_train_temperature_station).float()[:, idx][:, -100:]
valid_reanalysis_coarse = torch.tensor(np_train_reanalysis_coarse).float()
valid_reanalysis_coarse = (valid_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

valid_temperature_station = (valid_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

test_lonlat_station = torch.tensor(np_test_lonlat_station).float()
test_temperature_station = torch.tensor(np_test_temperature_station).float()
test_reanalysis_coarse = torch.tensor(np_test_reanalysis_coarse).float()
test_reanalysis_coarse = (test_reanalysis_coarse - train_reanalysis_mean) / train_reanalysis_stds

test_temperature_station = (test_temperature_station - train_temperature_station_mean) / \
                            train_temperature_station_std

In [3]:
print('lonlat_fine'.ljust(30), lonlat_fine.shape)
print('lonlat_coarse'.ljust(30), lonlat_coarse.shape)
print('elevation_fine'.ljust(30), elevation_fine.shape, '\n')

print('train_reanalysis_coarse'.ljust(30), train_reanalysis_coarse.shape)
print('train_lonlat_station'.ljust(30), train_lonlat_station.shape)
print('train_temperature_station'.ljust(30), train_temperature_station.shape, '\n')

print('valid_reanalysis_coarse'.ljust(30), valid_reanalysis_coarse.shape)
print('valid_lonlat_station'.ljust(30), valid_lonlat_station.shape)
print('valid_temperature_station'.ljust(30), valid_temperature_station.shape, '\n')

print('test_reanalysis_coarse'.ljust(30), test_reanalysis_coarse.shape)
print('test_lonlat_station'.ljust(30), test_lonlat_station.shape)
print('test_temperature_station'.ljust(30), test_temperature_station.shape)

lonlat_fine                    torch.Size([1200, 1200, 2])
lonlat_coarse                  torch.Size([6, 6, 2])
elevation_fine                 torch.Size([1200, 1200]) 

train_reanalysis_coarse        torch.Size([8766, 25, 6, 6])
train_lonlat_station           torch.Size([862, 2])
train_temperature_station      torch.Size([8766, 862]) 

valid_reanalysis_coarse        torch.Size([8766, 25, 6, 6])
valid_lonlat_station           torch.Size([100, 2])
valid_temperature_station      torch.Size([8766, 100]) 

test_reanalysis_coarse         torch.Size([2192, 25, 6, 6])
test_lonlat_station            torch.Size([24, 2])
test_temperature_station       torch.Size([11688, 24])


In [4]:
# reanalysis_means = torch.mean(train_reanalysis_coarse, axis=[0, 2, 3])[None, :, None, None]
# reanalysis_stds = torch.var(train_reanalysis_coarse, axis=[0, 2, 3])[None, :, None, None] ** 0.5

# temperature_means = torch.mean(train_temperature_station, axis=[0, 1])[None, None]
# temperature_stds = torch.var(train_temperature_station, axis=[0, 1])[None, None] ** 0.5

# train_reanalysis_coarse = (train_reanalysis_coarse - reanalysis_means) / reanalysis_stds
# train_temperature_station = (train_temperature_station - temperature_means) / temperature_stds

# valid_reanalysis_coarse = (valid_reanalysis_coarse - reanalysis_means) / reanalysis_stds
# valid_temperature_station = (valid_temperature_station - temperature_means) / temperature_stds

In [5]:
class Dataloader:
    
    def __init__(self,
                 lonlat_fine,
                 lonlat_coarse,
                 elevation_fine,
                 reanalysis_coarse,
                 lonlat_station,
                 temperature_station,
                 iterations_per_epoch,
                 max_num_target,
                 batch_size,
                 device):
        
        # Set data tensors
        self.lonlat_fine = lonlat_fine
        self.lonlat_coarse = lonlat_coarse
        self.elevation_fine = elevation_fine
        self.reanalysis_coarse = reanalysis_coarse
        self.lonlat_station = lonlat_station
        self.temperature_station = temperature_station
        
        # Set dataloader parameters
        self.iterations_per_epoch = iterations_per_epoch
        self.max_num_target = max_num_target
        self.num_datasets = self.reanalysis_coarse.shape[0]
        self.batch_size = batch_size
        self.device = device
        
        
    def generate_batch(self):
        
        # Draw batch indices at random - these are time indices
        idx1 = torch.randperm(self.num_datasets)[:batch_size]
        
        batch_reanalysis_coarse = self.reanalysis_coarse[idx1]
        batch_temperature_station = self.temperature_station[idx1]
        
        # Keep stations which have no nan values
        nan_mask = torch.isnan(torch.sum(batch_temperature_station, dim=0))
        
        batch_lonlat_station = self.lonlat_station[~nan_mask, :]
        batch_temperature_station = batch_temperature_station[:, ~nan_mask]
        
        # From the non-nan stations, pick **num_target** at random
        num_target = min(max_num_target, nan_mask.shape[0])
        idx2 = torch.randperm(batch_lonlat_station.shape[0])[:num_target]
        
        batch_lonlat_station = batch_lonlat_station[idx2, :]
        batch_temperature_station = batch_temperature_station[:, idx2]
        
        a = torch.cuda.memory_allocated('cuda:1')
        if False: print(f'Memory usage (before loading):'.ljust(50) + f'{a}')
        
        batch = {'lonlat_fine'         : self.lonlat_fine.to(self.device),
                 'lonlat_coarse'       : self.lonlat_coarse.to(self.device),
                 'elevation_fine'      : self.elevation_fine.to(self.device),
                 'reanalysis_coarse'   : batch_reanalysis_coarse.to(self.device),
                 'lonlat_station'      : batch_lonlat_station.to(self.device),
                 'temperature_station' : batch_temperature_station.to(self.device)}
        
        
        a = torch.cuda.memory_allocated('cuda:1')
        if False: print(f'Memory usage (after loading):'.ljust(50) + f'{a}')
        
        return batch
        
    def __iter__(self):
        return LambdaIterator(lambda: self.generate_batch(), self.iterations_per_epoch)

In [6]:
class StandardEnvUNet(nn.Module):
    
    def __init__(self,
                 in_channels,
                 latent_channels,
                 out_channels):
        
        super().__init__()
        
        self.in_channels = in_channels
        self.latent_channels = latent_channels
        self.out_channels = out_channels
        self.kernel_size = 5
        self.padding = (self.kernel_size - 1) // 2
        
        self.l1 = nn.Conv2d(in_channels=self.in_channels,
                            out_channels=self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l2 = nn.Conv2d(in_channels=self.latent_channels,
                            out_channels=2*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l3 = nn.Conv2d(in_channels=2*self.latent_channels,
                            out_channels=4*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l4 = nn.Conv2d(in_channels=4*self.latent_channels,
                            out_channels=4*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=2)
        
        self.l5 = nn.Conv2d(in_channels=4*self.latent_channels,
                            out_channels=8*self.latent_channels,
                            kernel_size=self.kernel_size,
                            padding=self.padding,
                            stride=3)
        
        self.l6 = nn.ConvTranspose2d(in_channels=8*self.latent_channels,
                                     out_channels=4*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=3,
                                     padding=self.padding,
                                     output_padding=2)
        
        self.l7 = nn.ConvTranspose2d(in_channels=8*self.latent_channels,
                                     out_channels=4*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=self.padding,
                                     output_padding=1)
        
        self.l8 = nn.ConvTranspose2d(in_channels=8*self.latent_channels,
                                     out_channels=2*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=self.padding,
                                     output_padding=1)
        
        self.l9 = nn.ConvTranspose2d(in_channels=4*self.latent_channels,
                                     out_channels=2*self.latent_channels,
                                     kernel_size=self.kernel_size,
                                     stride=2,
                                     padding=self.padding,
                                     output_padding=1)
        
        self.l10 = nn.ConvTranspose2d(in_channels=3*self.latent_channels,
                                      out_channels=self.latent_channels,
                                      kernel_size=self.kernel_size,
                                      stride=2,
                                      padding=self.padding,
                                      output_padding=1)

        self.last_multiplier = nn.Conv2d(in_channels=self.in_channels+self.latent_channels,
                                         out_channels=self.out_channels,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
        
        self.activation = nn.ReLU()
        
        
    def forward(self, tensor):
        
        h1 = self.activation(self.l1(tensor))
        h2 = self.activation(self.l2(h1))
        h3 = self.activation(self.l3(h2))
        h4 = self.activation(self.l4(h3))
        h5 = self.activation(self.l5(h4))
        
        h6 = self.activation(self.l6(h5))
        h6 = torch.cat([h4, h6], dim=1)
        
        h7 = self.activation(self.l7(h6))
        h7 = torch.cat([h3, h7], dim=1)
        
        h8 = self.activation(self.l8(h7))
        h8 = torch.cat([h2, h8], dim=1)
        
        h9 = self.activation(self.l9(h8))
        h9 = torch.cat([h1, h9], dim=1)
        
        h10 = self.activation(self.l10(h9))
        h10 = torch.cat([tensor, h10], dim=1)
        
        output = self.last_multiplier(h10)
        
        return output

In [7]:
class StandardEnvUpscaleEncoder(nn.Module):
    
    def __init__(self, lengthscale):
        
        super().__init__()
        
        self.lengthscale = nn.Parameter(torch.tensor([lengthscale, lengthscale]))

        
    def convert_to_fine(self, tensor, lonlat_coarse, lonlat_fine):
        """
        Upscales **tensor** with corresponding coordinates **lonlat_coarse**
        to a finer discretisation with coordinates **lonlat_fine**.
        
        Arguments:
        
            tensor        : torch.tensor, shape (B, C, K, L)
            lonlat_coarse : torch.tensor, shape (K, L, 2)
            lonlat_fine   : torch.tensor, shape (N, M, 2)
        """
        
        # Compute differences between grid locations
        diff = lonlat_coarse[:, :, None, None, :] - \
               lonlat_fine[None, None, :, :, :]
        
        # Compute weight matrix
        quad = -0.5 * (diff / self.lengthscale[None, None, None, None, :]) ** 2
        quad = torch.sum(quad, axis=-1)
        exp = torch.exp(quad)
        exp = exp / torch.sum(exp, dim=[0, 1])[None, None, :]
        
        # Compute refined tensor
        tensor = torch.einsum('bckl, klnm -> bcnm', tensor, exp)
        
        return tensor
    
        
    def forward(self,
                lonlat_fine,
                elevation_fine,
                lonlat_coarse,
                reanalysis_coarse):
        
        # Get number of batches
        B = reanalysis_coarse.shape[0]
        elevation_fine = elevation_fine[None, None, :, :].repeat(B, 1, 1, 1)
        
        # Upscale reanalysis data to match elevation grid
        reanalysis_fine = self.convert_to_fine(tensor=reanalysis_coarse,
                                               lonlat_coarse=lonlat_coarse,
                                               lonlat_fine=lonlat_fine)
        
        # Concatenate reanalysis and elevation
        tensor = torch.cat([reanalysis_fine, 0.*elevation_fine], axis=1)
        
        return tensor

In [8]:
class StandardEnvDecoder(nn.Module):
    
    def __init__(self, lengthscale, out_channels):
        
        super().__init__()
        
        self.conv_in_channels = 26
        self.conv_latent_channels = 8
        self.conv_out_channels = 8
        self.out_channels = out_channels
        
        self.lengthscale = nn.Parameter(torch.tensor([lengthscale, lengthscale]))
        
        self.cnn = StandardEnvUNet(in_channels=self.conv_in_channels,
                                   latent_channels=self.conv_latent_channels,
                                   out_channels=self.conv_out_channels)

        self.l1 = nn.Linear(in_features=self.conv_out_channels+1,
                            out_features=self.out_channels,
                            bias=True)

        self.l2 = nn.Linear(in_features=self.out_channels,
                            out_features=self.out_channels,
                            bias=True)

        self.l3 = nn.Linear(in_features=self.out_channels,
                            out_features=self.out_channels,
                            bias=True)
        
        self.activation = nn.ReLU()
        

    def forward(self, tensor, elevation_fine, lonlat_fine, lonlat_target):
        """
        
        Arguments:
        
            tensor        : torch.tensor, shape (B, C1, K, L)
            lonlat_fine   : torch.tensor, shape (K, L, 2)
            lonlat_target : torch.tensor, shape (N, 2)
            
        Returns:
            tensor        : torch.tensor, shape (B, C2, N)
        """
        
        B = tensor.shape[0]
        elevation_fine = elevation_fine[None, None, :, :].repeat(B, 1, 1, 1)
        
        tensor = self.cnn(tensor)
        
        # Compute differences between grid locations
        diff = lonlat_fine[:, :, None, :] - \
               lonlat_target[None, None, :, :]
        
        # Compute weight matrix
        quad = -0.5 * (diff / self.lengthscale[None, None, None, :]) ** 2
        
        quad = torch.sum(quad, axis=-1)
        exp = torch.exp(quad)
        exp = exp / torch.sum(exp, dim=[0, 1])[None, None, :]
        
        # Compute refined tensor
        tensor = torch.cat([tensor, elevation_fine], dim=1)
        tensor = torch.einsum('bckl, kln -> bcn', tensor, exp)
        
        tensor = torch.permute(tensor, (0, 2, 1))
        tensor = self.l1(tensor)
        tensor = self.activation(tensor)
        tensor = self.l2(tensor) + tensor
        tensor = self.activation(tensor)
        tensor = self.l3(tensor) + tensor
        
        return tensor

In [9]:
class EnvUpscaleConvGNP(GaussianNeuralProcess):
    
    def __init__(self, covariance, add_noise):
        
        self.output_dim = 1
        self.encoder_lengthscale = 1.5
        self.decoder_lengthscale = 0.05
        
        # Construct the convolutional decoder
        decoder_out_channels = self.output_dim          + \
                               covariance.num_basis_dim + \
                               covariance.extra_cov_dim + \
                               add_noise.extra_noise_dim
        
        encoder = StandardEnvUpscaleEncoder(lengthscale=self.encoder_lengthscale)
        
        decoder = StandardEnvDecoder(lengthscale=self.decoder_lengthscale,
                                     out_channels=decoder_out_channels)

        super().__init__(encoder=encoder,
                         decoder=decoder,
                         covariance=covariance,
                         add_noise=add_noise)
        
        
    def forward(self, batch):
        
        # Pass through encoder
        tensor = self.encoder(lonlat_fine=batch['lonlat_fine'],
                              elevation_fine=batch['elevation_fine'],
                              lonlat_coarse=batch['lonlat_coarse'],
                              reanalysis_coarse=batch['reanalysis_coarse'])
        
        # Pass through decoder
        tensor = self.decoder(tensor,
                              elevation_fine=batch['elevation_fine'],
                              lonlat_fine=batch['lonlat_fine'],
                              lonlat_target=batch['lonlat_station'])
        
        # Produce mean
        mean = tensor[..., 0:1]
        
        # Produce cov
        embedding = tensor[..., 1:]
        cov = self.covariance(embedding)
        cov_plus_noise = self.add_noise(cov, embedding)
        
        return mean, cov, cov_plus_noise

    
    def loss(self, batch):

        y_mean, _, y_cov = self.forward(batch)

        y_mean = y_mean.double()
        y_cov = y_cov.double()
        y_target = batch['temperature_station'].double()

        jitter = 1e-3 * torch.eye(y_cov.shape[-1], device=y_cov.device).double()
        y_cov = y_cov + jitter[None, :, :]
        
        mae = torch.mean(torch.abs(y_mean[:, :, 0] - y_target))
        
        dist = MultivariateNormal(loc=y_mean[:, :, 0],
                                  covariance_matrix=y_cov)
        nll = - torch.mean(dist.log_prob(y_target.double()))

        return nll.float(), mae.float()

In [10]:
torch.manual_seed(0)
np.random.seed(0)

In [11]:
iterations_per_epoch = 5000
batch_size = 8
max_num_target = 8

test_batch_size = 4
test_max_num_target = 24

train_data = Dataloader(lonlat_fine=lonlat_fine,
                        lonlat_coarse=lonlat_coarse,
                        elevation_fine=elevation_fine,
                        reanalysis_coarse=train_reanalysis_coarse,
                        lonlat_station=train_lonlat_station,
                        temperature_station=train_temperature_station,
                        iterations_per_epoch=iterations_per_epoch,
                        max_num_target=max_num_target,
                        batch_size=batch_size,
                        device=device)

valid_data = Dataloader(lonlat_fine=lonlat_fine,
                        lonlat_coarse=lonlat_coarse,
                        elevation_fine=elevation_fine,
                        reanalysis_coarse=valid_reanalysis_coarse,
                        lonlat_station=valid_lonlat_station,
                        temperature_station=valid_temperature_station,
                        iterations_per_epoch=iterations_per_epoch,
                        max_num_target=max_num_target,
                        batch_size=batch_size,
                        device=device)

test_data = Dataloader(lonlat_fine=lonlat_fine,
                       lonlat_coarse=lonlat_coarse,
                       elevation_fine=elevation_fine,
                       reanalysis_coarse=test_reanalysis_coarse,
                       lonlat_station=test_lonlat_station,
                       temperature_station=test_temperature_station,
                       iterations_per_epoch=iterations_per_epoch,
                       max_num_target=test_max_num_target,
                       batch_size=test_batch_size,
                       device=device)

In [None]:
num_basis_dim = 128

# covariance = KvvCov(num_basis_dim)
# add_noise = AddHomoNoise()

covariance = MeanFieldCov(num_basis_dim=1)
add_noise = AddNoNoise()

model = EnvUpscaleConvGNP(covariance=covariance,
                          add_noise=add_noise)
model.to(device)

optimizer = torch.optim.Adam(lr=5e-4, params=model.parameters()) #, weight_decay=1e-4)

mae_scale = train_temperature_station_std

for i, (train_batch, valid_batch, test_batch) in enumerate(zip(train_data, valid_data, test_data)):
    
    optimizer.zero_grad()
    
    loss, mae = model.loss(train_batch)
    loss.backward()
    optimizer.step()
    
    if i % 20 == 0:
    
        with torch.no_grad():
            valid_loss, valid_mae = model.loss(valid_batch)
            test_loss, test_mae = model.loss(test_batch)
            
        encoder_scale = model.encoder.lengthscale.detach().cpu().numpy()
        decoder_scale = model.decoder.lengthscale.detach().cpu().numpy()
            
        print(f'Loss: {loss:4.2f}, {valid_loss:4.2f}, {test_loss:4.2f} '
              f'MAE: {mae*mae_scale:4.2f}, {valid_mae*mae_scale:4.2f}, {test_mae*mae_scale:4.2f} '
              f'Encoder scale: {encoder_scale[0]:.2f}, {encoder_scale[1]:.2f} '
              f'Decoder scale: {decoder_scale[0]:.2f}, {decoder_scale[1]:.2f} ')

Loss: 16.59, 24.00, 23.01 MAE: 10.27, 18.75, 13.81 Encoder scale: 1.50, 1.50 Decoder scale: 0.05, 0.05 
Loss: 18.44, 18.48, 20.52 MAE: 9.42, 9.01, 10.40 Encoder scale: 1.51, 1.50 Decoder scale: 0.05, 0.05 
Loss: 14.25, 21.66, 20.15 MAE: 7.70, 12.30, 8.87 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 13.43, 22.02, 11.88 MAE: 7.34, 11.66, 4.78 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 14.47, 14.58, 18.79 MAE: 7.06, 8.26, 9.38 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 14.60, 13.29, 11.75 MAE: 7.14, 6.98, 6.02 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 12.08, 10.70, 23.87 MAE: 5.17, 4.38, 12.05 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 16.50, 13.46, 17.47 MAE: 9.57, 5.99, 8.84 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 12.34, 12.76, 15.39 MAE: 6.11, 6.04, 7.82 Encoder scale: 1.51, 1.49 Decoder scale: 0.05, 0.05 
Loss: 19.21, 14.30, 14.12 MAE: 11.68, 6.71, 7.33 Encoder scale: 1.51, 1.49 Decoder s

In [None]:
# plt.figure(figsize=(16, 20))
# for i in range(80):
    
#     plt.subplot(10, 8, i+1)
#     temp = train_temperature_station[:, i].numpy()
#     temp = temp[~np.isnan(temp)]
    
#     plt.hist(temp)
    
# plt.tight_layout()
# plt.show()

In [None]:
plt.scatter(train_lonlat_station[:, 0],
            train_lonlat_station[:, 1],
            zorder=2)

plt.scatter(train_lonlat_station[:, 0],
            train_lonlat_station[:, 1],
            zorder=2)

plt.scatter(valid_lonlat_station[:, 0],
            valid_lonlat_station[:, 1],
            zorder=3)

plt.contourf(lonlat_fine[:, :, 0],
             lonlat_fine[:, :, 1],
             elevation_fine, origin='lower',
             alpha=0.5,
             cmap='coolwarm',
             zorder=1)

plt.show()