In [None]:
notebook_path = "Projects/QuantumFlow/notebooks"
try:
    import os
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.chdir("/content/gdrive/My Drive/" + notebook_path)
except:
    pass

%tensorflow_version 2.x
!pip install -q ruamel.yaml
!pip install -q tensorflow-addons

import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import time
%matplotlib inline

import ipywidgets as widgets
from IPython.display import display

import sys
sys.path.append('../')

from quantumflow.colab_utils import load_hyperparameters, test_colab_devices, get_resolver, QFDataset

has_gpu, has_tpu = test_colab_devices()
if has_gpu: print("Found GPU")
if has_tpu: print("Found TPU")

data_dir = "../data"
%load_ext tensorboard

In [None]:
data_dir = "../data"
experiment = 'ke_cnn'

base_dir = os.path.join(data_dir, experiment)
log_dir = "gs://quantumflow/" + experiment if has_tpu else '/home/' + experiment
if not os.path.exists(base_dir): os.makedirs(base_dir)
file_hyperparams = os.path.join(base_dir, "hyperparams.config")

In [None]:
%%writefile $file_hyperparams
globals: [ConvNNKineticEnergyFunctional, DerivConvNNKineticEnergyFunctional, has_tpu]

default: &DEFAULT
    dataset_train: recreate/dataset_paper
    dataset_validate: recreate/dataset_validate
    N: 1
    seed: 0
    dtype: float32

    model_fn: DerivConvNNKineticEnergyFunctional
    base_model_fn: ConvNNKineticEnergyFunctional

    model: &DEFAULT_MODEL
        filters: [32, 32, 32, 32, 32]
        kernel_size: [100, 100, 100, 100, 100]
        padding: valid
        activation: softplus
        l2_regularization: True
        bias_mean_initialization: True

    features: ['density']
    targets: ['kinetic_energy', 'derivative']

    eval_metrics:
        kinetic_energy: MeanAbsoluteError
        derivative: MeanAbsoluteError

    loss: &DEFAULT_LOSS
        kinetic_energy: MeanSquaredError
        derivative: MeanSquaredError

    loss_weights:
        regularization: 0.00002
        kinetic_energy: 1.0
        derivative: 0.2

    shuffle: True
    shuffle_buffer_size: 100
    drop_remainder: False
    
    eval_batch_size: 250
    train_batch_size: 10

    train_epochs: 10000
    train_epochs_per_eval: 1000

    save_summary_epochs: 100
    save_model_epochs: 2000

    summaries: [losses, metrics, learning_rate, examples_per_sec, global_norm]

    optimizer: RectifiedAdam
    optimizer_lookahead: True
    
    gradient_clip_norm: 100.0

    optimizer_kwargs:
        learning_rate: ExponentialDecay
        learning_rate_kwargs:
            initial_learning_rate: 0.002
            decay_steps: 10000
            decay_rate: 0.5
            staircase: True


In [None]:
def ConvNNKineticEnergyFunctional(params):
    density = tf.keras.layers.Input(shape=params['features_shape']['density'], name='density')
    value = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(density)

    kernel_regularizer = None
    bias_initializer = None

    if params['model'].get('l2_regularization', False):
        kernel_regularizer = tf.keras.regularizers.l2(params['loss_weights']['regularization'])

    if params['model'].get('bias_mean_initialization', False):
        bias_initializer = tf.constant_initializer(value=params['targets_mean']['kinetic_energy'])

    for layer in range(len(params['model']['filters'])):
        value = tf.keras.layers.Conv1D(filters=params['model']['filters'][layer], 
                                       kernel_size=params['model']['kernel_size'][layer], 
                                       activation=params['model']['activation'], 
                                       padding=params['model']['padding'],
                                       kernel_regularizer=kernel_regularizer)(value)

    value = tf.keras.layers.Flatten()(value)
    value = tf.keras.layers.Dense(1, kernel_regularizer=kernel_regularizer, bias_initializer=bias_initializer)(value)
    kinetic_energy = tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=-1), name='kinetic_energy')(value)

    return tf.keras.Model(inputs={'density': density}, outputs={'kinetic_energy': kinetic_energy})

