# Model training 

In [1]:
import os
import gc
import math

# numpy
import numpy as np
import scipy

# from jax
import jax
from jax import jacfwd, jacrev
from jax import vmap, grad, jit, random
from jax.config import config
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

import haiku as hk
import flax
from flax import linen as nn
import optax

import properscoring as ps

config.update("jax_enable_x64", True)

# plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.ticker import FormatStrFormatter
plt.style.use('default')

from matplotlib.gridspec import GridSpec
from matplotlib import colors
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap

from tqdm.notebook import tqdm
from tqdm.notebook import trange

# netCDF
import netCDF4 as nc
import pickle
import warnings
warnings.filterwarnings('ignore')

from sklearn.decomposition import PCA

In [2]:
def mse_weighted(x, y):
    nlat = x.shape[1]
    weight = np.cos(np.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight /= np.mean(weight)
    weight = weight[None,:,None]
    
    return np.sqrt(np.mean((x - y)**2 * weight, axis = (1, 2)))

def resize_model(x, dim1, dim2): 
    x = np.moveaxis(x, 0, 2)
    x = resize(x, (dim1, dim2))
    x = np.moveaxis(x, 2, 0)
    return x

In [3]:
def scale_and_split(xtrain, xtest, model_no, nval = 200, reshape = False):
    ytrain = xtrain[:,model_no]
    ytest = xtest[:,model_no]
    xtrain = np.delete(xtrain, model_no, axis=1)
    xtest = np.delete(xtest, model_no, axis=1)

    ## dimensions
    ntrain = xtrain.shape[0]
    ntest = xtest.shape[0]
    nlat = xtrain.shape[2]
    nlon = xtrain.shape[3]

    ## rescale
    xtrain_mean = np.mean(xtrain, axis = 1)
    xtest_mean = np.mean(xtest, axis = 1)
    xtrain_sd = np.std(xtrain, axis = 1)
    xtest_sd = np.std(xtest, axis = 1)

    xtrain_scaled = (xtrain - xtrain_mean[:,None]) / xtrain_sd[:,None]
    xtest_scaled = (xtest - xtest_mean[:,None]) / xtest_sd[:,None]
    ytrain_scaled = (ytrain - xtrain_mean) / xtrain_sd
    ytest_scaled = (ytest - xtest_mean) / xtest_sd

    ## convert
    if reshape:
        xtrain_scaled = xtrain_scaled.reshape(ntrain, -1)
        xtest_scaled = xtest_scaled.reshape(ntest, -1)
        ytrain_scaled = ytrain_scaled.reshape(ntrain, -1)
        ytest_scaled = ytest_scaled.reshape(ntest, -1)

    xtrain_scaled = jnp.array(xtrain_scaled)
    xtest_scaled = jnp.array(xtest_scaled)
    ytrain_scaled = jnp.array(ytrain_scaled)
    ytest_scaled = jnp.array(ytest_scaled)

    xval_scaled = xtrain_scaled[-nval:]
    yval_scaled = ytrain_scaled[-nval:]
    xtrain_scaled = xtrain_scaled[:-nval]
    ytrain_scaled = ytrain_scaled[:-nval]
    
    return [xtrain_scaled, xval_scaled, xtest_scaled,
            ytrain_scaled, yval_scaled, ytest_scaled,
            xtrain_mean, xtest_mean, xtrain_sd, xtest_sd]

def anomalize(xtrain, xtest):
    xtrain_anom = np.zeros_like(xtrain)
    xtest_anom = np.zeros_like(xtest)

    ntrain = xtrain.shape[0]
    ntest = xtest.shape[0]

    monthly_means = np.array([np.mean(xtrain[(12*20):(12*50)][i::12], axis = 0) for i in range(12)])
    for k in trange(ntrain, leave = False):
        xtrain_anom[k] = xtrain[k] - monthly_means[k % 12]

    for k in trange(ntest, leave = False):
        xtest_anom[k] = xtest[k] - monthly_means[k % 12]
    
    return np.array(xtrain_anom), np.array(xtest_anom)

## model definitions

In [4]:
## regression
def ols(x, y):
    return jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(x.T, x)), x.T), y)

ols = jit(ols)
v_ols = vmap(vmap(ols, (2, 1)), (3, 2))
v_ols = jit(v_ols)

def reg_pred(x, beta):
    return jnp.sum(x * beta[None,:,:,:], axis = 1)

reg_pred = jit(reg_pred)

In [5]:
### delta
def delta_pred(xtest, xtrain, ytrain):
    delta = jnp.mean(xtest - xtrain, axis = 1)
    return jnp.median(ytrain + delta, axis = 0)

delta_pred = jit(vmap(delta_pred, (0, None, None)))

# def delta_var(xtest, xtrain, ytrain):
#     delta = jnp.mean(xtest - xtrain, axis = 3)
#     return jnp.std(ytrain + delta, axis = 0)

# delta_var = jit(delta_var)

In [6]:
### model init

def nngp_params(key, depth = 1):

    subkeys = random.split(key, 3 + 2 * (depth - 1))

    base_layer = jnp.array([random.uniform(subkeys[0], minval = 0.75, maxval = 1.25), 
                            random.uniform(subkeys[1], minval = 0.1, maxval = 0.5),
                            random.uniform(subkeys[2], minval = 0.2, maxval = 0.7)])
    base_layer = [tuple(base_layer)]

    return base_layer
parallel_nngp_params = vmap(nngp_params, in_axes=(0, None))


def random_params(key, depth):
    nngp_par = nngp_params(key, depth)
    trend_par = trend_params(key)
    
    return [nngp_par, trend_par]
parallel_params = vmap(random_params, in_axes=(0, None))

@jit
def random_init(params):
    return opt_init(params)
parallel_init = vmap(random_init)

@jit
def param_abs(x):
    return jnp.abs(x)

@jit
def param_select(param_list, i):
    return tree_map(lambda x: x[i], param_list)


### model specification

depth = 7

