In [1]:
import flax
import jax
import jax.numpy as jnp
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")
import tensorflow_datasets as tfds
import sys
from flax import optim
sys.path.append('/workspace/yquai/BVEX/DL/')
from bvex_dl import *
from namelist_dl import *
sys.path.append('/workspace/yquai/BVEX/DL/DL_Model/CNN/')
sys.path.append('/workspace/yquai/BVEX/DL/DL_Model/')
from model import PeriodicCNN
from flax import serialization
from copy import deepcopy
import flax.linen as nn
import optax
from flax.training import train_state
import time as Time

In [2]:
schedule = optax.exponential_decay(init_value=0.00001, transition_steps=10,decay_rate=0.8,transition_begin=0,staircase=True, end_value=0.00001)


def create_train_state(rng, learning_rate, pretrained):
    """Creates initial `TrainState`."""
    cnn = PeriodicCNN()
    
    if pretrained == None:
        params = cnn.init(rng, jnp.ones([1, 64, 64, 3]))['params']
    else:
        params = pretrained
        
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx)

def get_datasets(batch_size):
    ds_builder = tfds.builder('observation_history')
#    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[:95%]', batch_size=batch_size, shuffle_files=True))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='train[95%:]', batch_size=batch_size, shuffle_files=True))
    return train_ds, test_ds

In [3]:

@jax.jit
def train_step(state, batch):
    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        del fNew
        del pNew
        
        correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
        

        del stateNew
        
        qNewCorrected = qNew + correction.squeeze()
                
        del qNew
        
        for i in range(1, 40):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()

            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)      
        
        loss = jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 1, :, :, :2]))

        for i in range(40, 80):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()

            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)      
        
        loss += jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 2, :, :, :2]))
        
        
        return loss, statesCorrected


    grad_fn = jax.value_and_grad(accumulate_mse, has_aux=True)
    (loss, final_state), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    
    metrics = {
        "obs_loss": loss,
    }

    
    return state, metrics


@jax.jit
def eval_step(params, batch):
    def accumulate_mse(params):
        
        batch['states'] = jnp.transpose(batch['states'], axes=[0,1,3,2,4])
        
        stateNow = batch['states'][:,0,:,:,:]
        qNow = batch['states'][:,0,:,:,0]
        tNow = batch['time'][:, 0]
        
        qNew, tNew= vetdrk4(qNow, tNow)
        pNew, _, _ = laplacian(qNew)
        fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
        stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
        
        del fNew
        del pNew
        
        correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()
        

        del stateNew
        
        qNewCorrected = qNew + correction.squeeze()
                
        del qNew
        
        for i in range(1, 40):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()

            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)      
        
        loss = jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 1, :, :, :2]))
        
        for i in range(40, 80):
            
            qNew, tNew = vetdrk4(qNewCorrected, tNew)
            pNew, _, _ = laplacian(qNew)
            fNew = jax.vmap(cal_forcing)(pNew, tNew)
        
            stateNew = jnp.stack((qNew, pNew, fNew), axis=-1)
            
            del pNew
            del fNew
            
            correction = PeriodicCNN().apply({'params': params}, stateNew).squeeze()

            del stateNew
            
            qNewCorrected = qNew + correction.squeeze()
            pNewCorrected, _, _ = laplacian(qNewCorrected)
            
            del qNew
            
        statesCorrected = jnp.stack((qNewCorrected, pNewCorrected), axis=-1)      
        
        loss += jnp.mean(optax.l2_loss(statesCorrected, batch['states'][:, 2, :, :, :2]))
        

        
        return loss, statesCorrected


    total_mse, final_states = accumulate_mse(params)
    
    
    metrics = {
        "obs_loss": total_mse,
    }
    
    del total_mse
    
    return metrics

def train_epoch(state, train_ds, batch_size, epoch, rng):
    
    batch_metrics = []
    # 4 cycles per epoch
  
    
    for batch in train_ds:
        optimizer, metrics = train_step(state, batch)
        batch_metrics.append(metrics)
        
    training_batch_metrics = jax.device_get(batch_metrics)
    
    training_batch_metrics = training_batch_metrics[:-1]
    
    training_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in training_batch_metrics])
        for k in training_batch_metrics[0]
    }
    
    print('Training - epoch: %d, obs loss: %.4f,' 
          % (epoch,
             training_epoch_metrics['obs_loss']))
    
    return optimizer, training_epoch_metrics


def eval_model(params, test_ds):
    
    batch_metrics = []
    for batch in test_ds:
        metrics = eval_step(params, batch)
        batch_metrics.append(metrics)
        
    testing_batch_metrics = jax.device_get(batch_metrics)
    
    testing_batch_metrics = testing_batch_metrics[:-1]
    
    testing_metrics = {
        k: np.mean([metrics[k] for metrics in testing_batch_metrics])
        for k in testing_batch_metrics[0]
    }

    
    print('Testing - epoch: %d,  obs_loss: %.4f'  
          % (epoch,
             testing_metrics['obs_loss'],
))
    
    return testing_metrics


In [4]:
model = PeriodicCNN
loaded = np.load('/workspace/yquai/BVEX/DL/DL_Model/CNN/checkpoint/CNN16_state_dict_epoch_100.npy',allow_pickle=True).item()
rng = jax.random.PRNGKey(2021)
rng, init_rng = jax.random.split(rng)
learning_rate = schedule
state = create_train_state(init_rng, learning_rate, loaded)

In [5]:
num_epochs = 100
batch_size = 4
train_ds, test_ds = get_datasets(batch_size)
weight= 0.0
n_observation = 3

In [None]:
for epoch in range(1, num_epochs + 1):
    
    rng, input_rng = jax.random.split(rng)
    start = Time.time()
    state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    end = Time.time()
    print(f'Time - {end - start}, LR - {schedule(epoch)}')
    eval_model(state.params, test_ds)
    dict_output = serialization.to_state_dict(state.params)
    np.save(f"checkpoint/CNN16_o_{n_observation}_epoch_{epoch}.npy", dict_output)

Training - epoch: 1, obs loss: 0.0460,
Time - 1320.9103543758392, LR - 9.999999747378752e-06
