In [1]:
import os
import gc
import math

# numpy
import numpy as np
import numpy.ma as ma

from functools import partial
from jax import vmap, grad, jit, random
from jax.config import config
# from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

config.update("jax_enable_x64", True)

from skimage.restoration import estimate_sigma
from skimage.transform import resize

# netCDF
import netCDF4 as nc

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import pickle

In [2]:
### model init

def exp_params(key):

    subkeys = random.split(key, 3)

    base_layer = jnp.array([random.uniform(subkeys[0], minval = 3.5, maxval = 4.5),
                            random.uniform(subkeys[1], minval = 8, maxval = 12),
                            random.uniform(subkeys[2], minval = 1, maxval = 5)])
    base_layer = [tuple(base_layer)]

    return base_layer

@jit
def random_init(params):
    return opt_init(params)
parallel_init = vmap(random_init)

@jit
def param_abs(x):
    return jnp.abs(x)

@jit
def param_select(param_list, i):
    return tree_map(lambda x: x[i], param_list)

In [3]:
### model specification

## nngp functions
def exp_kernel(params, x, y):
    x = jnp.array(x)
    y = jnp.array(y)
    
    return params[0][0] * jnp.exp(-jnp.sqrt(jnp.mean((x - y)**2)) / params[0][1])

exp_kernel = jit(exp_kernel)
exp_cov = jit(vmap(vmap(exp_kernel, (None, None, 0)), (None, 0, None)))
exp_var = jit(vmap(exp_kernel, (None, 0, 0)))