## nngp functions
def nngp_kernel(params, x, y):
    x = jnp.array(x)
    y = jnp.array(y)
    
    kxx = params[0][1] + params[0][0] * (jnp.dot(x, x.T) / x.shape[0])
    kyy = params[0][1] + params[0][0] * (jnp.dot(y, y.T) / x.shape[0])
    kxy = params[0][1] + params[0][0] * (jnp.dot(x, y.T) / x.shape[0])
    
    for i in range(depth):
        
        ## kxy
        cor = jnp.clip(kxy / jnp.sqrt(kxx * kyy), -1.0 + 1e-16, 1.0 - 1e-16)
        theta = jnp.arccos(cor)
        trig = jnp.sin(theta) + (math.pi - theta) * jnp.cos(theta)
        kxy = params[0][1] + (params[0][0] / (2 * math.pi)) * jnp.sqrt(kxx * kyy) * trig
                
        ## kxx
        kxx = params[0][1] + (params[0][0] / 2) * kxx
    
        ## kyy
        kyy = params[0][1] + (params[0][0] / 2) * kyy
        
    return kxy

nngp_kernel = jit(nngp_kernel)
nngp_cov = jit(vmap(vmap(nngp_kernel, (None, None, 0)), (None, 0, None)))
nngp_var = jit(vmap(nngp_kernel, (None, 0, 0)))


def nngp_predict(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = nngp_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = nngp_cov(params, xtest, xtrain)
    
    proj0 = jnp.matmul(k_xD0, prec0)
    return jnp.matmul(proj0, ytrain)
nngp_predict = jit(nngp_predict)


def nngp_dist(params, xtest, xtrain, ytrain):
    n, p = ytrain.shape
        
    k_DD0 = nngp_cov(params, xtrain, xtrain)
    k_DD0 += jnp.diag(jnp.repeat(params[0][2], n))
    prec0 = jnp.linalg.inv(k_DD0)
    
    k_xD0 = nngp_cov(params, xtest, xtrain)
    proj0 = jnp.matmul(k_xD0, prec0)
    
    mu = jnp.matmul(proj0, ytrain)
    
    k_xx = nngp_var(params, xtest, xtest)
    sig = k_xx - jnp.sum(proj0 * k_xD0, axis = 1)
    
    return mu, sig
nngp_dist = jit(nngp_dist)


#### loss functions

def compute_lr(pgrad, scale = 2):
    leaves = jnp.array(tree_leaves(pgrad))
    lrs = 10**(-jnp.floor(jnp.log10(jnp.abs(leaves))) - scale)
    return lrs

def weighted_mse(k_inv, ytrain):
    return jnp.matmul(jnp.matmul(ytrain.T, k_inv), ytrain)
weighted_mse = jit(vmap(weighted_mse, (None, 1)))


def full_nll_loss(params, xtrain, ytrain):
    n, p = ytrain.shape

    k = nngp_cov(params, xtrain, xtrain)
    k += jnp.diag(jnp.repeat(params[0][2], n))
    k_inv = jnp.linalg.inv(k)
    
    mse = jnp.sum(weighted_mse(k_inv, ytrain))
    pen = p * jnp.linalg.slogdet(k)[1]
    nor = p/2 * jnp.log(2 * math.pi)
    return (0.5*mse + 0.5*pen + nor) / p
full_grad_loss = jit(grad(full_nll_loss))


def gradient_step(params, xtrain, ytrain):
        
    param, tdef = tree_flatten(params)
    pgrad = tree_flatten(full_grad_loss(params, xtrain, ytrain))[0]
    
    lrs = compute_lr(pgrad, 3)
    param = [a - lr * b for a, b, lr in zip(tree_leaves(param), pgrad, lrs)]

    params = tree_unflatten(tdef, param)
    params = tree_map(param_abs, params)
    return params
gradient_step = jit(gradient_step)
# parallel_gradient_step = vmap(gradient_step, in_axes = (0, None, None))

In [7]:
class encoder(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(64, kernel_size=(3,3))(x)
        x = nn.selu(x)
        x = nn.Conv(64, kernel_size=(3,3))(x)
        x = nn.selu(x)
        x1 = x
        x = nn.Conv(64, kernel_size=(3,3))(x)
        x = nn.selu(x)
        x = nn.Conv(64, kernel_size=(3,3))(x)
        x = x1 + x
        x = nn.Conv(1, kernel_size=(3,3))(x)
        return x
    
def mse_loss(model, theta, x, y):
    z = model.apply(theta, x).squeeze()
    return jnp.mean((z - y)**2)
grad_mse_loss = jit(grad(mse_loss, argnums=1), static_argnames=['model'])

def optim_update(model, theta, opt_state, x, y):
    grads = grad_mse_loss(model, theta, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(theta, updates), opt_state
optim_update = jit(optim_update, static_argnames=['model'])

In [8]:
def init_wa(x):
    return [jnp.zeros(x.shape[2:]), jnp.ones(x.shape[1:]) / x.shape[1]]

def weighted_average(theta, x):
    b = theta[0]
    w = theta[1]
    return jnp.sum(x * w[None], axis = 1) + b
    
def wa_loss(theta, x, y):
    z = weighted_average(theta, x).squeeze()
    return jnp.sqrt(jnp.mean((z - y)**2))
grad_wa_loss = jit(grad(wa_loss))

def optim_wa_update(theta, opt_state, x, y):
    grads = grad_wa_loss(theta, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(theta, updates), opt_state
optim_wa_update = jit(optim_wa_update)

In [9]:
# nmod = 30
# ntrain = 1000
# nval = 200
# ntest = 1000
# np.random.seed(0)
# ztrain = np.random.normal(0, 1, [ntrain, nmod, 30, 50])
# ztest = np.random.normal(0, 1, [ntest, nmod, 30, 50])
# ntrain = ntrain - nval

In [10]:
# class encoder2(nn.Module):
#     @nn.compact
#     def __call__(self, x):
#         x = nn.Conv(32, kernel_size=(11,11))(x)
#         x = nn.selu(x)
#         x = nn.Conv(32, kernel_size=(7,7))(x)
# #         x = nn.selu(x)
# #         x = nn.Conv(32, kernel_size=(1,1))(x)
# #         x = nn.selu(x)
# #         x = nn.Conv(32, kernel_size=(1,1))(x)
# #         x = nn.selu(x)
# #         x = nn.Conv(32, kernel_size=(1,1))(x)
# #         x = nn.selu(x)
#         x = nn.Conv(1, kernel_size=(1,1))(x)
#         return x
    
# def mse_loss(model, theta, x, y):
#     z = model.apply(theta, x).squeeze()
#     return jnp.mean((z - y)**2)
# grad_mse_loss = jit(grad(mse_loss, argnums=1), static_argnames=['model'])

# def optim_update(model, theta, opt_state, x, y):
#     grads = grad_mse_loss(model, theta, x, y)
#     updates, opt_state = opt_update(grads, opt_state)
#     return optax.apply_updates(theta, updates), opt_state
# optim_update = jit(optim_update, static_argnames=['model'])

## testing and eval

In [11]:
def global_mmd(x, y):
    dist = jnp.max(jnp.abs(x - y), axis = (1, 2))
    dist = jnp.mean(dist)
    return 1/(1 + dist)
global_mmd = jit(vmap(global_mmd, (None, 0)))

def global_mmd_self(x, y):
    dist = jnp.max(jnp.abs(x - y), axis = (1, 2))
    dist = jnp.mean(jnp.sort(dist)[1:])
    return 1/(1 + dist)
global_mmd_self = jit(vmap(global_mmd_self, (None, 0)))

def conformal_ensemble(resval, depth_fn, alpha):
    nval = resval.shape[0]
    level = np.ceil((1 - alpha) * (nval + 1))/(nval + 1)
    dr1 = 1 - depth_fn(resval, resval)
    q = np.sort(dr1)[int(np.ceil((1 - alpha) * (nval + 1)))]
    q = (q + np.quantile(dr1, level))/2
    return resval[dr1 < q]

In [12]:
def quantile_interp(model, model_quant, ref_quant):
    return jnp.interp(model, model_quant, ref_quant)
quantile_interp = vmap(quantile_interp, (1, 1, 1))
quantile_interp = vmap(quantile_interp, (1, 1, 1))
quantile_interp = jit(quantile_interp)

def quantile_map(ref, model, n):
    ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, n), axis = 0)
    model_quant = jnp.quantile(model, jnp.linspace(0, 1, n), axis = 0)
    corrected_model = quantile_interp(model, model_quant, ref_quant)
    return jnp.moveaxis(corrected_model, 2, 0)

def quantile_map(ref, model_hist, model_future, n):
    ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, n), axis = 0)
    model_quant = jnp.quantile(model_hist, jnp.linspace(0, 1, n), axis = 0)
    corrected_model = quantile_interp(model_future, model_quant, ref_quant)
    return jnp.moveaxis(corrected_model, 2, 0)

