In [1]:
import os
import gc
import math

# numpy
import numpy as np
import numpy.ma as ma

from jax import vmap, grad, jit, random
from jax.config import config
from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

config.update("jax_enable_x64", True)

from skimage.restoration import estimate_sigma
from skimage.transform import resize

### CRPS
import properscoring as ps

### SSIM and PSNR
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.ticker import FormatStrFormatter
plt.style.use('default')

from matplotlib.gridspec import GridSpec
from matplotlib import colors
from cartopy import config
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# netCDF
import netCDF4 as nc

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import pickle

In [2]:
xhist = pickle.load(open('../data/saved/xhist_pr.pkl', 'rb'))
xrcp = pickle.load(open('../data/saved/xrcp_pr.pkl', 'rb'))

In [3]:
#### construct training set
nmod = 16
nval = 9
nfit = 789
nlat = 721
nlon = 1440

for i in trange(nmod):
    
    x1 = xhist[i]
    x2 = xrcp[i][0:nval]
    x = np.vstack([x1, x2])
    x = np.moveaxis(x, 0, 2)
    x = resize(x, (nlat, nlon))
    x = np.moveaxis(x, 2, 0)

    np.save(f'./pr_rescaled/model_{i:02}', x)

  0%|          | 0/16 [00:00<?, ?it/s]

In [2]:
def lm_beta(x, y):
    return jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(x.T, x)), x.T), y)

lm_beta = jit(lm_beta)

def prediction_variance(xtrain, xtest, sigma_field):
    x_inv = jnp.linalg.inv(jnp.matmul(xtrain.T, xtrain))
    x_scale = jnp.sqrt(1 + jnp.diag(jnp.matmul(jnp.matmul(xtest, x_inv), xtest.T)))
    return sigma_field * x_scale

prediction_variance = jit(vmap(vmap(prediction_variance, (1, 1, 0)), (2, 2, 1)))

In [3]:
nmod = 16
nval = 9
nfit = 789
nlat = 721
nlon = 1440

rea = np.load('../data/saved/pr_obs.npz.npy')
rea = rea[0:nfit]

# Estimate Betas

In [4]:
nmod = 16
beta = []

for s in trange(2):
    for t in trange(2):
        
        #### load training set
        x = []
        for i in trange(nmod):
            xi = np.load(f'./pr_rescaled/model_{i:02}.npy', mmap_mode = 'r')
            x.append(np.array(xi[:,(360*s):(360*(s+1) + s),(720*t):(720*(t+1))]))
            
        x2 = np.stack(x, axis = 3)
        x2 = np.moveaxis(x2, (1, 2), (0, 1))
        
        del x
        
        nfit = x2.shape[2]
        
        y2 = rea[0:nfit,(360*s):(360*(s+1) + s),(720*t):(720*(t+1))]
        y2 = np.moveaxis(y2, 0, 2)
        y2 = y2.reshape(y2.shape + (1,))
        
        beta_st = np.zeros((360 + s, 720, 16, 1))

        for i in trange(360 + s):
            for j in range(720):
                x = jnp.array(x2[i, j])
                y = jnp.array(y2[i, j])
                beta_st[i,j] = lm_beta(x, y)
        
        np.save(f'./pr_model_coef/beta_{i}', beta_st)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/360 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/360 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/361 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/361 [00:00<?, ?it/s]

In [5]:
for i, b in enumerate(beta):
    np.save(f'./pr_model_coef/beta_{i}', b)

In [9]:
nmod = 16
nval = 9
nfit = 789
nlat = 721
nlon = 1440


beta = [np.load(f'./pr_model_coef/beta_{i}.npy') for i in trange(4)]
xrcp = pickle.load(open('../data/saved/xrcp_pr.pkl', 'rb'))
rea = np.load('../data/saved/pr_obs.npz.npy', mmap_mode = 'r')

  0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
# beta = [np.load(f'./model_coef/beta_{i}.npy') for i in trange(4)]
beta = np.hstack([np.vstack([beta[0], beta[2]]), np.vstack([beta[1], beta[3]])])[:,:,:,0]

In [None]:
beta = [np.load(f'./pr_model_coef/beta_{i}.npy') for i in trange(4)]
beta = np.hstack([np.vstack([beta[0], beta[2]]), np.vstack([beta[1], beta[3]])])
beta = beta.reshape(nlat, nlon, nmod)

# LM predictions

In [None]:
lm_pred = []
lm_predvar = []

m = 0
for s in trange(4):
    for t in trange(4):
        
        leap = 0
        if s == 3:
            leap = 1
        
        #### load training set
        xtrain = []
        for i in trange(nmod):
            xtrain_i = np.load(f'./pr_rescaled/model_{i:02}.npy', mmap_mode = 'r')
            xtrain.append(np.array(xtrain_i[:,(180*s):(180*(s+1) + leap),(360*t):(360*(t+1))]))

        xtrain = np.stack(xtrain, axis = 3)

        xtest = []
        for k in trange(nmod, leave = False):
            xtest_k = xrcp[k][nval:(nval + 70)]
            xtest_k = np.moveaxis(xtest_k, 0, 2)
            xtest_k = resize(xtest_k, (nlat, nlon))
            xtest_k = np.moveaxis(xtest_k, 2, 0)
            xtest_k = xtest_k[:,(180*s):(180*(s+1) + leap),(360*t):(360*(t+1))]
            xtest.append(xtest_k)

        xtest = np.stack(xtest, axis = 3)

        ### fit training to compute sigma
        xtrain = jnp.array(xtrain)
        xtest = jnp.array(xtest)
        
        beta_k = beta[(180*s):(180*(s+1) + leap),(360*t):(360*(t+1))]

        yhat = np.sum(xtrain * beta_k[None,:,:,:], axis = 3)
        ytrain = np.array(rea[0:nfit,(180*s):(180*(s+1) + leap),(360*t):(360*(t+1))])
        
        sigma = np.std(ytrain - yhat, axis = 0)

        
        ### compute prediction variance
        pred_var = prediction_variance(xtrain, xtest, sigma)
        pred_var = np.moveaxis(pred_var, (2, 1, 0), (0, 1, 2))
        
        lm_pred.append(yhat)
        lm_predvar.append(pred_var)
        
        ### cleanup
        m = m + 1
        del xtrain, xtest

In [None]:
lm_pred = np.concatenate(lm_pred)
lm_predvar = np.concatenate(lm_predvar)

np.save('./pr_predictions/lm_pred', lm_pred)
np.save('./pr_predictions/lm_predvar', lm_predvar)

# EA predictions

In [None]:
ens_pred = []
ens_predvar = []
for i in trange(21):
    
    xtest = []
    for k in trange(16, leave = False):
        xtest_k = xrcp[k][(i * 50):((i+1)*50)]
        xtest_k = np.moveaxis(xtest_k, 0, 2)
        xtest_k = resize(xtest_k, (nlat, nlon))
        xtest_k = np.moveaxis(xtest_k, 2, 0)
        xtest.append(xtest_k)

    xtest = np.stack(xtest, axis = 3)
    ens_pred.append(np.mean(xtest, axis = 3))
    ens_predvar.append(np.std(xtest, axis = 3))

In [None]:
ens_pred = np.concatenate(ens_pred)
ens_predvar = np.concatenate(ens_predvar)

np.save('./pr_predictions/ens_pred', ens_pred)
np.save('./pr_predictions/ens_predvar', ens_predvar)