In [None]:
import os
import sys
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
logging.getLogger('tensorflow').setLevel(logging.DEBUG)

try:
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install -q ruamel.yaml
    !pip install -q tensorboard-plugin-profile
    project_path = '/content/drive/MyDrive/Colab Projects/QuantumFlow'
except:
    project_path = os.path.expanduser('~/QuantumFlow')

In [None]:
os.chdir(project_path)
sys.path.append(project_path)

import numpy as np
import tensorflow as tf
import tree

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import quantumflow

experiment = 'xdiff_perciever'
run_name = 'debug_x'
epoch = 1000

preview = 5

base_dir = os.path.join(project_path, "experiments", experiment)
params = quantumflow.utils.load_yaml(os.path.join(base_dir, 'hyperparams.yaml'))[run_name]
run_dir = os.path.join(base_dir, run_name)

In [None]:
dataset_validate = quantumflow.instantiate(params['dataset_validate'], run_dir=run_dir)
dataset_validate.build()

In [None]:
model = quantumflow.instantiate(params['model'], run_dir=run_dir, dataset=dataset_validate) # TODO: fix missing imports
#model = tf.keras.models.load_model(os.path.join(run_dir, 'saved_model'))
if epoch is not None: _ = model.load_weights(os.path.join(run_dir, params['checkpoint']['filename'].format(epoch=epoch)))
model.summary()

