In [1]:
import os
import gc
import math

# numpy
import numpy as np
import numpy.ma as ma

# from jax
import jax
from jax import vmap, grad, jit, random, nn
from jax.config import config
from jax.experimental import optimizers
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves

import haiku as hk
import optax

config.update("jax_enable_x64", True)

from skimage.restoration import estimate_sigma
from skimage.transform import resize

# netCDF
import netCDF4 as nc

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import pickle



In [2]:
def convnet(x):
    cnn = hk.Sequential([
        hk.Conv2D(128, 3), nn.relu,
        hk.Conv2D(64, 3), nn.relu,
        hk.Conv2D(32, 3), nn.relu,
        hk.Conv2D(1, 3)
    ])
    return cnn(x)

convnet = hk.without_apply_rng(hk.transform(convnet))

def rescale(x):
    xmin = np.min(x, axis = (1,2))[:,None,None,:]
    xmax = np.max(x, axis = (1,2))[:,None,None,:]
    return (x - xmin) / (xmax - xmin)

def loss(params, x, y):
    yhat = convnet.apply(params, x)
    return jnp.mean((yhat.squeeze() - y.squeeze())**2)
grad_loss = jit(grad(loss))

@jit
def update(params, opt_state, x, y):
    grads = grad_loss(params, x, y)
    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state

In [3]:
xhist = pickle.load(open('../data/saved/xhist_tas.pkl', 'rb'))
xrcp = pickle.load(open('../data/saved/xrcp_tas.pkl', 'rb'))

nval = 72
nmod = len(xhist)

In [1]:
### run experiments

nmod = len(xhist)
ntrain = xhist[0].shape[0]
ntest = xrcp[0].shape[0]

sgpr_list = []

for m1 in trange(nmod):

    _, nlat, nlon = xhist[m1].shape

    #### construct training set
    xtrain = []
    for m2 in range(nmod):
        if m1 != m2:
            x1 = xhist[m2]
            x2 = xrcp[m2][0:nval]
            
            xt = np.moveaxis(np.vstack([x1, x2]), 0, 2)
            xtrain.append(resize(xt, (nlat, nlon)))
        
    xtrain = np.moveaxis(np.array(xtrain), (0, 3), (3, 0))
    y1 = xhist[m1]
    y2 = xrcp[m1][0:nval]
    ytrain = jnp.array(np.vstack([y1, y2]))
    
    
    #### construct testing set
    xtest = []
    for m2 in range(nmod):
        if m1 != m2:
            x1 = xrcp[m2][nval:ntest]
            x1 = np.moveaxis(x1, 0, 2)
            xtest.append(resize(x1, (nlat, nlon)))
        
    xtest = np.moveaxis(np.array(xtest), (0, 3), (3, 0))
    ytest = xrcp[m1][nval:ntest]
    
    
    ### train
    xtrain = jnp.array(xtrain)
    ytrain = jnp.array(ytrain)
    
    xmean = np.mean(xtrain, axis = (0, 3))
    xtrain = xtrain - xmean[None,:,:,None]
    ytrain = ytrain - xmean[None,:,:]
    xtest = xtest - xmean[None,:,:,None]

    key = random.PRNGKey(0)
    params = convnet.init(key, xtrain[0:1])

    bsize = 128
    nepoch = 400
    nbatch = int(ytrain.shape[0] / bsize) + 1

    ### init opt
    opt_init, opt_update = optax.adam(1e-4)
    opt_state = opt_init(params)
    
    for n in trange(nepoch):
        for i in trange(nbatch, leave = False):

            xi = xtrain[(i*bsize):((i+1)*bsize)]
            yi = ytrain[(i*bsize):((i+1)*bsize)]

            params, opt_state = update(params, opt_state, xi, yi)
            
            
    bsize = 100
    nbatch = int(ytest.shape[0] / bsize) + 1
    yhat = []
    for i in range(nbatch):
        yhat.append(convnet.apply(params, xtest[(i*bsize):((i+1)*bsize)]))

    yhat = np.concatenate(yhat).squeeze()
    yhat += xmean[None,:,:]
    
    np.save(f'../experiments/tas_predictions/cnn_tas_{m1}.npz', yhat)