class DerivConvNNKineticEnergyFunctional(tf.Module):
    def __init__(self, params):
        super().__init__()

        self.model = params['base_model_fn'](params)
        self.h = params['h']
    
    @tf.function
    def __call__(self, inputs, training=False):
        if isinstance(inputs, dict):
            density = inputs['density']
        else:
            density = inputs

        with tf.GradientTape() as tape:
            tape.watch(inputs)
            predictions = self.model(inputs)
            kinetic_energy = predictions['kinetic_energy']

        predictions['derivative'] = 1/self.h*tape.gradient(kinetic_energy, density)
        return predictions

    @property
    def losses(self):
        return self.model.losses

    @property
    def trainable_weights(self):
        return self.model.trainable_weights

    def signatures(self, dataset_train):
        return {'serving_default': self.__call__.get_concrete_function(tf.TensorSpec([None, dataset_train.discretisation_points], dataset_train.dtype, name='density')),
                'serving_dict': self.__call__.get_concrete_function({'density': tf.TensorSpec([None, dataset_train.discretisation_points], dataset_train.dtype, name='density')})
                }
                                    

In [None]:
progress = widgets.IntProgress(value=0, max=0, description='...', 
                               bar_style='info', layout=widgets.Layout(width='92%'))
display(progress)

In [None]:
training_run_names = ['default']

remote_file_hyperparams = os.path.join(log_dir, "hyperparams.config")
if log_dir.startswith('gs://'):
    !gsutil -q cp $file_hyperparams $remote_file_hyperparams
else:
    if not os.path.exists(log_dir): os.makedirs(log_dir)
    !cp $file_hyperparams $remote_file_hyperparams