def quantile_map(ref, model_hist, model_future):
    ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, 20), axis = 0)
    model_quant = jnp.quantile(model_hist, jnp.linspace(0, 1, 20), axis = 0)
    corrected_model = quantile_interp(model_future, model_quant, ref_quant)
    return jnp.moveaxis(corrected_model, 2, 0)
quantile_map = jit(quantile_map)

In [13]:
def slices(nlat, nlon, seed = 0):
    np.random.seed(seed)
    w = np.random.normal(0,1, [500, nlat*nlon])
    w = w / np.sqrt(np.sum(w**2, axis = 1))[:,None]
    return w

def sw2(x, y, w):
    x = x.reshape(-1, nlat*nlon)
    y = y.reshape(-1, nlat*nlon)
    
    x = x @ w.T
    y = y @ w.T
    
    qx = jnp.quantile(x, jnp.linspace(0, 1, 50), axis = 0)
    qy = jnp.quantile(y, jnp.linspace(0, 1, 50), axis = 0)
    return jnp.mean(jnp.sqrt(jnp.mean((qx - qy)**2, axis = 0)))

sw2 = jit(sw2)


def w2(x, y):
    qx = jnp.quantile(x, jnp.linspace(0.005, 0.995, 100), axis = 0)
    qy = jnp.quantile(y, jnp.linspace(0.005, 0.995, 100), axis = 0)
    return jnp.sqrt(jnp.mean((qx - qy)**2))

w2 = vmap(w2, (1, 1))
w2 = vmap(w2, (1, 1))
w2 = jit(w2)

def pointwise_crps(ensemble, y): 
    n = ensemble.shape[0]//2
    a = 0.5 * jnp.mean(jnp.abs(ensemble[:n] - ensemble[n:]), axis = 0)
    b = jnp.mean(jnp.abs(ensemble - y), axis = 0)
    return b - a
single_ens_crps = jit(vmap(pointwise_crps, (None, 0)))

def _pit(ensemble, y):
    pit = jnp.mean(jnp.array(ensemble[None,:] < y[:,None]).squeeze(), axis = 1)
    pit_cdf = jnp.array([jnp.mean(pit < x, axis = 0) for x in np.linspace(0, 1, 50)])
    uni_cdf = jnp.linspace(0, 1, 50)[:,None,None]
    pit_score = jnp.sqrt(jnp.mean((pit_cdf - uni_cdf)**2, axis = 0))
    return pit_score

pit = jit(_pit)

def boot(x):
    n = x.shape[0]
    idx = np.random.choice(n, n, replace = True)
    return x[idx]

def shuffle(x):
    n = x.shape[0]
    idx = np.random.choice(n, n, replace = False)
    return x[idx]

def moving_average(a, n=12):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def mse_weighted(x, y):
    nlat = x.shape[1]
    weight = np.cos(np.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight /= np.mean(weight)
    weight = weight[None,:,None]
    
    return np.sqrt(np.mean((x - y)**2 * weight, axis = (1, 2)))

In [14]:
nproj = 6
nval = 200

xhist_tas = pickle.load(open('../data/xhist_tas_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_tas_hr.pkl', 'rb'))
nmod = xhist_tas.shape[1]
nlat, nlon = xhist_tas[0].shape[1:]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

save_loc = 'trained_models/tas'

xtrain_anom, xtest_anom = anomalize(xtrain_orig, xtest_orig)
xtrain_anom = jnp.array(xtrain_anom)
xtest_anom = jnp.array(xtest_anom)
del xtrain_orig, xtest_orig

  0%|          | 0/1008 [00:00<?, ?it/s]

  0%|          | 0/912 [00:00<?, ?it/s]

In [15]:
model_no = 12
_data = scale_and_split(xtrain_anom, xtest_anom, model_no = model_no, nval = nval)

xtrain = _data[0]
xval = _data[1]
xtest = _data[2]
ytrain = _data[3]
yval = _data[4]
ytest = _data[5]
xtrain_mean = _data[6]
xtest_mean = _data[7]
xtrain_sd = _data[8]
xtest_sd = _data[9]

ntrain = xtrain.shape[0]
nval = xval.shape[0]
ntest = xtest.shape[0]

del _data

weight = jnp.cos(jnp.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
weight = weight / jnp.mean(weight)
weight = weight[None,:,None]

#### NNGP
key = random.PRNGKey(1023)
params = nngp_params(key, depth)
for _ in trange(300, leave = False):
    params = gradient_step(params, xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))

  0%|          | 0/300 [00:00<?, ?it/s]

In [16]:
yval_hat = nngp_predict(params, xval.reshape(nval, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
ytest_hat = nngp_predict(params, xtest.reshape(ntest, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
yval_hat = yval_hat.reshape(-1, nlat, nlon)
ytest_hat = ytest_hat.reshape(-1, nlat, nlon)

xtrain2 = xtrain_sd[:-nval,None] * xtrain + xtrain_mean[:-nval,None]
ytrain2 = xtrain_sd[:-nval] * ytrain + xtrain_mean[:-nval]

yval2 = xtrain_sd[-nval:] * yval + xtrain_mean[-nval:]
yval_hat = xtrain_sd[-nval:] * yval_hat + xtrain_mean[-nval:]

xtest2 = xtest_sd[:,None] * xtest + xtest_mean[:,None]
ytest2 = xtest_sd * ytest + xtest_mean
ytest_hat = xtest_sd * ytest_hat + xtest_mean

In [17]:
resval = yval2 - yval_hat
restest = ytest2 - ytest_hat

imv_ens = xtest2
imv_ens = imv_ens - ytest_hat[:,None]

qc_ens = jnp.array([quantile_map(ytrain2, xtrain2[:,j], xtest2[:,j]) for j in range(nmod)])
qc_ens = np.moveaxis(qc_ens, 0, 1)
qc_ens = qc_ens - ytest_hat[:,None]

depth_fns = [global_mmd_self]
conf_ens1 = conformal_ensemble(resval, depth_fns[0], 0.1)

w = slices(nlat, nlon)
imv_sw = sw2(imv_ens, restest, w)
qmc_sw = sw2(qc_ens, restest, w)
cnf_sw = sw2(conf_ens1, restest, w)
imv_sw, qmc_sw, cnf_sw

(Array(1.62861029, dtype=float64),
 Array(1.67079928, dtype=float64),
 Array(1.14943193, dtype=float64))

## White noise control run

In [11]:
# nmod = 30
# ntrain = 1000
# nval = 200
# ntest = 1000
# np.random.seed(0)
# ztrain = np.random.normal(0, 1, [ntrain, nmod, 30, 50])
# ztest = np.random.normal(0, 1, [ntest, nmod, 30, 50])
# ntrain = ntrain - nval

# save_loc = 'trained_models/wn'

In [12]:
# for model_no in trange(nmod):
#     xtrain, xval, xtest, ytrain, yval, ytest = scale_and_split(ztrain, ztest, model_no = model_no, nval = nval)
    
#     #### Ens average
#     yval_hat = np.mean(xval, axis = 1)
#     ytest_hat = np.mean(xtest, axis = 1)
#     np.save(save_loc + f'/yval_ens_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_ens_{model_no}', ytest_hat)

#     #### Delta
#     yval_hat = jnp.vstack([delta_pred(xval[n:(n+10)], xtrain, ytrain) for n in trange(nval//10, leave = False)])
#     ytest_hat = jnp.vstack([delta_pred(xtest[n:(n+10)], xtrain, ytrain) for n in trange(ntest//10, leave = False)])
#     np.save(save_loc + f'/yval_delta_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_delta_{model_no}', ytest_hat)
    
#     #### linear regression
#     beta = v_ols(xtrain, ytrain)
#     beta = jnp.nan_to_num(beta)
#     beta = jnp.moveaxis(beta, (0, 1, 2), (2, 1, 0))
#     yval_hat = reg_pred(xval, beta)
#     ytest_hat = reg_pred(xtest, beta)
#     np.save(save_loc + f'/yval_lm_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_lm_{model_no}', ytest_hat)

#     #### NNGP
#     key = random.PRNGKey(1023)
#     params = nngp_params(key, depth)
#     for _ in trange(300, leave = False):
#         params = gradient_step(params, xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))

#     yval_hat = nngp_predict(params, xval.reshape(nval, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
#     ytest_hat = nngp_predict(params, xtest.reshape(ntest, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
#     np.save(save_loc + f'/yval_nngp_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_nngp_{model_no}', ytest_hat)
    
#     #### WA
#     theta = init_wa(xtrain)
#     opt_init, opt_update = optax.adam(1e-1)
#     opt_state = opt_init(theta)

#     for i in trange(1000, leave = False):
#         theta, opt_state = optim_wa_update(theta, opt_state, xtrain, ytrain)
    
#     yval_hat = weighted_average(theta, xval)
#     ytest_hat = weighted_average(theta, xtest)
#     np.save(save_loc + f'/yval_wa_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_wa_{model_no}', ytest_hat)
    
#     #### CNN
#     xtrain = jnp.moveaxis(xtrain, 1, 3)
#     xval = jnp.moveaxis(xval, 1, 3)
#     xtest = jnp.moveaxis(xtest, 1, 3)
#     key = random.key(0)
#     enc = encoder()
#     theta = enc.init(key, xtrain[0:5])
#     opt_init, opt_update = optax.adam(1e-3)
#     opt_state = opt_init(theta)

#     for i in trange(2000, leave = False):
#         idx = np.random.choice(ntrain, 32)
#         theta, opt_state = optim_update(enc, theta, opt_state, xtrain[idx], ytrain[idx])

#     yval_hat = enc.apply(theta, xval).squeeze()
#     ytest_hat = np.vstack([enc.apply(theta, xtest[(30*i):(30*(i+1))]) for i in range(ntest//30 + 1)]).squeeze()
#     np.save(save_loc + f'/yval_cnn1_{model_no}', yval_hat)
#     np.save(save_loc + f'/ytest_cnn1_{model_no}', ytest_hat)

## Temperature (TAS)

In [9]:
nproj = 5
nval = 200

xhist_tas = pickle.load(open('../data/xhist_tas_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_tas_hr.pkl', 'rb'))

# xhist_tas = np.moveaxis(np.array(xhist_tas), 0, 1)
# xrcp_tas = np.moveaxis(np.array(xrcp_tas), 0, 1)
nmod = xhist_tas.shape[1]
nlat = xhist_tas.shape[2]
nlon = xhist_tas.shape[3]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

save_loc = 'trained_models/tas'

In [11]:
xtrain_anom, xtest_anom = anomalize(xtrain_orig, xtest_orig)

  0%|          | 0/1008 [00:00<?, ?it/s]

  0%|          | 0/912 [00:00<?, ?it/s]

In [12]:
for model_no in trange(nmod):
    _data = scale_and_split(xtrain_anom, xtest_anom, model_no = model_no, nval = nval)

    xtrain = _data[0]
    xval = _data[1]
    xtest = _data[2]
    ytrain = _data[3]
    yval = _data[4]
    ytest = _data[5]
    xtrain_mean = _data[6]
    xtest_mean = _data[7]
    xtrain_sd = _data[8]
    xtest_sd = _data[9]

    ntrain = xtrain.shape[0]
    nval = xval.shape[0]
    ntest = xtest.shape[0]

    del _data

    weight = jnp.cos(jnp.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight = weight / jnp.mean(weight)
    weight = weight[None,:,None]
    
    #### Ens average
    yval_hat = np.mean(xval, axis = 1)
    ytest_hat = np.mean(xtest, axis = 1)
    np.save(save_loc + f'/yval_ens_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_ens_{model_no}', ytest_hat)

    #### Delta
    yval_hat = jnp.vstack([delta_pred(xval[n:(n+10)], xtrain, ytrain) for n in trange(nval//10, leave = False)])
    ytest_hat = jnp.vstack([delta_pred(xtest[n:(n+12)], xtrain, ytrain) for n in trange(ntest//12, leave = False)])
    np.save(save_loc + f'/yval_delta_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_delta_{model_no}', ytest_hat)
    
    #### linear regression
    beta = v_ols(xtrain, ytrain)
    beta = jnp.nan_to_num(beta)
    beta = jnp.moveaxis(beta, (0, 1, 2), (2, 1, 0))
    yval_hat = reg_pred(xval, beta)
    ytest_hat = reg_pred(xtest, beta)
    np.save(save_loc + f'/yval_lm_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_lm_{model_no}', ytest_hat)

    #### NNGP
    key = random.PRNGKey(1023)
    params = nngp_params(key, depth)
    for _ in trange(300, leave = False):
        params = gradient_step(params, xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))

    yval_hat = nngp_predict(params, xval.reshape(nval, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    ytest_hat = nngp_predict(params, xtest.reshape(ntest, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    yval_hat = yval_hat.reshape(-1, nlat, nlon)
    ytest_hat = ytest_hat.reshape(-1, nlat, nlon)
    np.save(save_loc + f'/yval_nngp_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_nngp_{model_no}', ytest_hat)
    
    #### WA
    theta = init_wa(xtrain)
    opt_init, opt_update = optax.adam(1e-1)
    opt_state = opt_init(theta)

    for i in trange(1000, leave = False):
        theta, opt_state = optim_wa_update(theta, opt_state, xtrain, ytrain)
    
    yval_hat = weighted_average(theta, xval)
    ytest_hat = weighted_average(theta, xtest)
    np.save(save_loc + f'/yval_wa_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_wa_{model_no}', ytest_hat)
    
    #### CNN
    xtrain = jnp.moveaxis(xtrain, 1, 3)
    xval = jnp.moveaxis(xval, 1, 3)
    xtest = jnp.moveaxis(xtest, 1, 3)
    key = random.PRNGKey(0)
    enc = encoder()
    theta = enc.init(key, xtrain[0:5])
    opt_init, opt_update = optax.adam(1e-3)
    opt_state = opt_init(theta)

    for i in trange(2000, leave = False):
        idx = np.random.choice(ntrain, 32)
        theta, opt_state = optim_update(enc, theta, opt_state, xtrain[idx], ytrain[idx])

    yval_hat = enc.apply(theta, xval).squeeze()
    ytest_hat = np.vstack([enc.apply(theta, xtest[(30*i):(30*(i+1))]) for i in range(ntest//30 + 1)]).squeeze()
    np.save(save_loc + f'/yval_cnn1_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_cnn1_{model_no}', ytest_hat)

  0%|          | 0/31 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

## Max Temp (TASMAX)

In [13]:
nproj = 5
nval = 200

xhist_tas = pickle.load(open('../data/xhist_tasmax_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_tasmax_hr.pkl', 'rb'))
nmod = xhist_tas.shape[1]
nlat = xhist_tas.shape[2]
nlon = xhist_tas.shape[3]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

save_loc = 'trained_models/tmax'

In [14]:
xtrain_anom, xtest_anom = anomalize(xtrain_orig, xtest_orig)

  0%|          | 0/1008 [00:00<?, ?it/s]

  0%|          | 0/924 [00:00<?, ?it/s]

In [15]:
for model_no in trange(nmod):
    _data = scale_and_split(xtrain_anom, xtest_anom, model_no = model_no, nval = nval)

    xtrain = _data[0]
    xval = _data[1]
    xtest = _data[2]
    ytrain = _data[3]
    yval = _data[4]
    ytest = _data[5]
    xtrain_mean = _data[6]
    xtest_mean = _data[7]
    xtrain_sd = _data[8]
    xtest_sd = _data[9]

    ntrain = xtrain.shape[0]
    nval = xval.shape[0]
    ntest = xtest.shape[0]

    del _data

    weight = jnp.cos(jnp.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight = weight / jnp.mean(weight)
    weight = weight[None,:,None]
    
    #### Ens average
    yval_hat = np.mean(xval, axis = 1)
    ytest_hat = np.mean(xtest, axis = 1)
    np.save(save_loc + f'/yval_ens_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_ens_{model_no}', ytest_hat)

    #### Delta
    yval_hat = jnp.vstack([delta_pred(xval[n:(n+10)], xtrain, ytrain) for n in trange(nval//10, leave = False)])
    ytest_hat = jnp.vstack([delta_pred(xtest[n:(n+12)], xtrain, ytrain) for n in trange(ntest//12, leave = False)])
    np.save(save_loc + f'/yval_delta_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_delta_{model_no}', ytest_hat)
    
    #### linear regression
    beta = v_ols(xtrain, ytrain)
    beta = jnp.nan_to_num(beta)
    beta = jnp.moveaxis(beta, (0, 1, 2), (2, 1, 0))
    yval_hat = reg_pred(xval, beta)
    ytest_hat = reg_pred(xtest, beta)
    np.save(save_loc + f'/yval_lm_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_lm_{model_no}', ytest_hat)

    #### NNGP
    key = random.PRNGKey(1023)
    params = nngp_params(key, depth)
    for _ in trange(300, leave = False):
        params = gradient_step(params, xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))

    yval_hat = nngp_predict(params, xval.reshape(nval, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    ytest_hat = nngp_predict(params, xtest.reshape(ntest, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    yval_hat = yval_hat.reshape(-1, nlat, nlon)
    ytest_hat = ytest_hat.reshape(-1, nlat, nlon)
    np.save(save_loc + f'/yval_nngp_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_nngp_{model_no}', ytest_hat)
    
    #### WA
    theta = init_wa(xtrain)
    opt_init, opt_update = optax.adam(1e-1)
    opt_state = opt_init(theta)

    for i in trange(1000, leave = False):
        theta, opt_state = optim_wa_update(theta, opt_state, xtrain, ytrain)
    
    yval_hat = weighted_average(theta, xval)
    ytest_hat = weighted_average(theta, xtest)
    np.save(save_loc + f'/yval_wa_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_wa_{model_no}', ytest_hat)
    
    #### CNN
    xtrain = jnp.moveaxis(xtrain, 1, 3)
    xval = jnp.moveaxis(xval, 1, 3)
    xtest = jnp.moveaxis(xtest, 1, 3)
    key = random.PRNGKey(0)
    enc = encoder()
    theta = enc.init(key, xtrain[0:5])
    opt_init, opt_update = optax.adam(1e-3)
    opt_state = opt_init(theta)

    for i in trange(2000, leave = False):
        idx = np.random.choice(ntrain, 32)
        theta, opt_state = optim_update(enc, theta, opt_state, xtrain[idx], ytrain[idx])

    yval_hat = enc.apply(theta, xval).squeeze()
    ytest_hat = np.vstack([enc.apply(theta, xtest[(30*i):(30*(i+1))]) for i in range(ntest//30 + 1)]).squeeze()
    np.save(save_loc + f'/yval_cnn1_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_cnn1_{model_no}', ytest_hat)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/77 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

## Precipitation (PR)

In [14]:
nproj = 5
nval = 200

xhist_tas = pickle.load(open('../data/xhist_pr_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_pr_hr.pkl', 'rb'))
nmod = xhist_tas.shape[1]
nlat = xhist_tas.shape[2]
nlon = xhist_tas.shape[3]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

xtrain_orig = np.log(np.abs(xtrain_orig) + 1e-3)
xtest_orig = np.log(np.abs(xtest_orig) + 1e-3)

save_loc = 'trained_models/pr'

In [15]:
xtrain_anom, xtest_anom = anomalize(xtrain_orig, xtest_orig)

  0%|          | 0/1008 [00:00<?, ?it/s]

  0%|          | 0/912 [00:00<?, ?it/s]

In [16]:
for model_no in trange(18, nmod):
    _data = scale_and_split(xtrain_anom, xtest_anom, model_no = model_no, nval = nval)

    xtrain = _data[0]
    xval = _data[1]
    xtest = _data[2]
    ytrain = _data[3]
    yval = _data[4]
    ytest = _data[5]
    xtrain_mean = _data[6]
    xtest_mean = _data[7]
    xtrain_sd = _data[8]
    xtest_sd = _data[9]

    ntrain = xtrain.shape[0]
    nval = xval.shape[0]
    ntest = xtest.shape[0]

    del _data

    weight = jnp.cos(jnp.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight = weight / jnp.mean(weight)
    weight = weight[None,:,None]
    
    #### Ens average
    yval_hat = np.mean(xval, axis = 1)
    ytest_hat = np.mean(xtest, axis = 1)
    np.save(save_loc + f'/yval_ens_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_ens_{model_no}', ytest_hat)

    #### Delta
    yval_hat = jnp.vstack([delta_pred(xval[n:(n+10)], xtrain, ytrain) for n in trange(nval//10, leave = False)])
    ytest_hat = jnp.vstack([delta_pred(xtest[n:(n+12)], xtrain, ytrain) for n in trange(ntest//12, leave = False)])
    np.save(save_loc + f'/yval_delta_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_delta_{model_no}', ytest_hat)
    
    #### linear regression
    beta = v_ols(xtrain, ytrain)
    beta = jnp.nan_to_num(beta)
    beta = jnp.moveaxis(beta, (0, 1, 2), (2, 1, 0))
    yval_hat = reg_pred(xval, beta)
    ytest_hat = reg_pred(xtest, beta)
    np.save(save_loc + f'/yval_lm_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_lm_{model_no}', ytest_hat)

    #### NNGP
    key = random.PRNGKey(1023)
    params = nngp_params(key, depth)
    for _ in trange(300, leave = False):
        params = gradient_step(params, xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))

    yval_hat = nngp_predict(params, xval.reshape(nval, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    ytest_hat = nngp_predict(params, xtest.reshape(ntest, -1), xtrain.reshape(ntrain, -1), ytrain.reshape(ntrain, -1))
    yval_hat = yval_hat.reshape(-1, nlat, nlon)
    ytest_hat = ytest_hat.reshape(-1, nlat, nlon)
    np.save(save_loc + f'/yval_nngp_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_nngp_{model_no}', ytest_hat)
    
    #### WA
    theta = init_wa(xtrain)
    opt_init, opt_update = optax.adam(1e-1)
    opt_state = opt_init(theta)

    for i in trange(1000, leave = False):
        theta, opt_state = optim_wa_update(theta, opt_state, xtrain, ytrain)
    
    yval_hat = weighted_average(theta, xval)
    ytest_hat = weighted_average(theta, xtest)
    np.save(save_loc + f'/yval_wa_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_wa_{model_no}', ytest_hat)
    
    #### CNN
    xtrain = jnp.moveaxis(xtrain, 1, 3)
    xval = jnp.moveaxis(xval, 1, 3)
    xtest = jnp.moveaxis(xtest, 1, 3)
    key = random.PRNGKey(0)
    enc = encoder()
    theta = enc.init(key, xtrain[0:5])
    opt_init, opt_update = optax.adam(1e-3)
    opt_state = opt_init(theta)

    for i in trange(2000, leave = False):
        idx = np.random.choice(ntrain, 32)
        theta, opt_state = optim_update(enc, theta, opt_state, xtrain[idx], ytrain[idx])

    yval_hat = enc.apply(theta, xval).squeeze()
    ytest_hat = np.vstack([enc.apply(theta, xtest[(30*i):(30*(i+1))]) for i in range(ntest//30 + 1)]).squeeze()
    np.save(save_loc + f'/yval_cnn1_{model_no}', yval_hat)
    np.save(save_loc + f'/ytest_cnn1_{model_no}', ytest_hat)

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

2024-06-25 09:53:08.868604: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.92GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-06-25 09:53:43.948242: W external/tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 13.92GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/76 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

# Quantile mapping

In [53]:
nproj = 5
nval = 200

xhist_tas = pickle.load(open('../data/xhist_tas_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_tas_hr.pkl', 'rb'))

# xhist_tas = np.moveaxis(np.array(xhist_tas), 0, 1)
# xrcp_tas = np.moveaxis(np.array(xrcp_tas), 0, 1)
nmod = xhist_tas.shape[1]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

save_loc = 'trained_models/tas'

In [238]:
def quantile_interp(model, model_quant, ref_quant):
    return jnp.interp(model, model_quant, ref_quant)
quantile_interp = vmap(quantile_interp, (1, 1, 1))
quantile_interp = vmap(quantile_interp, (1, 1, 1))
quantile_interp = jit(quantile_interp)

def quantile_map(ref, model, n):
    ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, n), axis = 0)
    model_quant = jnp.quantile(model, jnp.linspace(0, 1, n), axis = 0)
    corrected_model = quantile_interp(model, model_quant, ref_quant)
    return jnp.moveaxis(corrected_model, 2, 0)

def quantile_map(ref, model_hist, model_future, n):
    ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, n), axis = 0)
    model_quant = jnp.quantile(model_hist, jnp.linspace(0, 1, n), axis = 0)
    corrected_model = quantile_interp(model_future, model_quant, ref_quant)
    return jnp.moveaxis(corrected_model, 2, 0)

In [210]:
ref = xtrain_orig[:,0]
model = xtrain_orig[:,1]

n = 20
ref_quant = jnp.quantile(ref, jnp.linspace(0, 1, n), axis = 0)
model_quant = jnp.quantile(model, jnp.linspace(0, 1, n), axis = 0)

In [212]:
model_bc = quantile_map(ref, model, 20)

In [244]:
for model_no in trange(2):
    xtrain, xval, xtest, ytrain, yval, ytest = scale_and_split(xtrain_orig, xtest_orig, model_no = model_no, nval = nval)

  0%|          | 0/2 [00:00<?, ?it/s]

In [259]:
xtrain_orig.shape, xtest_orig.shape

((1008, 31, 80, 120), (912, 31, 80, 120))

In [260]:
xtest_bc = jnp.array([quantile_map(yval, xval[:,j], xtest[:,j], 20) for j in trange(nmod - 1)])
xtest_bc = jnp.moveaxis(xtest_bc, 0, 1)

  0%|          | 0/30 [00:00<?, ?it/s]

In [250]:
def pointwise_error(y):
    return y**2
pointwise_error = jit(pointwise_error)

def pointwise_sharpness(ensemble):
    return jnp.abs(jnp.max(ensemble, axis = 0) - jnp.min(ensemble, axis = 0))
single_ens_sharpness = jit(pointwise_sharpness)
multi_ens_sharpness = jit(vmap(pointwise_sharpness, 0))

def pointwise_crps(ensemble, y): 
    n = ensemble.shape[0]//2
    a = 0.5 * jnp.mean(jnp.abs(ensemble[:n] - ensemble[n:]), axis = 0)
    b = jnp.mean(jnp.abs(ensemble - y), axis = 0)
    return b - a
single_ens_crps = jit(vmap(pointwise_crps, (None, 0)))
multi_ens_crps = jit(vmap(pointwise_crps, (0, 0)))

def pointwise_pit(ensemble, y):
    pit = np.mean(np.array(ensemble[None,:] < y[:,None]).squeeze(), axis = 1)
    pit_cdf = np.array([np.mean(pit < x, axis = 0) for x in np.linspace(0, 1, 50)])
    uni_cdf = np.linspace(0, 1, 50)[:,None,None]
    pit_score = np.mean((pit_cdf - uni_cdf)**2, axis = 0)
    dis_score = (1/12 - np.var(pit_cdf, axis = 0))
    return pit_score, dis_score

def pointwise_pit(ensemble, y):
    pit = jnp.mean(jnp.array(ensemble[None,:] < y[:,None]).squeeze(), axis = 1)
    pit_cdf = jnp.array([jnp.mean(pit < x, axis = 0) for x in np.linspace(0, 1, 50)])
    uni_cdf = jnp.linspace(0, 1, 50)[:,None,None]
    pit_score = jnp.mean((pit_cdf - uni_cdf)**2, axis = 0)
    return pit_score

def metrics(ensemble, y):
    ntest = y.shape[0]
    nlat = y.shape[1]
    w = jnp.cos(jnp.linspace(math.pi/2 - 1/nlat, -math.pi/2 + 1/nlat, nlat))
    weight = w / jnp.mean(w)

    # rmse
    _rmse = pointwise_error(y)
    _rmse = jnp.sqrt(jnp.mean(_rmse * weight[None,:,None]))

    # crps
    if len(ensemble.shape) == 3:
        _crps = single_ens_crps(ensemble, y)
    else:
        _crps = multi_ens_crps(ensemble, y)

    _crps = np.mean(_crps * weight[None,:,None])
    # PIT
    _pit_score = pointwise_pit(ensemble, y)
    _pit_score = jnp.mean(_pit_score * weight[:,None])
    return _rmse, _crps, _pit_score

metrics = jit(metrics)

def sw2(x, y, w):
    x = x.reshape(-1, nlat*nlon)
    y = y.reshape(-1, nlat*nlon)
    
    x = x @ w.T
    y = y @ w.T
    
    qx = jnp.quantile(x, jnp.linspace(0, 1, 50), axis = 0)
    qy = jnp.quantile(y, jnp.linspace(0, 1, 50), axis = 0)
    return jnp.mean(jnp.sqrt(jnp.mean((qx - qy)**2, axis = 0)))

sw2 = jit(sw2)

In [248]:
metrics(xtest, ytest)

(Array(0.98708157, dtype=float64),
 Array(0.51559381, dtype=float64),
 Array(0.05693089, dtype=float64))

In [249]:
metrics(xtest_bc, ytest)

(Array(0.98708157, dtype=float64),
 Array(0.36672001, dtype=float64),
 Array(0.00603064, dtype=float64))

In [267]:
nproj = 6
nval = 200

xhist_tas = pickle.load(open('../data/xhist_tas_hr.pkl', 'rb'))
xrcp_tas = pickle.load(open('../data/xrcp_tas_hr.pkl', 'rb'))
nmod = xhist_tas.shape[1]
nlat, nlon = xhist_tas[0].shape[1:]

n = (2015 - 1940) * 12
xhist_tas = xhist_tas[-n:]

xtrain_orig = np.concatenate([xhist_tas, xrcp_tas[:108]], axis = 0)
xtest_orig = xrcp_tas[108:]

ntrain = xtrain_orig.shape[0] - nval
ntest = xtest_orig.shape[0]

save_loc = 'trained_models/tas'

# level = 1 - np.ceil((1-0.1) * (nval + 1))/(nval + 1)
alpha = 0.1
level = 1 - np.ceil((1 - alpha) * (nval + 1))/(nval + 1)

In [297]:
analysis = ['ens', 'wa', 'delta', 'lm', 'nngp', 'cnn1']
results = np.zeros([nmod, 2 * nproj, 4])

np.random.seed(0)
w = np.random.normal(0,1, [500, nlat*nlon])
w = w / np.sqrt(np.sum(w**2, axis = 1))[:,None]

model_no = 0
k = 8
xtrain, xval, xtest, ytrain, yval, ytest = scale_and_split(xtrain_orig, xtest_orig, model_no = model_no, nval = nval)
        
yval_hat = np.load(save_loc + f'/yval_{analysis[k//2]}_{model_no}.npy')
ytest_hat = np.load(save_loc + f'/ytest_{analysis[k//2]}_{model_no}.npy')

if analysis[k//2] == 'lm':
    yval_hat = np.clip(yval_hat, -3, 3)
    ytest_hat = np.clip(ytest_hat, -3, 3)

resval = yval - yval_hat.reshape(yval.shape)
restest = ytest - ytest_hat.reshape(ytest.shape)
# dr1 = global_mmd_self(resval, resval)
# q = np.quantile(dr1, level)

## IMV
imv_ens = np.delete(xtest_orig, model_no, 1)
imv_ens = imv_ens - np.mean(imv_ens, axis = 1)[:,None,:,:]

bc_ens = jnp.array([quantile_map(xtrain_orig[:,0], xtrain_orig[:,j], xtest_orig[:,j], 20) for j in trange(1, nmod)])
bc_ens = jnp.moveaxis(bc_ens, 0, 1)
bc_ens = bc_ens - np.mean(bc_ens, axis = 1)[:,None,:,:]

  0%|          | 0/30 [00:00<?, ?it/s]

In [304]:
imv_ens = xtest
imv_ens = imv_ens - np.mean(imv_ens, axis = 1)[:,None,:,:]

bc_ens = jnp.array([quantile_map(yval, xval[:,j], xtest[:,j], 20) for j in trange(1, nmod)])
bc_ens = jnp.moveaxis(bc_ens, 0, 1)
bc_ens = bc_ens - np.mean(bc_ens, axis = 1)[:,None,:,:]

  0%|          | 0/30 [00:00<?, ?it/s]

In [305]:
metrics(imv_ens, restest)

(Array(0.61676607, dtype=float64),
 Array(0.3380705, dtype=float64),
 Array(0.01735841, dtype=float64))

In [306]:
metrics(bc_ens, restest)

(Array(0.61676607, dtype=float64),
 Array(0.33083619, dtype=float64),
 Array(0.01340928, dtype=float64))