In [4]:
%load_ext autoreload
%autoreload 2

In [63]:
import mlflow
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [2]:
model = mlflow.pytorch.load_model('parameters/20221217/best_ad_mean')

In [64]:
def sample(gpr, noise, position):

    y_samples = []
    y_mean, y_cov = gpr.predict(position.reshape(2, -1).T, return_cov=True)

    # add small perturbation, since matrix often ends up being singular
    y_cov += 1e-7 * np.eye(y_cov.shape[0])
    b = np.linalg.cholesky(y_cov)
    y_samples.append(y_mean + np.dot(b, noise))

    return torch.tensor(np.array(y_samples))


In [77]:
def plot_gpr_contourplot(gpr, gridsize=48):
    """Plot contour plot of the GPR model.
    """
    fig, ax = plt.subplots(figsize=(15, 5), ncols=3)
    fig.tight_layout(pad=5.0)

    lon = torch.linspace(-16, 16, gridsize)
    lat = torch.linspace(-16, 16, gridsize)
    LON, LAT = torch.meshgrid(lon, lat)
    posgrid = torch.cat([LAT.reshape(-1), LON.reshape(-1)])
    noise = torch.randn((gridsize ** 2,))

    mean, std = gpr.predict(posgrid.reshape(2, -1).T, return_std=True)
    mean = mean.reshape(gridsize, gridsize)
    std = std.reshape(gridsize, gridsize)

    cs0 = ax[0].contourf(LON, LAT, mean, 200)
    ax[0].set_title('mean')
    
    divider = make_axes_locatable(ax[0])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(cs0, cax=cax, orientation='vertical')

    cs1 = ax[1].contourf(LON, LAT, std, 200)
    ax[1].set_title('std')

    divider = make_axes_locatable(ax[1])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(cs1, cax=cax, orientation='vertical')

    samples = sample(gpr, noise, posgrid).reshape(gridsize, gridsize)
    cs2 = ax[2].contourf(LON, LAT, samples, 200)
    ax[2].set_title('sample')
    
    divider = make_axes_locatable(ax[2])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(cs2, cax=cax, orientation='vertical')

    cmap = plt.get_cmap('viridis')
    norm = matplotlib.colors.Normalize(vmin=mean.min(), vmax=mean.max())

    for (lat, lon), sst in zip(gpr.X_train_, gpr.y_train_):
        ax[2].scatter(lon, lat, color=cmap(norm(sst)), s=100, edgecolor='white', linewidth=1)
        
    return fig


In [None]:
date = '2008-01-01'
idx = np.argmax(pd.date_range('2008-01-01', '2016-12-31') == date)
fig = plot_gpr_contourplot(model.gprs[idx], 48)
fig.savefig(f'figures/gpr/{date}.png')