### Convolutional Block Attention Module

In [1]:
import jax
from jax import lax, random, numpy as jnp
import flax
import numpy as np
from flax import linen as nn
from flax import optim
from jax.numpy.fft import fft2, ifft2
import tensorflow_datasets as tfds
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
from flax import serialization
import sys
from model import ResNet_CBAM
sys.path.append('../..')
from bvex_dl import *
from namelist_dl import *
from model import *
tf.random.set_seed(2021)

In [2]:
@jax.vmap
def MSE(x, y):
    return jnp.mean((x-y)**2)


def get_initial_params(key, model):
    init_val = jnp.ones((2,64,64,3), jnp.float32)
    initial_params = model().init(key, init_val)['params']
    return initial_params


def create_optimizer(params, learning_rate):
    optimizer_def = optim.Adam(learning_rate=learning_rate)
    optimizer = optimizer_def.create(params)
    return optimizer


def get_datasets(batch_size):
    ds_builder = tfds.builder('highres_forcing')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[:95%]', batch_size=batch_size, shuffle_files=True))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[5%:]', batch_size=batch_size, shuffle_files=True))
    return train_ds, test_ds


In [3]:
@jax.jit
def train_step(optimizer, batch):

    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        del fNew
        del pNew
        
        correction = model().apply({'params': params}, stateNew)
        
        del stateNew
        
        qNewCorrected = qNew + correction.squeeze()
        pNewCorrected, _, _ = laplacian(qNewCorrected)
                
        del qNew
        
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
        
        loss = MSE(statesCorrected, batch['states'][:, 1, :, :, :2])
        
        del statesCorrected
        
        for i in range(1, n_look_ahead - 1):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = model().apply({'params': params}, stateNew)
            
            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
            statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)

            loss += weight ** i * MSE(statesCorrected, batch['states'][:, i+1, :, :, :2])
            
        loss = jnp.mean(loss)
        
        # regularization
        return loss, statesCorrected


    grad_fn = jax.value_and_grad(accumulate_mse, has_aux=True)
    (total_mse, final_state), grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    final_mse = MSE(final_state, batch['states'][:, -1, :, :, :2])
    
    metrics = {
        "total_mse": total_mse,
        "final_mse": final_mse
    }
    
    del final_state
    del total_mse
    del final_mse
    
    return optimizer, metrics

In [4]:
@jax.jit
def eval_step(params, batch):
    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        del fNew
        del pNew
        
        correction = model().apply({'params': params}, stateNew)
        
        del stateNew
        
        qNewCorrected = qNew + correction.squeeze()
        pNewCorrected, _, _ = laplacian(qNewCorrected)
                
        del qNew
        
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
        
        loss = MSE(statesCorrected, batch['states'][:, 1, :, :, :2])
        
        del statesCorrected
        
        for i in range(1, n_look_ahead - 1):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = model().apply({'params': params}, stateNew)
            
            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
            statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)

            loss += weight ** i * MSE(statesCorrected, batch['states'][:, i+1, :, :, :2])
            
        loss = jnp.mean(loss)
        
        return loss, statesCorrected


    total_mse, final_states = accumulate_mse(params)
    final_mse = MSE(final_states, batch['states'][:, -1, :, :, :2])
    
    del final_states
    
    
    metrics = {
        "total_mse": total_mse,
        "final_mse": final_mse

    }
    
    del total_mse
    del final_mse
    
    return metrics

In [5]:
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
    
    batch_metrics = []
    # 4 cycles per epoch
  
    
    for batch in train_ds:
        optimizer, metrics = train_step(optimizer, batch)
        batch_metrics.append(metrics)
        
    training_batch_metrics = jax.device_get(batch_metrics)
    
    training_batch_metrics = training_batch_metrics[:-1]
    
    training_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in training_batch_metrics])
        for k in training_batch_metrics[0]
    }
    
    print('Training - epoch: %d, final step mse: %.4f, total mse: %.4f' 
          % (epoch,
             training_epoch_metrics['final_mse'],
             training_epoch_metrics['total_mse']))
    
    return optimizer, training_epoch_metrics


def eval_model(model, test_ds):
    
    batch_metrics = []
    for batch in test_ds:
        metrics = eval_step(model, batch)
        batch_metrics.append(metrics)
        
    testing_batch_metrics = jax.device_get(batch_metrics)
    
    testing_batch_metrics = testing_batch_metrics[:-1]
    
    testing_metrics = {
        k: np.mean([metrics[k] for metrics in testing_batch_metrics])
        for k in testing_batch_metrics[0]
    }

    
    print('Testing - epoch: %d,  final step mse: %.4f, total mse: %.4f'  
          % (epoch,
             testing_metrics['final_mse'],
             testing_metrics['total_mse']))
    
    return testing_metrics



In [6]:
model = ResNet_CBAM
rng = jax.random.PRNGKey(2021)
rng, init_rng = jax.random.split(rng)
params = get_initial_params(init_rng, model)
learning_rate = 0.0002
num_epochs = 100
batch_size = 4
optimizer = create_optimizer(params, learning_rate=learning_rate)
train_ds, test_ds = get_datasets(batch_size)
weight = 1
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng)
    eval_model(optimizer.target, test_ds)
    dict_output = serialization.to_state_dict(optimizer.target)
    np.save(f"checkpoint/CBAM_state_dict_epoch_{epoch}.npy", dict_output)

Training - epoch: 1, final step mse: 0.0182, total mse: 0.2298
Testing - epoch: 1,  final step mse: 0.0142, total mse: 0.1839
Training - epoch: 2, final step mse: 0.0133, total mse: 0.1723
Testing - epoch: 2,  final step mse: 0.0123, total mse: 0.1611
Training - epoch: 3, final step mse: 0.0117, total mse: 0.1538
Testing - epoch: 3,  final step mse: 0.0111, total mse: 0.1466
Training - epoch: 4, final step mse: 0.0107, total mse: 0.1411
Testing - epoch: 4,  final step mse: 0.0101, total mse: 0.1349
Training - epoch: 5, final step mse: 0.0099, total mse: 0.1313
Testing - epoch: 5,  final step mse: 0.0092, total mse: 0.1245
Training - epoch: 6, final step mse: 0.0093, total mse: 0.1238
Testing - epoch: 6,  final step mse: 0.0088, total mse: 0.1192
Training - epoch: 7, final step mse: 0.0088, total mse: 0.1178
Testing - epoch: 7,  final step mse: 0.0085, total mse: 0.1144
Training - epoch: 8, final step mse: 0.0084, total mse: 0.1129
Testing - epoch: 8,  final step mse: 0.0081, total mse: