In [None]:
import os
import sys
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
logging.getLogger('tensorflow').setLevel(logging.DEBUG)

try:
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install -q ruamel.yaml
    !pip install -q tensorboard-plugin-profile
    project_path = '/content/drive/MyDrive/Colab Projects/quantumflow'
except ImportError:
    project_path = os.path.expanduser('~/quantumflow')

In [None]:
os.chdir(project_path)
sys.path.append(project_path)

import tensorflow as tf
%load_ext tensorboard

import quantumflow

experiment = 'resnets'
run_name = 'derivative'

base_dir = os.path.join(project_path, "experiments", experiment)
params = quantumflow.utils.load_yaml(os.path.join(base_dir, 'hyperparams.yaml'))[run_name]
run_dir = os.path.join(base_dir, run_name)

In [None]:
# %tensorboard --logdir="$base_dir" --load_fast=false

In [None]:
dataset_train = quantumflow.instantiate(params['dataset_train'], run_dir=run_dir)
dataset_train.build()

dataset_validate = quantumflow.instantiate(params['dataset_validate'], run_dir=run_dir)
dataset_validate.build()

In [None]:
tf.profiler.experimental.server.start(6009)
tf.keras.backend.clear_session()
tf.random.set_seed(params['seed'])

model = quantumflow.instantiate(params['model'], run_dir=run_dir, dataset=dataset_train)
display(model.summary())

# Visualization

In [None]:
from quantumflow.utils import anim_plot
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tree
%matplotlib inline

In [None]:
sample_features = tree.map_structure(lambda feature: feature[10:11], dataset_validate.features)
sample_targets = tree.map_structure(lambda target: target[10:11], dataset_validate.targets)
sample_targets_pred = model(sample_features)
sample = tree.map_structure(lambda target, target_pred: (target[0], target_pred.numpy()[0]), sample_targets, sample_targets_pred)

In [None]:
for target_path, (target, target_pred) in tree.flatten_with_path_up_to(sample_targets, sample):
    target_name = '/'.join(target_path)
    
    if np.squeeze(target).shape == dataset_validate.x.shape:

        plt.figure(figsize=(20, 3))
        plt.plot(dataset_validate.x, target, 'k:')
        plt.plot(dataset_validate.x, np.squeeze(target_pred))
        plt.title(target_name)
        plt.show()
    else:
        print(f"{target_name}: {target_pred} ({target})")

In [None]:
visualize_model = model if not hasattr(model, 'base_model') else model.base_model
visualize_params = params['model'] if not hasattr(model, 'base_model') else params['model']['base_model']

import quantumflow.xdiff

latents = visualize_model.layers[0](sample_features)
x, x_inputs, inputs = visualize_model.layers[1](latents['density'])
x = visualize_model.layers[2](x)

self = visualize_model.layers[3]

x_token = self.x_token # (d_model)
for shape in tf.unstack(tf.shape(x))[:-2]:
    x_token = tf.repeat(tf.expand_dims(x_token, axis=-3), shape, axis=-3) # (..., latent_size, d_model)
x_token = tf.repeat(x_token, tf.shape(x)[-2], axis=0)

xdiff = quantumflow.xdiff.get_xdiff(x, x, visualize_params['scale'], visualize_params['K'])
xdiff_cross = quantumflow.xdiff.get_xdiff(x, x_inputs, visualize_params['scale'], visualize_params['K'])

inputs = tf.concat([
    inputs, 
    quantumflow.xdiff.positional_encoding(inputs, visualize_params['K_input'])
], axis=-1) # (..., x1_size, x2_size, x_features)
    
latents = x_token #self.x_token_layer(xdiff, xdiff_cross)

layers = []
for r in range(visualize_params['num_repeats']):
    for i in range(visualize_params['num_layers']):
        layers.append(self.enc_layers[r][i])
    layers.append(self.cross_enc_layers[r])

for i in range(visualize_params['num_layers']):
    layers.append(self.enc_layers[visualize_params['num_repeats']][i])

layers.append(self.layernorm)

for layer in self.pre_final_layers:
    layers.append(layer)

layers.append(self.final_layer)

    
for i in range(latents.shape[2]):
    plt.figure(figsize=(20, 3))
    plt.plot(latents[0, :, i, :])
    plt.title(f"Latents {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}")
    plt.show()
                