for run_name in training_run_names:
    params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())
    params['model_dir'] = remote_model_dir = os.path.join(log_dir, run_name)
    local_model_dir = os.path.join(base_dir, run_name)

    display(params)
    if os.path.exists(remote_model_dir):
        print('removing logdir', remote_model_dir)
        import shutil
        shutil.rmtree(remote_model_dir)

    if os.path.exists(local_model_dir):
        print('WARNING: model directory', local_model_dir, 'will be overwritten after training completes.')

    dataset_train = QFDataset(os.path.join(data_dir, params['dataset_train'] + '.pkl'), params, set_h=True, set_shapes=True, set_mean=True)
    dataset_eval = QFDataset(os.path.join(data_dir, params['dataset_validate'] + '.pkl'), params)

    writer = tf.summary.create_file_writer(params['model_dir'])
    eval_writer = tf.summary.create_file_writer(os.path.join(params['model_dir'], 'eval'))

    eval_metrics = {key:getattr(tf.keras.metrics, params['eval_metrics'][key])() for key in params['eval_metrics']}
    eval_losses = {key:tf.keras.metrics.Mean() for key in list(params['loss'].keys()) + ['loss', 'regularization']}

    optimizer_kwargs = params['optimizer_kwargs'].copy()
    if isinstance(params['optimizer_kwargs']['learning_rate'], str):
        optimizer_kwargs['learning_rate'] = learning_rate = getattr(tf.keras.optimizers.schedules, params['optimizer_kwargs']['learning_rate'])(**params['optimizer_kwargs']['learning_rate_kwargs'])
        del optimizer_kwargs['learning_rate_kwargs']

    try:
        optimizer = getattr(tf.keras.optimizers, params['optimizer'])(**optimizer_kwargs)
    except AttributeError:
        optimizer = getattr(tfa.optimizers, params['optimizer'])(**optimizer_kwargs)

    if params['optimizer_lookahead']:
        optimizer = tfa.optimizers.Lookahead(optimizer, **params.get('lookahead_kwargs', {}))

    loss_fn = {key:getattr(tf.keras.losses, params['loss'][key])() for key in params['loss']}
    weights = {key:params.get('loss_weights', {key: 1.0}).get(key, 1.0) for key in params['loss']}

    model = params['model_fn'](params)

    def save_model(epoch=None):
        signatures = None
        if hasattr(model, 'signatures'):
            signatures = model.signatures(dataset_train)

        if epoch is not None:
            save_dir = os.path.join(params['model_dir'], ("{:0" + str(1+int(np.log10(params['train_epochs']))) +  "d}").format(epoch))
        else:
            save_dir = os.path.join(params['model_dir'], 'saved_model')

        tf.saved_model.save(model, save_dir, signatures=signatures)

    def calc_loss(targets, predictions):
        losses = {}

        for key in targets.keys():
            losses[key] = weights[key]*loss_fn[key](targets[key], predictions[key])

        for loss_tensor in model.losses:
            losses[loss_tensor.name] = loss_tensor
        
        loss = tf.add_n(list(losses.values()))
        
        losses['loss'] = loss
        losses['regularization'] = tf.add_n(model.losses)

        return loss, losses

    @tf.function
    def eval_step(batch):
        features, targets = batch
        predictions = model(features)
        loss, losses = calc_loss(predictions, targets)
        
        for key, metric in eval_metrics.items():
            metric.update_state(targets[key], predictions[key])

        for key, loss_metric in eval_losses.items():
            loss_metric.update_state(losses[key])

    @tf.function
    def train_step(batch, step, params):
        features, targets = batch

        with tf.GradientTape() as tape:
            predictions = model(features)
            loss, losses = calc_loss(targets, predictions)

        grads = tape.gradient(loss, model.trainable_weights)

        clip_norm = params.get('gradient_clip_norm', 0.0)
        if clip_norm > 0.0:
            grads, global_norm = tf.clip_by_global_norm(grads, clip_norm)
        else:
            global_norm = None

        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        return losses, global_norm

    global_step = tf.constant(0, dtype=tf.int64)

    if progress is not None:
        progress.max = params['train_epochs']
        progress.description = 'Training'

    for epoch in range(params['train_epochs']):
        if progress is not None and epoch % (params['train_epochs']//1000) == 0:
            progress.value = epoch + 1

        batch_start_time = time.time()

        dataset = dataset_train.dataset
        if params['shuffle']:
            dataset = dataset.shuffle(params['shuffle_buffer_size'])

        for batch in dataset.batch(params['train_batch_size']):
            global_step += 1
            losses, global_norm = train_step(batch, global_step, params)

            current_time = time.time()
            global_steps_per_sec = 1/(current_time - batch_start_time)
            batch_start_time = current_time
            examples_per_sec = global_steps_per_sec*params['train_batch_size']

        if global_step % params['save_summary_epochs'] == 0:
            with writer.as_default():
                if 'learning_rate' in params['summaries'] and 'learning_rate' in globals():
                    tf.summary.scalar('learning_rate', learning_rate(global_step), step=epoch)
                    
                if 'losses' in params['summaries']:
                    tf.summary.scalar('loss', losses['loss'], step=epoch)
                    tf.summary.scalar('loss/regularization', losses['regularization'], step=epoch)

                    for key in list(params['loss'].keys()):
                        tf.summary.scalar('loss/' + key, losses[key], step=epoch)

                if 'global_norm' in params['summaries'] and params.get('gradient_clip_norm', 0.0) > 0.0:
                    tf.summary.scalar('global_norm', global_norm, step=epoch)

                tf.summary.scalar('examples/sec', examples_per_sec, step=epoch)
                writer.flush()

        if (epoch + 1) % params['train_epochs_per_eval'] == 0:
            with eval_writer.as_default():

                # reset all metric states
                for metric in eval_metrics.values():
                    metric.reset_states()

                for loss_metric in eval_losses.values():
                    loss_metric.reset_states()

                for batch in dataset_eval.dataset.batch(params['eval_batch_size']):
                    eval_step(batch)

                for key, metric in eval_metrics.items():
                    tf.summary.scalar('metrics/' + key, metric.result(), step=epoch)

                for key, loss_metric in eval_losses.items():
                    tf.summary.scalar('loss/'*(key != 'loss') + key, loss_metric.result(), step=epoch)

                eval_writer.flush()

        if (epoch + 1) % params['save_model_epochs'] == 0:
            save_model(epoch=epoch+1)

    save_model()

    if os.path.exists(local_model_dir):
        import shutil
        shutil.rmtree(local_model_dir)

    if log_dir.startswith('gs://'):
        !gsutil -m cp -r $remote_model_dir $base_dir
    else:
        !cp -r $remote_model_dir $base_dir