In [None]:
!pip install -q ruamel.yaml
from google.colab import drive
drive.mount('/content/drive')
project_path = '/content/drive/MyDrive/Colab Projects/QuantumFlow'

In [None]:
import os
import sys

os.chdir(project_path)
sys.path.append(project_path)

if not os.path.exists('experiments'): os.makedirs('experiments')

import tensorflow as tf

import quantumflow

experiment = 'test'
run_name = 'resnet_100'

base_dir = os.path.join(project_path, "experiments", experiment)
params = quantumflow.utils.load_yaml(os.path.join(base_dir, 'hyperparams.yaml'))[run_name]
run_dir = os.path.join(base_dir, run_name)

In [None]:
# %tensorboard --logdir=$base_dir

In [None]:
dataset_train = quantumflow.instantiate(params['dataset_train'], run_dir=run_dir)
dataset_train.build()

dataset_validate = quantumflow.instantiate(params['dataset_validate'], run_dir=run_dir)
dataset_validate.build()

In [None]:
tf.keras.backend.clear_session()
tf.random.set_seed(params['seed'])

model = quantumflow.instantiate(params['model'], run_dir=run_dir, dataset=dataset_train)

In [None]:
optimizer = quantumflow.instantiate(params['optimizer'])

model.compile(
    optimizer,
    loss=params['loss'], 
    loss_weights=params.get('loss_weights', None), 
    metrics=params.get('metrics', None)
)


if params.get('load_checkpoint', None) is not None:
    model.load_weights(os.path.join(data_dir, params['load_checkpoint']))
    if params['fit_kwargs'].get('verbose', 0) > 0:
        print("loading weights from ", os.path.join(data_dir, params['load_checkpoint']))

callbacks = []

'''

if model_dir is not None and params.get('checkpoint', False):
    checkpoint_params = params['checkpoint_kwargs'].copy()
    checkpoint_params['filepath'] = os.path.join(model_dir, checkpoint_params.pop('filename', 'weights.{epoch:05d}.hdf5'))
    checkpoint_params['verbose'] = checkpoint_params.get('verbose', min(1, params['fit_kwargs'].get('verbose', 1)))
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(**checkpoint_params))

if model_dir is not None and params.get('tensorboard', False):
    tensorboard_callback_class = params['tensorboard'] if callable(params['tensorboard']) else tf.keras.callbacks.TensorBoard
    callbacks.append(tensorboard_callback_class(log_dir=model_dir, learning_rate=learning_rate, **params['tensorboard_kwargs']))
'''


model.fit(x=dataset_train.features, 
          y=dataset_train.targets, 
          callbacks=callbacks,
          validation_data=(dataset_validate.features, dataset_validate.targets) if dataset_validate is not None else None,
          **params['fit'])

'''
if model_dir is not None and params['save_model'] is True:
    model.save(os.path.join(model_dir, 'model.h5')) 

if model_dir is not None and params['export'] is True:
    export_model = getattr(model, params['export_model']) if not params.get('export_model', 'self') == 'self' else model
    tf.saved_model.save(export_model, os.path.join(model_dir, 'saved_model'))
'''


In [None]:
raise dsfsdf

# Train

In [None]:
from collections import OrderedDict

class SampleCallback(tf.keras.callbacks.Callback):
    def __init__(self, dataset, sample_freq=1, merge_layers=None):
        super().__init__()
        self.dataset = dataset
        self.predictions = []
        self.weights = []
        self.layers = []
        self.epochs = []
        self.metrics = []
        self.additional = []
        
        self.merge_layers = merge_layers or {}

        self.sample_freq = sample_freq

    def on_epoch_begin(self, epoch, logs=None):
        if epoch % self.sample_freq == 0:
            self.epochs.append(epoch)
            self.predictions.append(self.model(self.dataset.features))
            self.weights.append({weight.name: weight.numpy() for weight in self.model.trainable_variables})
    
            layers_dict = OrderedDict()

            def add_layer(layer, value, name):
                if layer.name in self.merge_layers:
                    value = layer([value, layers_dict[self.merge_layers[layer.name]]])
                else:
                    value = layer(value)
                layers_dict[name] = value.numpy()
                return value

            value = self.dataset.density #tf.nest.flatten(self.dataset.features)
            for layer in self.model.layers:
                if hasattr(layer, 'layers'): # sub-model
                    for sub_layer in layer.layers:
                        value = add_layer(sub_layer, value, layer.name + '/' + sub_layer.name)
                else:
                    value = add_layer(layer, value, layer.name)

            self.layers.append(layers_dict)


    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.sample_freq == 0:
            self.metrics.append(logs)
            self.additional.append({'learning_rate': self.model.optimizer._decayed_lr(tf.float32).numpy(),
                                    'adam_iterations': self.model.optimizer.iterations.numpy(),
                                    'adam_m_' + self.model.trainable_variables[0].name: self.model.optimizer.get_slot(self.model.trainable_variables[0], 'm').numpy(),
                                    'adam_v_' + self.model.trainable_variables[0].name: self.model.optimizer.get_slot(self.model.trainable_variables[0], 'v').numpy(),
                                    'adam_beta_1': self.model.optimizer._get_hyper('beta_1', tf.float32).numpy()})
    def get_metric(self, key):
        return tf.stack([metric[key] for metric in self.metrics])

    def get_prediction(self, key):
        return tf.stack([prediction[key] for prediction in self.predictions])

    def get_weight(self, key):
        return tf.stack([weight[key] for weight in self.weights])

    def get_layer(self, key):
        return tf.stack([layer[key] for layer in self.layers])
    
    def get_additional(self, key):
        return tf.stack([add[key] for add in self.additional])