def exp_predict(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = exp_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = exp_cov(params, xtest, xtrain)
    
    proj0 = jnp.matmul(k_xD0, prec0)
    return jnp.matmul(proj0, ytrain)
exp_predict = jit(exp_predict)


def exp_dist(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = exp_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = exp_cov(params, xtest, xtrain)
    proj0 = jnp.matmul(k_xD0, prec0)
    
    mu = jnp.matmul(proj0, ytrain)
    
    k_xx = exp_var(params, xtest, xtest)
    sig = k_xx - jnp.sum(proj0 * k_xD0, axis = 1)
    
    return mu, sig
exp_dist = jit(exp_dist)

In [4]:
### model specification

## nngp functions
def se_kernel(params, x, y):
    x = jnp.array(x)
    y = jnp.array(y)
    
    return params[0][0] * jnp.exp(-jnp.mean((x - y)**2) / params[0][1]**2)

se_kernel = jit(se_kernel)
se_cov = jit(vmap(vmap(se_kernel, (None, None, 0)), (None, 0, None)))
se_var = jit(vmap(se_kernel, (None, 0, 0)))


def se_predict(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = se_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = se_cov(params, xtest, xtrain)
    
    proj0 = jnp.matmul(k_xD0, prec0)
    return jnp.matmul(proj0, ytrain)
se_predict = jit(se_predict)


def se_dist(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = se_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = se_cov(params, xtest, xtrain)
    proj0 = jnp.matmul(k_xD0, prec0)
    
    mu = jnp.matmul(proj0, ytrain)
    
    k_xx = se_var(params, xtest, xtest)
    sig = k_xx - jnp.sum(proj0 * k_xD0, axis = 1)
    
    return mu, sig
se_dist = jit(se_dist)

In [5]:
#### loss functions

def compute_lr(pgrad, scale = 2):
    leaves = jnp.array(tree_leaves(pgrad))
    lrs = 10**(-jnp.floor(jnp.log10(jnp.abs(leaves))) - scale)
    return lrs

def weighted_mse(k_inv, ytrain):
    return jnp.matmul(jnp.matmul(ytrain.T, k_inv), ytrain)
weighted_mse = jit(vmap(weighted_mse, (None, 1)))


def nll_loss_exp(params, xtrain, ytrain):
    n, p = ytrain.shape

    k = exp_cov(params, xtrain, xtrain)
    k += jnp.diag(jnp.repeat(params[0][2], n))
    k_inv = jnp.linalg.inv(k)
    
    mse = jnp.sum(weighted_mse(k_inv, ytrain))
    pen = p * jnp.linalg.slogdet(k)[1]
    nor = p * jnp.log(2 * math.pi)

    return (mse + pen + nor) / (2 * p)
nll_loss_exp = jit(nll_loss_exp)
grad_loss_exp = jit(grad(nll_loss_exp))

def nll_loss_se(params, xtrain, ytrain):
    n, p = ytrain.shape

    k = se_cov(params, xtrain, xtrain)
    k += jnp.diag(jnp.repeat(params[0][2], n))
    k_inv = jnp.linalg.inv(k)
    
    mse = jnp.sum(weighted_mse(k_inv, ytrain))
    pen = p * jnp.linalg.slogdet(k)[1]
    nor = p * jnp.log(2 * math.pi)

    return (mse + pen + nor) / (2 * p)
nll_loss_se = jit(nll_loss_se)
grad_loss_se = jit(grad(nll_loss_se))

def gradient_step_exp(params, xtrain, ytrain):
        
    param, tdef = tree_flatten(params)
    pgrad = tree_flatten(grad_loss_exp(params, xtrain, ytrain))[0]
    
    lrs = compute_lr(pgrad, 2)
    param = [a - lr * b for a, b, lr in zip(tree_leaves(param), pgrad, lrs)]
    
    params = tree_unflatten(tdef, param)
    params = tree_map(param_abs, params)
    return params
gradient_step_exp = jit(gradient_step_exp)


def gradient_step_sqe(params, xtrain, ytrain):
        
    param, tdef = tree_flatten(params)
    pgrad = tree_flatten(grad_loss_se(params, xtrain, ytrain))[0]
    
    lrs = compute_lr(pgrad, 2)
    param = [a - lr * b for a, b, lr in zip(tree_leaves(param), pgrad, lrs)]
    
    params = tree_unflatten(tdef, param)
    params = tree_map(param_abs, params)
    return params
gradient_step_sqe = jit(gradient_step_sqe)

In [6]:
xhist = pickle.load(open('../submit/data/saved/xhist_pr.pkl', 'rb'))
xrcp = pickle.load(open('../submit/data/saved/xrcp_pr.pkl', 'rb'))

nval = 72
ntest = len(xrcp[0])
# ntest = 72 + 336
nmod = len(xhist)

In [7]:
### run experiments
nmod = len(xhist)
ntrain = xhist[0].shape[0]
ntest = xrcp[0].shape[0]


for m1 in trange(nmod):

    _, nlat, nlon = xhist[m1].shape

    #### construct training set
    xtrain = []
    for m2 in trange(nmod, leave = False):
        if m1 != m2:
            x1 = xhist[m2].reshape(ntrain, -1)
            x2 = xrcp[m2][0:nval].reshape(nval, -1)

            xtrain.append(np.vstack([x1, x2]))

    xmean = np.array([np.mean(f, axis = 1) for f in xtrain])
    xtrain = jnp.hstack(xtrain)

    y1 = xhist[m1].reshape(ntrain, -1)
    y2 = xrcp[m1][0:nval].reshape(nval, -1)

    ytrain = jnp.array(np.vstack([y1, y2]))
    ymean = jnp.mean(ytrain, axis = 1)

    x = jnp.vstack([xmean]).T
    beta = jnp.linalg.inv(x.T @ x) @ x.T @ ymean
    ytrain = ytrain - (x @ beta)[:,None]


    #### test
    xtest = []
    for m2 in range(nmod):
        if m1 != m2:
            xtest.append(xrcp[m2][nval:ntest].reshape(ntest-nval, -1))

     ## center and join data
    xmean = np.array([np.mean(f, axis = 1) for f in xtest])
    xtest = jnp.hstack(xtest)
    ytest = np.array(xrcp[m1][nval:ntest])
    x = jnp.vstack([xmean]).T


    ## randomize ensemble init
    key = random.PRNGKey(1023)
    params_exp = exp_params(key)
    params_sqe = exp_params(key)

    ## fit ensemble
    for _ in trange(300):
        params_exp = gradient_step_exp(params_exp, xtrain, ytrain)
        params_sqe = gradient_step_sqe(params_sqe, xtrain, ytrain) 

    scale = np.cos(np.linspace(math.pi/2, -math.pi/2, nlat))

    exp_hat, exp_predvar = exp_dist(params_exp, xtest, xtrain, ytrain)
    exp_hat += (x @ beta)[:,None]
    exp_hat = exp_hat.reshape(-1, nlat, nlon)
    # exp_hat *= scale[:,]
    exp_predvar = np.sqrt(exp_predvar + params_exp[0][2])

    sqe_hat, sqe_predvar = se_dist(params_sqe, xtest, xtrain, ytrain)
    sqe_hat += (x @ beta)[:,None]
    sqe_hat = sqe_hat.reshape(-1, nlat, nlon)
    sqe_predvar = np.sqrt(sqe_predvar + params_sqe[0][2])

    np.save(f'../submit/experiments/pred/gpr_exp_pr_{m1}.npz', exp_hat)
    np.save(f'../submit/experiments/pred/gpr_exp_pr_var_{m1}.npz', exp_predvar)
    
    np.save(f'../submit/experiments/pred/gpr_se_pr_{m1}.npz', sqe_hat)
    np.save(f'../submit/experiments/pred/gpr_se_pr_var_{m1}.npz', sqe_predvar)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]