In [1]:
# GPU memory allocation 
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.80

import jax
from jax import numpy as jnp
import numpy as np

import flax
from flax import serialization
from flax import linen as nn
from flax.training import train_state
import optax

import tensorflow_datasets as tfds
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")

import sys
sys.path.append('../..')

from bvex_dl import *
from namelist_dl import *
import time as Time
from model import PeriodicCNN

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.80


In [3]:
def create_train_state(rng, learning_rate):
    """Create initial `TrainState`."""
    cnn = PeriodicCNN()
    params = cnn.init(rng, jnp.ones([1, 64, 64, 3]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx)


def get_datasets(batch_size):
    ds_builder = tfds.builder('highres_forcing_long')
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[:95%]', batch_size=batch_size, shuffle_files=True))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[95%:]', batch_size=batch_size, shuffle_files=True))
    return train_ds, test_ds

In [4]:
@jax.jit
def train_step(state, batch):
    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        # obtain the initial state
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
        
        qNewCorrected = qNew + correction.squeeze()
        pNewCorrected, _, _ = laplacian(qNewCorrected)
        
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
        
        loss = jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 1, :, :, :2]))
                       
        for i in range(1, n_look_ahead - 1):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)

            statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
            
            loss += weight ** i * jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, i+1, :, :, :2]))
        
        return loss, statesCorrected


    grad_fn = jax.value_and_grad(accumulate_mse, has_aux=True)
    (total_mse, final_state), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    final_mse = jnp.mean(optax.l2_loss(final_state, batch['states'][:, n_look_ahead, :, :, :2]))
    
    metrics = {
        "total_mse": total_mse,
        "final_mse": final_mse
    }
    
    return state, metrics

In [5]:
@jax.jit
def eval_step(params, batch):
    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
        
        qNewCorrected = qNew + correction.squeeze()
        pNewCorrected, _, _ = laplacian(qNewCorrected)
                
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
        
        loss = jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 1, :, :, :2]))
        
        for i in range(1, n_look_ahead - 1):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)
            
            loss += weight ** i * jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, i+1, :, :, :2]))
            
        
        return loss, statesCorrected


    total_mse, final_states = accumulate_mse(params)
    final_mse = jnp.mean(optax.l2_loss(final_states, batch['states'][:, 16, :, :, :2]))
    
    del final_states
    
    
    metrics = {
        "total_mse": total_mse,
        "final_mse": final_mse

    }
    
    del total_mse
    del final_mse
    
    return metrics

In [6]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    
    batch_metrics = []
    # 4 cycles per epoch
  
    
    for batch in train_ds:
        state, metrics = train_step(state, batch)

        batch_metrics.append(metrics)
        
    training_batch_metrics = jax.device_get(batch_metrics)
    
    training_batch_metrics = training_batch_metrics[:-1]
    
    training_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in training_batch_metrics])
        for k in training_batch_metrics[0]
    }
    
    print('Training - epoch: %d, final step mse: %.4f, total mse: %.4f' 
          % (epoch,
             training_epoch_metrics['final_mse'],
             training_epoch_metrics['total_mse']))
    
    return state


def eval_model(params, test_ds):
    
    batch_metrics = []
    for batch in test_ds:
        metrics = eval_step(params, batch)
        batch_metrics.append(metrics)
        
    testing_batch_metrics = jax.device_get(batch_metrics)
    
    testing_batch_metrics = testing_batch_metrics[:-1]
    
    testing_metrics = {
        k: np.mean([metrics[k] for metrics in testing_batch_metrics])
        for k in testing_batch_metrics[0]
    }

    
    print('Testing - epoch: %d,  final step mse: %.4f, total mse: %.4f'  
          % (epoch,
             testing_metrics['final_mse'],
             testing_metrics['total_mse']))
    
    return testing_metrics

In [7]:
rng = jax.random.PRNGKey(2021)
rng, init_rng = jax.random.split(rng)

# learning rate scheduler 
schedule = optax.exponential_decay(init_value=0.001, 
                                   transition_steps=10,
                                   decay_rate=0.8,
                                   transition_begin=10,
                                   staircase=True, 
                                    end_value=0.0002)

#create train state
learning_rate=schedule
state = create_train_state(init_rng, learning_rate)

num_epochs = 100
batch_size = 8

#load data
train_ds, test_ds = get_datasets(batch_size)
weight = 1

In [8]:
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    start = Time.time()
    state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    end = Time.time()
    print(f'Time - {end - start}')
    eval_model(state.params, test_ds)
    dict_output = serialization.to_state_dict(state.params)
    np.save(f"checkpoint/CNN{n_look_ahead}_state_dict_epoch_{epoch}.npy", dict_output)

Training - epoch: 1, final step mse: 0.0067, total mse: 0.0067
Time - 25.244832515716553
Testing - epoch: 1,  final step mse: 0.2234, total mse: 0.0047
Training - epoch: 2, final step mse: 0.0044, total mse: 0.0044
Time - 22.484177350997925
Testing - epoch: 2,  final step mse: 0.2240, total mse: 0.0041
Training - epoch: 3, final step mse: 0.0038, total mse: 0.0038
Time - 22.57449245452881
Testing - epoch: 3,  final step mse: 0.2241, total mse: 0.0037
Training - epoch: 4, final step mse: 0.0034, total mse: 0.0034
Time - 22.62535834312439
Testing - epoch: 4,  final step mse: 0.2251, total mse: 0.0033
Training - epoch: 5, final step mse: 0.0032, total mse: 0.0032
Time - 22.702173233032227
Testing - epoch: 5,  final step mse: 0.2260, total mse: 0.0031
Training - epoch: 6, final step mse: 0.0029, total mse: 0.0029
Time - 22.747468948364258
Testing - epoch: 6,  final step mse: 0.2254, total mse: 0.0029
Training - epoch: 7, final step mse: 0.0028, total mse: 0.0028
Time - 22.750892400741577
T