In [None]:
from quantumflow.utils import anim_plot
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
run_name = 'resnet_vW_N2_100000'
load_checkpoint = 0
run_epochs = 300
sample_freq = 2
model_dir = os.path.join(data_dir, experiment, run_name)

params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())
model = build_model(params)
display(model.summary())

params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())
params['fit_kwargs']['initial_epoch'] = load_checkpoint
params['load_checkpoint'] = ('{}/{}/' + params['checkpoint_kwargs']['filename']).format(experiment, run_name, epoch=load_checkpoint) if load_checkpoint else None
params['fit_kwargs']['epochs'] = params['fit_kwargs']['initial_epoch'] + run_epochs

merge_layers = {'add': 'model/lambda',
                'add_1': 'model/activation',
                'add_2': 'model/activation_1',
                'add_3': 'model/activation_2',
                'add_4': 'model/activation_3'}

#model_dir = '../data/pop_test/' + run_name

#params['dataset'] = {'h': 1/499}
#params['loss'] = {'kinetic_energy': params['loss']['kinetic_energy'], 'derivative': params['loss']['derivative'](params)}

sample_callback = SampleCallback(QFDataset(os.path.join(data_dir, 'recreate/dataset_sample.hdf5'), params), sample_freq=sample_freq, merge_layers=merge_layers)
model, params = train(params, callbacks=[sample_callback], model_dir=globals().get('model_dir', None))

In [None]:
print(sample_callback.weights[0].keys())
print(sample_callback.metrics[0].keys())

In [None]:
plt.figure(figsize=(20, 3))
plt.plot(sample_callback.epochs, sample_callback.get_prediction('kinetic_energy'))
plt.show()

In [None]:
kedensity = sample_callback.get_prediction('kinetic_energy_density')
kedensity = np.stack([kedensity, np.repeat(np.expand_dims(sample_callback.dataset.kinetic_energy_density, axis=0), len(kedensity), axis=0)], axis=2)
print(kedensity.shape)
anim_plot(np.moveaxis(kedensity[:, 0], 2, 1), bar='Rendering')

In [None]:
derivative = sample_callback.get_prediction('derivative')
derivative = np.stack([derivative[:, 0], np.repeat(np.expand_dims(sample_callback.dataset.derivative, axis=0), len(derivative), axis=0)], axis=2)
print(derivative.shape)
anim_plot(np.moveaxis(derivative[:, 0], 2, 1), bar='Rendering')

In [None]:
plt.figure(figsize=(20, 3))
plt.plot(sample_callback.epochs, sample_callback.get_metric('loss'))
plt.show()

In [None]:
layer_names = sample_callback.layers[0].keys()

for layer_name in layer_names:
    print(layer_name)
    value = sample_callback.get_layer(layer_name)
    #frame = 0
    #plt.figure(figsize=(10, 1))
    #plt.plot(value[frame][0])
    #plt.title(layer_name + ' ' + str(value[frame].shape))
    #plt.show()

    if len(value.shape) < 3:
        continue
    if len(value.shape) == 3:
        value = tf.expand_dims(value, axis=-1)

    anim_plot(value[:, 0], figsize=(10, 1), bar='Rendering')

In [None]:
plt.figure(figsize=(20, 3))
for key in sample_callback.metrics[0].keys():
    if 'loss' not in key:
        continue
    metric = sample_callback.get_metric(key)
    while len(metric.shape) > 1:
        metric = tf.reduce_mean(metric, axis=-1)

    plt.plot(sample_callback.epochs, metric, label=key)

plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(20, 3))
for key in sample_callback.metrics[0].keys():
    if 'mean' not in key:
        continue
    metric = sample_callback.get_metric(key)
    while len(metric.shape) > 1:
        metric = tf.reduce_mean(metric, axis=-1)

    plt.plot(sample_callback.epochs, metric, label=key)

plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
for key in sample_callback.weights[0].keys():
    if not 'conv' in key:
        continue
    weight = sample_callback.get_weight(key)
    while len(weight.shape) > 1:
        weight = tf.reduce_mean(weight, axis=-1)

    plt.plot(sample_callback.epochs, weight, label=key)

plt.legend()
plt.show()

In [None]:
kernel = sample_callback.get_weight('conv1d_4/kernel:0')
anim_plot(kernel[:, :, 0], bar='Rendering')
gradient = kernel[1:] - kernel[:-1]
anim_plot(gradient[:, 0], bar='Rendering')
plt.figure(figsize=(20, 3))
plt.plot(np.log(tf.reduce_sum(tf.square(tf.reshape(gradient, [run_epochs//sample_freq-1, -1])), axis=-1)))
plt.show()