In [2]:
import os
import gc
import math

# numpy
import numpy as np
import numpy.ma as ma

# from jax
import jax
from jax import vmap, grad, jit, random, nn
from jax.config import config
# from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

import haiku as hk
import optax

config.update("jax_enable_x64", True)

from skimage.restoration import estimate_sigma
from skimage.transform import resize


# netCDF
import netCDF4 as nc

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import pickle

  PyTreeDef = type(jax.tree_structure(None))


In [None]:
def delta_pred(xtest, xtrain, ytrain):
    delta = jnp.mean(xtest - xtrain, axis = 3)
    return jnp.median(ytrain + delta, axis = 0)

delta_pred = jit(delta_pred)

def delta_var(xtest, xtrain, ytrain):
    delta = jnp.mean(xtest - xtrain, axis = 3)
    return jnp.std(ytrain + delta, axis = 0)

delta_var = jit(delta_var)

In [3]:
xhist = pickle.load(open('../data/saved/xhist_tas.pkl', 'rb'))
xrcp = pickle.load(open('../data/saved/xrcp_tas.pkl', 'rb'))

nval = 72
nmod = len(xhist)

In [6]:
### run experiments

nmod = len(xhist)
ntrain = xhist[0].shape[0]
ntest = xrcp[0].shape[0]
# ntest = nval + 100
# ntest = 400

sgpr_list = []

# for m1 in trange(nmod):
for m1 in trange(nmod):

    _, nlat, nlon = xhist[m1].shape

    #### construct training set
    xtrain = []
    for m2 in range(nmod):
        if m1 != m2:
            x1 = xhist[m2]
            x2 = xrcp[m2][0:nval]
            
            xt = np.moveaxis(np.vstack([x1, x2]), 0, 2)
            xtrain.append(resize(xt, (nlat, nlon)))
        
    xtrain = np.moveaxis(np.array(xtrain), (0, 3), (3, 0))
    y1 = xhist[m1]
    y2 = xrcp[m1][0:nval]
    ytrain = jnp.array(np.vstack([y1, y2]))
    
    
    #### construct testing set
    xtest = []
    for m2 in range(nmod):
        if m1 != m2:
            x1 = xrcp[m2][nval:ntest]
            x1 = np.moveaxis(x1, 0, 2)
            xtest.append(resize(x1, (nlat, nlon)))
        
    xtest = np.moveaxis(np.array(xtest), (0, 3), (3, 0))
    ytest = xrcp[m1][nval:ntest]
    
    
    ### train
    xtrain = jnp.array(xtrain)
    ytrain = jnp.array(ytrain)
    xtest = jnp.array(xtest)
    
    yhat = np.zeros_like(ytest)
    yhat_var = np.zeros_like(ytest)
    for i in trange(yhat.shape[0]):
        yhat[i] = delta_pred(xtest[i], xtrain, ytrain)
        yhat_var[i] = delta_var(xtest[i], xtrain, ytrain)
    
    np.save(f'../experiments/tas_predictions/del_tas_{m1}.npz', yhat)
    np.save(f'../experiments/tas_predictions/del_tas_{m1}_var.npz', yhat)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]

  0%|          | 0/960 [00:00<?, ?it/s]