for self in layers:
    print(self.name)
    if 'encoder' in self.name:
        if 'cross' in self.name:
            
            inp = inputs
            lat = self.layernorm1(latents)
                
            for i in range(lat.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(lat[0, :, i, :])
                plt.title(f"Normalized Latents")
                plt.show()
                
            plt.figure(figsize=(20, 3))
            plt.plot(inp[0, 0, :, :])
            plt.title('Normalized Inputs')
            plt.show()
                
            attn_output, attention = self.mha(lat, inp, inp, xdiff_cross, mask=None)  # (..., latent_size, d_model)

            for i in range(attention.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.imshow(attention[0, :, i, 0, :], norm=matplotlib.colors.Normalize(vmin=0, vmax=0.01, clip=False), aspect=1.0)
                plt.show()
                print(f'Attention Map {np.std(attention[0, :, i, 0, :]):.3f}')
                
            for i in range(attn_output.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(attn_output[0, :, i, :])
                plt.title(f'Attention Output {np.mean(attn_output[0, :, i, :]):.3f} {np.std(attn_output[0, :, i, :]):.3f}')
                plt.show()
                
            attn_output = self.dropout1(attn_output, training=True)
            
            latents = latents + attn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip Attn Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
            lat = self.layernorm2(latents)  # (..., latent_size, d_model)
            ffn_output = self.ffn[1](self.ffn[0](lat))  # (..., input_size, d_model)
            ffn_output = self.dropout2(ffn_output, training=True)

            latents = latents + ffn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip FFN Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
        else:
            lat = self.layernorm1(latents)  # (..., input_size, d_model)
                
            for i in range(lat.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(lat[0, :, i, :])
                plt.title('Normalized Latents')
                plt.show()

            attn_output, attention = self.mha(lat, lat, lat, xdiff, mask=None)  # (..., input_size, d_model)

            if attention.shape[-1] > 1:
                for i in range(attention.shape[2]):

                    plt.figure(figsize=(20, 3))
                    plt.imshow(attention[0, :, i, 0, :], norm=matplotlib.colors.Normalize(vmin=0, vmax=0.01, clip=False), aspect=1.0)
                    plt.show()
                    print(f'Attention Map {np.std(attention[0, :, i, 0, :]):.3f}')
            else:
                print(attention[0, :, :, 0, 0].numpy())
            
            for i in range(attn_output.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(attn_output[0, :, i, :])
                plt.title(f'Attention Output {np.mean(attn_output[0, :, i, :]):.3f} {np.std(attn_output[0, :, i, :]):.3f}')
                plt.show()
                
            attn_output = self.dropout1(attn_output, training=True)
            
            latents = latents + attn_output

            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip Attn Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
            lat = self.layernorm2(latents)  # (..., input_size, d_model)
            ffn_output = self.ffn[1](self.ffn[0](lat))  # (..., input_size, d_model)
            ffn_output = self.dropout2(ffn_output, training=True)
            
            latents = latents + ffn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip FFN Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
    else:
        latents = self(latents)

kinetic_energy_density = tf.reduce_sum(latents[..., 0], axis=-1)

In [None]:
plt.plot(kinetic_energy_density[0] - sample_targets_pred['kinetic_energy_density'][0])

# Train

In [None]:
raise YouShallNotPass

In [None]:
optimizer = quantumflow.instantiate(params['optimizer'])

model.compile(
    optimizer,
    loss=params['loss'], 
    loss_weights=params.get('loss_weights', None), 
    metrics=params.get('metrics', None)
)


if params.get('load_checkpoint', None) is not None:
    model.load_weights(os.path.join(data_dir, params['load_checkpoint']))
    if params['fit'].get('verbose', 0) > 0:
        print("loading weights from ", os.path.join(data_dir, params['load_checkpoint']))

callbacks = []


if run_dir is not None and params.get('checkpoint', False):
    checkpoint_params = params['checkpoint'].copy()
    checkpoint_params['filepath'] = os.path.join(run_dir, checkpoint_params.pop('filename', 'weights.{epoch:05d}.hdf5'))
    checkpoint_params['verbose'] = checkpoint_params.get('verbose', min(1, params['fit'].get('verbose', 1)))
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(**checkpoint_params))


if 'tensorboard' in params:
    callbacks.append(
        quantumflow.instantiate(params['tensorboard'], log_dir=run_dir, learning_rate=optimizer.learning_rate))


model.fit(x=dataset_train.features, 
          y=dataset_train.targets, 
          callbacks=callbacks,
          validation_data=(dataset_validate.features, dataset_validate.targets) if dataset_validate is not None else None,
          **params['fit'])

if params['save'] is True:
    save_model = getattr(model, params['save_model']) if not params.get('save_model', 'self') == 'self' else model
    save_model.save(os.path.join(run_dir, 'saved_model'), include_optimizer=False)