In [None]:
def predict(model, features, batch_size=None):
    if batch_size is None:
        return tree.map_structure(lambda out: out.numpy(), model(features))
    else:
        outputs = []
        steps = -(-tree.flatten(features)[0].shape[0]//batch_size)
        print_steps = steps//100
        print('/', steps)
        for i in range(steps):
            if i % print_steps == 0: print(i, end=' ')
            features_batch = tree.map_structure(lambda inp: inp[i*batch_size:(i+1)*batch_size], features)
            outputs.append(model(features_batch))
        print()
        return tree.map_structure(lambda *outs: np.concatenate(outs, axis=0), *outputs)

In [None]:
targets_pred = predict(model, dataset_validate.features, params['dataset_validate'].get('max_batch_size', None))

In [None]:
targets_pred['kinetic_energy'][:preview]

In [None]:
dataset_validate.targets['kinetic_energy'][:preview]

In [None]:
kinetic_energy_err = targets_pred['kinetic_energy'] - dataset_validate.targets['kinetic_energy'][:len(targets_pred['kinetic_energy'])]

In [None]:
kcalmol_per_hartree = 627.5094738898777
np.mean(np.abs(kinetic_energy_err))*kcalmol_per_hartree

In [None]:
plt.figure(figsize=(20, 3))
plt.plot(dataset_validate.x, dataset_validate.targets['kinetic_energy_density'][:preview, :].transpose())
plt.show()

In [None]:
plt.figure(figsize=(20, 3))
plt.plot(dataset_validate.x, targets_pred['kinetic_energy_density'][:preview, :].transpose())
plt.show()

In [None]:
plt.figure(figsize=(20, 3))
plt.plot(dataset_validate.x, dataset_validate.targets['kinetic_energy_density'][:preview, :].transpose() - targets_pred['kinetic_energy_density'][:preview, :].transpose())
plt.show()

In [None]:
tf.config.experimental.get_memory_info('GPU:0')['peak']/1024**3

### Test

In [None]:
from quantumflow.utils import anim_plot
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tree
%matplotlib inline

In [None]:
sample_features = tree.map_structure(lambda feature: feature[10:11], dataset_validate.features)
sample_targets = tree.map_structure(lambda target: target[10:11], dataset_validate.targets)
sample_targets_pred = model(sample_features)
sample = tree.map_structure(lambda target, target_pred: (target[0], target_pred.numpy()[0]), sample_targets, sample_targets_pred)

In [None]:
for target_path, (target, target_pred) in tree.flatten_with_path_up_to(sample_targets, sample):
    target_name = '/'.join(target_path)
    
    if np.squeeze(target).shape == dataset_validate.x.shape:

        plt.figure(figsize=(20, 3))
        plt.plot(dataset_validate.x, target, 'k')
        plt.plot(dataset_validate.x, np.squeeze(target_pred))
        plt.title(target_name)
        plt.show()
    else:
        print(f"{target_name}: {target_pred} ({target})")

In [None]:
latents = model.layers[0](sample_features)
x, x_inputs, inputs = model.layers[1](latents['density'])
x = model.layers[2](x)

self = model.layers[3]

x_token = self.x_token # (d_model)
for shape in tf.unstack(tf.shape(x))[:-2]:
    x_token = tf.repeat(tf.expand_dims(x_token, axis=-3), shape, axis=-3) # (..., latent_size, d_model)
x_token = tf.repeat(x_token, tf.shape(x)[-2], axis=0)

x_outputs = tf.repeat(x, params['model']['latents_per_x'], axis=-2)
xdiff = quantumflow.xdiff.get_xdiff(x, x)/params['model']['scale'] # (..., latent_size, latent_size)
xdiff_cross = quantumflow.xdiff.get_xdiff(x, x_inputs)/params['model']['scale'] # (..., latent_size, input_size)

latents = x_token

layers = []
for r in range(params['model']['num_repeats']):
    for i in range(params['model']['num_layers']):
        layers.append(self.enc_layers[r][i])
    layers.append(self.cross_enc_layers[r])

for i in range(params['model']['num_layers']):
    layers.append(self.enc_layers[params['model']['num_repeats']][i])

layers.append(self.layernorm)

for layer in self.pre_final_layers:
    layers.append(layer)

layers.append(self.final_layer)

    
for i in range(latents.shape[2]):
    plt.figure(figsize=(20, 3))
    plt.plot(latents[0, :, i, :])
    plt.title(f"Latents {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}")
    plt.show()
                
for self in layers:
    print(self.name)
    if 'encoder' in self.name:
        if 'cross' in self.name:

            inp = inputs
            for layer in self.input_layers:
                inp = layer(inp)

            lat = self.layernorm1a(latents)
            inp = self.layernorm1b(inp)
                
            for i in range(lat.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(lat[0, :, i, :])
                plt.title(f"Normalized Latents")
                plt.show()
                
            plt.figure(figsize=(20, 3))
            plt.plot(inp[0, 0, :, :])
            plt.title('Normalized Inputs')
            plt.show()
                
            attn_output, attention = self.mha(lat, inp, inp, xdiff_cross, mask=None)  # (..., latent_size, d_model)

            for i in range(attention.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.imshow(attention[0, :, i, 0, :], norm=matplotlib.colors.Normalize(vmin=0, vmax=0.01, clip=False), aspect=1.0)
                plt.show()
                print(f'Attention Map {np.std(attention[0, :, i, 0, :]):.3f}')
                
            for i in range(attn_output.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(attn_output[0, :, i, :])
                plt.title(f'Attention Output {np.mean(attn_output[0, :, i, :]):.3f} {np.std(attn_output[0, :, i, :]):.3f}')
                plt.show()
                
            attn_output = self.dropout1(attn_output, training=True)
            
            latents = latents + attn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip Attn Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
            lat = self.layernorm2(latents)  # (..., latent_size, d_model)
            ffn_output = self.ffn[1](self.ffn[0](lat))  # (..., input_size, d_model)
            ffn_output = self.dropout2(ffn_output, training=True)

            latents = latents + ffn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip FFN Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
        else:
            lat = self.layernorm1(latents)  # (..., input_size, d_model)
                
            for i in range(lat.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(lat[0, :, i, :])
                plt.title('Normalized Latents')
                plt.show()

            attn_output, attention = self.mha(lat, lat, lat, xdiff, mask=None)  # (..., input_size, d_model)

            if attention.shape[-1] > 1:
                for i in range(attention.shape[2]):

                    plt.figure(figsize=(20, 3))
                    plt.imshow(attention[0, :, i, 0, :], norm=matplotlib.colors.Normalize(vmin=0, vmax=0.01, clip=False), aspect=1.0)
                    plt.show()
                    print(f'Attention Map {np.std(attention[0, :, i, 0, :]):.3f}')
            else:
                print(attention[0, :, :, 0, 0].numpy())
            
            for i in range(attn_output.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(attn_output[0, :, i, :])
                plt.title(f'Attention Output {np.mean(attn_output[0, :, i, :]):.3f} {np.std(attn_output[0, :, i, :]):.3f}')
                plt.show()
                
            attn_output = self.dropout1(attn_output, training=True)
            
            latents = latents + attn_output

            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip Attn Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
                
            lat = self.layernorm2(latents)  # (..., input_size, d_model)
            ffn_output = self.ffn[1](self.ffn[0](lat))  # (..., input_size, d_model)
            ffn_output = self.dropout2(ffn_output, training=True)
            
            latents = latents + ffn_output
            
            for i in range(latents.shape[2]):
                plt.figure(figsize=(20, 3))
                plt.plot(latents[0, :, i, :])
                plt.title(f'Skip FFN Output {np.mean(latents[0, :, i, :]):.3f} {np.std(latents[0, :, i, :]):.3f}')
                plt.show()
    else:
        latents = self(latents)

kinetic_energy_density = tf.reduce_sum(latents[..., 0], axis=-1)

In [None]:
value = None
tensors = {}

def plot_layer(layer):
    global value, tensors
    
    if isinstance(layer, tf.keras.Model):
        tree.traverse(plot_layer, layer.layers)

    elif isinstance(layer, tf.keras.layers.InputLayer):
        value = sample_features[layer.name]
        tensors[layer.output.name] = value
        
    elif isinstance(layer, tf.keras.layers.Layer):
        if isinstance(layer.input, list):
            value = layer([tensors[inp.name] for inp in layer.input])
        else:
            value = layer(tensors[layer.input.name])
        tensors[layer.output.name] = value

_ = tree.traverse(plot_layer, model.layers)

In [None]:
for layer_name, tensor in tensors.items():
    if len(tensor.shape) == 3:
        plt.figure(figsize=(20, 3))
        plt.plot(tensor[0])
        plt.title(layer_name)
        plt.show()
    elif np.prod(tensor.shape) < 100:
        print(layer_name, tensor)
    else:
        print(layer_name, tensor.shape)