In [1]:
import os
import gc
import math

# numpy
import numpy as np
import numpy.ma as ma

from jax import vmap, grad, jit, random
from jax.config import config
from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

config.update("jax_enable_x64", True)

from skimage.restoration import estimate_sigma
from skimage.transform import resize

# netCDF
import netCDF4 as nc

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import pickle

In [5]:
xhist = pickle.load(open('../submit/data/saved/xhist_pr.pkl', 'rb'))
xrcp = pickle.load(open('../submit/data/saved/xrcp_pr.pkl', 'rb'))

In [6]:
nval = 72
nmod = len(xhist)

In [7]:
def lr_pred(xtrain, ytrain):
    x_inv = jnp.linalg.inv(jnp.matmul(xtrain.T, xtrain))
    beta = jnp.matmul(jnp.matmul(x_inv, xtrain.T), ytrain)
    return beta

def lr_pred_error(xtrain, yhat_train, ytrain, xtest):
    x_inv = jnp.linalg.inv(jnp.matmul(xtrain.T, xtrain))
    x_scale = jnp.sqrt(1 + jnp.diag(jnp.matmul(jnp.matmul(xtest, x_inv), xtest.T)))
    return jnp.std(ytrain - yhat_train) * x_scale

lr_pred = jit(vmap(lr_pred, in_axes = (0, 0)))
pred_error = jit(vmap(lr_pred_error, in_axes = (0, 0, 0, 0)))

In [14]:
nmod = len(xhist)
xreg_list = []
xreg_var_list = []
xbar_list = []
xbar_var_list = []

### outer loop
for m1 in trange(7, nmod):

    ### training 
    ntest, nlat, nlon = xrcp[m1].shape
    nmod = len(xhist)

    x1 = []
    for m2 in range(nmod):
        if m2 != m1:
            xt1 = xhist[m2]
            xt2 = xrcp[m2][0:nval]
            xt = np.moveaxis(np.vstack([xt1, xt2]), 0, 2)
            
            x1.append(resize(xt, (nlat, nlon)))
    x1 = np.moveaxis(np.array(x1), 0, 3).reshape(nlat*nlon, -1, nmod-1)
    
    y1 = xhist[m1]
    y2 = xrcp[m1][0:nval]
    
    y1 = jnp.array(np.vstack([y1, y2]))
    y1 = np.moveaxis(y1, 0, 2).reshape(nlat*nlon, -1)
    
    y2 = xrcp[m1][nval:ntest]
    y2 = np.moveaxis(y2, 0, 2)
    
    
    ### testing
    ntest, nlat, nlon = xrcp[m1].shape

    x2 = []
    for m2 in trange(nmod, leave = False):
        if m2 != m1:
            xt = xrcp[m2][nval:ntest]
            xt = np.moveaxis(xt, 0, 2)
            xt = resize(xt, (nlat, nlon))
            
            x2.append(xt)
    x2 = np.array(x2)
    x2 = np.moveaxis(x2, 0, 3)
    x2 = x2.reshape(nlat*nlon, -1, nmod-1)
    
    
    ### fit and predict
    # regression means
    beta = lr_pred(x1, y1)
    
    y1h = np.sum(x1 * beta[:,None,:], axis = 2) 
    y2h = np.sum(x2 * beta[:,None,:], axis = 2)
    y2h = y2h.reshape(nlat, nlon, -1)
    y2h = np.moveaxis(y2h, 2, 0)
    
    y2h_se = np.zeros_like(y2h)
    
    for j in trange(10, leave = False):
        error_j = pred_error(x1, y1h, y1, x2[:,(100*j):(100*(j+1))])
        y2h_se[(100*j):(100*(j+1))] = np.moveaxis(error_j.reshape(nlat, nlon, -1), 2, 0)
    
    
    # ensemble means
    x2h = np.mean(x2, axis = 2)
    x2h = x2h.reshape(nlat, nlon, -1)
    x2h = np.moveaxis(x2h, 2, 0)
    
    x2h_se = np.std(x2, axis = 2)
    x2h_se = x2h_se.reshape(nlat, nlon, -1)
    x2h_se = np.moveaxis(x2h_se, 2, 0)

    
    ### save
    np.save(f'/submit/experiments/pred/xreg_pr_{m1}.npz', y2h)
    np.save(f'/submit/experiments/pred/xreg_pr_{m1}_var.npz', y2h_se)
    
    np.save(f'/submit/experiments/pred/xbar_pr_{m1}.npz', x2h)
    np.save(f'/submit/experiments/pred/xbar_pr_{m1}_var.npz', x2h_se)

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
%%time 

###### WEA

nmod = len(xhist)
xwea_list = []

### outer loop
for m1 in trange(nmod):
    y1 = xhist[m1]
    y1 = np.moveaxis(y1, 0, 2)

    y2 = xrcp[m1]
    y2 = np.moveaxis(y2, 0, 2)

    nlat, nlon, ntest = y2.shape
    nmod = len(xhist)

    x1 = []
    for m2 in trange(nmod, leave = False):
        if m2 != m1:
            xt = np.moveaxis(xhist[m2], 0, 2)
            x1.append(resize(xt, (nlat, nlon)))
            
    sigma_d = 5.7
    sigma_s = 5.7

    di = np.array([np.mean((y1 - xi)**2, axis = 2) for xi in x1])
    sij = np.sum(np.array([[np.mean((xi - xj)**2, axis = 2) for xi in x1] for xj in x1]), axis = 0)
    w = np.exp(-di / sigma_d) / (1 + np.exp(-sij / sigma_s))
    w /= np.sum(w, axis = 0)[None,:,:]
    
    
    x2 = []
    for m2 in trange(nmod, leave = False):
        if m2 != m1:
            xt = np.moveaxis(xrcp[m2][nval:ntest], 0, 2)
            xt = resize(xt, (nlat, nlon))
            
            x2.append(xt)
    x2 = np.array(x2)
    
    xwea_hat = np.sum(w[:,:,:,None] * x2, axis = 0)
    xwea_hat = np.moveaxis(xwea_hat, 2, 0)
    
    xwea_list.append(xwea_hat)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

CPU times: user 13min 28s, sys: 2min 5s, total: 15min 34s
Wall time: 14min 22s


In [8]:
for t in trange(nmod):
    np.save(f'/submit/experiments/pred/xwea_pr_{t}.npz', xwea_list[t])

  0%|          | 0/16 [00:00<?, ?it/s]