#Create shared functions

Please copy all the notebooks of this project to your Google Drive into a folder named "notebooks" inside a folder of your choice. Please also set the "notebook_path" variable in every notebook to this folder. This script will generate the folder structure for the data generated by the project.

- base folder
    - notebooks
        - 0_define_helper_functions.ipynb = THIS NOTEBOOK
        - 1a_generate_datasets.ipynb
        - ...
    - data (will be created)
    - quantumflow (will be created / overwritten)


In [None]:
# setup notebook if it is run on Google Colab, cwd = notebook file location
try:
    # change notebook_path if this notebook is in a different subfolder of Google Drive
    notebook_path = "Projects/QuantumFlow/notebooks"
    import os
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.chdir("/content/gdrive/My Drive/" + notebook_path)
except:
    pass

if not os.path.exists('../data'):
    os.makedirs('../data')

if not os.path.exists('../quantumflow'):
    os.makedirs('../quantumflow')

In [None]:
%%writefile ../quantumflow/generate_datasets.py
import tensorflow as tf
import os
from quantumflow.utils import load_hyperparameters, integrate, laplace
from quantumflow.numerov_solver import solve_schroedinger

@tf.function
def generate_potentials(return_x=False,
                        return_h=False,
                        dataset_size=100, 
                        discretisation_points=500, 
                        n_gauss=3, 
                        interval_length=1.0,
                        a_minmax=(0.0, 3*10.0), 
                        b_minmax=(0.4, 0.6), 
                        c_minmax=(0.03, 0.1), 
                        n_method='sum',
                        dtype='float64',
                        **kwargs):
    
    if dtype == 'double' or dtype == 'float64':
        dtype = tf.float64
    elif dtype == 'float' or dtype == 'float32':
        dtype = tf.float32
    else:
        raise ValueError('unknown dtype {}'.format(dtype))

    x = tf.linspace(tf.constant(0.0, dtype=dtype), interval_length, discretisation_points, name="x")

    a = tf.random.uniform((dataset_size, 1, n_gauss), minval=a_minmax[0], maxval=a_minmax[1], dtype=dtype, name="a")
    b = tf.random.uniform((dataset_size, 1, n_gauss), minval=b_minmax[0]*interval_length, maxval=b_minmax[1]*interval_length, dtype=dtype, name="b")
    c = tf.random.uniform((dataset_size, 1, n_gauss), minval=c_minmax[0]*interval_length, maxval=c_minmax[1]*interval_length, dtype=dtype, name="c")

    curves = -tf.square(tf.expand_dims(tf.expand_dims(x, 0), 2) - b)/(2*tf.square(c))
    curves = -a*tf.exp(curves)

    if n_method == 'sum':
        potentials = tf.reduce_sum(curves, -1, name="potentials")
    elif n_method == 'mean':
        potentials = tf.reduce_mean(curves, -1, name="potentials")
    else:
        raise NotImplementedError('Method {} is not implemented.'.format(n_method))

    returns = [potentials]

    if return_x:
        returns += [x]
    
    if return_h:
        h = tf.cast(interval_length/(discretisation_points-1), dtype=dtype) # discretisation interval
        returns += [h]
   
    return returns

def generate_datasets(data_dir, experiment, generate_names):
    if not isinstance(generate_names, list):
        generate_names = [generate_names]

    base_dir = os.path.join(data_dir, experiment)
    file_hyperparams = os.path.join(base_dir, "hyperparams.config")

    for run_name in generate_names:
        params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())

        tf.keras.backend.clear_session()
        tf.random.set_seed(params['seed'])
        potential, x, h = generate_potentials(return_x=True, return_h=True, **params)

        params['h'] = h
        energies, wavefunctions = solve_schroedinger(potential, params)

        save_dataset(base_dir, params['filename'], params['format'], x.numpy(), h.numpy(), potential.numpy(), wavefunctions.numpy(), energies.numpy())
        print("dataset", params['filename'] + '.' + params['format'].replace('pickle', 'pkl'), "saved to", base_dir)

def save_dataset(directory, filename, format, x, h, potential, wavefunctions, energies):
        if format in ['pickle', 'pkl']:
            import pickle
            with open(os.path.join(directory, filename + '.pkl'), 'wb') as f:
                pickle.dump({'x': x, 'h': h, 'potential': potential, 'wavefunctions': wavefunctions, 'energies': energies}, f)
            
        elif format in ['hdf5', 'h5']:
            import h5py
            with h5py.File(os.path.join(directory, filename + '.hdf5'), "w") as f:
                f.attrs['x'] = x
                f.attrs['h'] = h
                f.create_dataset('potential', data=potential, compression="gzip")
                f.create_dataset('wavefunctions', data=wavefunctions, compression="gzip")
                f.create_dataset('energies', data=energies, compression="gzip")
        else:
            raise KeyError('Unknown format {} to save dataset.'.format(params['format']))

In [None]:
%%writefile ../quantumflow/numerov_solver.py
import tensorflow as tf
from quantumflow.utils import integrate

# recurrent tensorflow cell for solving the numerov equation recursively
class ShootingNumerovCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, shape, h, **kwargs):
        super(ShootingNumerovCell, self).__init__(**kwargs)
        self._h2_scaled = 1 / 12 * h ** 2
        self.shape = shape

    @property
    def state_size(self):
        return self.shape + (4,)

    @property
    def output_size(self):
        return self.shape + (1,)

    def build(self, input_shape):
        self.built = True

    def call(self, inputs, states):
        k_m2, k_m1, y_m2, y_m1 = tf.unstack(states[0], axis=-1)
        
        y = (2 * (1 - 5 * self._h2_scaled * k_m1) * y_m1 - (1 + self._h2_scaled * k_m2) * y_m2) / (1 + self._h2_scaled * inputs)

        new_state = tf.stack([k_m1, inputs, y_m1, y], axis=-1)
        return y, new_state

# tf function for using the shooting numerov method
#
# the numerov_init_slope is the slope of the solution at x=0
# it can be constant>0 because it's actual value will be determined when the wavefunction is normalized
#
def shooting_numerov(k_squared, params):
    h = params['h']
    numerov_init_slope = params['numerov_init_slope']
    init_values = tf.zeros_like(k_squared[:, 0])
    one_step_values = numerov_init_slope * h * tf.ones_like(k_squared[:, 0])
    init_state = tf.stack([k_squared[:, 0], k_squared[:, 1], init_values, one_step_values], axis=-1)
    outputs = tf.keras.layers.RNN(ShootingNumerovCell(k_squared.shape[2:], h), return_sequences=True, dtype=params['dtype'])(k_squared[:, 2:], initial_state=init_state)
    output = tf.concat([tf.expand_dims(init_values, axis=1), tf.expand_dims(one_step_values, axis=1), outputs], axis=1)
    return output


# returns the rearranged schroedinger equation term in the numerov equation
# k_squared = 2*m_e/h_bar**2*(E - V(x))
def numerov_k_squared(potentials, energies):
    return 2 * (tf.expand_dims(energies, axis=1) - tf.tile(tf.expand_dims(potentials, axis=2), [1, 1, energies.shape[1]]))


@tf.function
def find_split_energies(potentials, params):
    M = potentials.shape[0]
    N = params['n_orbitals']

    # Knotensatz: number of roots = quantum state
    # so target root = target excited state quantum number
    target_roots = tf.tile(tf.expand_dims(tf.range(N + 1), axis=0), [M, 1])

    # lowest value of potential as lower bound
    E_split = tf.tile(tf.expand_dims(tf.reduce_min(potentials, axis=1), axis=1), [1, N + 1])

    solutions_split = tf.zeros((potentials.shape[0], potentials.shape[1], N + 1), dtype=potentials.dtype)
    not_converged = tf.ones(potentials.shape[0], dtype=tf.bool)
    search_boost = tf.ones_like(E_split, dtype=tf.bool)
    E_delta = tf.ones_like(E_split)

    while tf.math.reduce_any(not_converged):
        V_split = numerov_k_squared(tf.boolean_mask(potentials, not_converged), tf.boolean_mask(E_split, not_converged))

        solutions_split_new = shooting_numerov(V_split, params)

        partitioned_data = tf.dynamic_partition(solutions_split, tf.cast(not_converged, tf.int32) , 2)
        condition_indices = tf.dynamic_partition(tf.range(tf.shape(solutions_split)[0]), tf.cast(not_converged, tf.int32) , 2)

        solutions_split = tf.dynamic_stitch(condition_indices, [partitioned_data[0], solutions_split_new])
        solutions_split.set_shape((potentials.shape[0], potentials.shape[1], N + 1))

        roots_split = tf.reduce_sum(tf.cast(detect_roots(solutions_split), tf.int32), axis=1)

        not_converged = tf.logical_and(tf.logical_not(tf.reduce_all(tf.equal(roots_split, target_roots), axis=1)), not_converged)

        search_direction = tf.cast(roots_split < target_roots, potentials.dtype) - tf.cast(roots_split > target_roots, potentials.dtype)
        boost = tf.logical_and(tf.equal(search_direction, tf.sign(E_delta)), search_boost)

        E_delta += tf.cast(boost, potentials.dtype)*E_delta
        stop_boost = search_direction * tf.sign(E_delta) < 0
        search_boost &= tf.logical_not(stop_boost)
        E_delta += -1.5*E_delta*tf.cast(stop_boost, potentials.dtype)

        E_split += E_delta*tf.expand_dims(tf.cast(not_converged, potentials.dtype), axis=-1)

    return E_split


def detect_roots(array):
    return tf.logical_or(tf.equal(array[:, 1:], 0), array[:, 1:] * array[:, :-1] < 0)


@tf.function
def solve_numerov(potentials, target_roots, split_energies, params):
    E_low = split_energies[:, :-1]
    E_high = split_energies[:, 1:]

    # because the search interval is halved at every step
    # 32 iterations will always converge to the best numerically possible accuracy of E
    # (empirically ~25 steps)

    E = 0.5 * (E_low + E_high)
    E_last = E * 2
    
    while tf.reduce_any(tf.logical_not(tf.equal(E_last, E))):
        V = numerov_k_squared(potentials, E)

        solutions = shooting_numerov(V, params)
        roots = tf.reduce_sum(tf.cast(detect_roots(solutions), tf.int32), axis=1)

        update_low = roots <= target_roots
        update_high = tf.logical_not(update_low)

        E_low = tf.where(update_low, E, E_low)
        E_high = tf.where(update_high, E, E_high)

        E_last = E
        E = 0.5 * (E_low + E_high)

    solutions_low = shooting_numerov(numerov_k_squared(potentials, E_low), params)
    roots_low = tf.cast(detect_roots(solutions_low), tf.double)

    solutions_high = shooting_numerov(numerov_k_squared(potentials, E_high), params)
    roots_high = tf.cast(detect_roots(solutions_high), tf.double)

    roots_diff = tf.abs(roots_high - roots_low)  

    roots_cumsum = tf.cumsum(tf.pad(roots_diff, ((0, 0), (1, 0), (0, 0)), 'constant'), axis=1)

    invalid = tf.equal(roots_cumsum, tf.expand_dims(roots_cumsum[:, -1], axis=1))

    return solutions_low, E, invalid


@tf.function
def solve_schroedinger(potentials, params):
    M = potentials.shape[0]
    G = potentials.shape[1]
    N = params['n_orbitals']
    
    E_split = find_split_energies(potentials, params)

    target_roots = tf.tile(tf.expand_dims(tf.range(N), axis=0), [M, 1])
    solutions_forward, E_forward, invalid_forward = solve_numerov(potentials, target_roots, E_split, params)
    #solutions_forward /= tf.expand_dims(tf.reduce_max(tf.abs(solutions_forward)*tf.cast(tf.logical_not(invalid_forward), tf.double), axis=1), axis=1)

    solutions_backward, E_backward, invalid_backward = solve_numerov(tf.reverse(potentials, axis=[1]), target_roots, E_split, params)
    solutions_backward = tf.reverse(solutions_backward, axis=[1])
    invalid_backward = tf.reverse(invalid_backward, axis=[1])
    #solutions_backward /= tf.expand_dims(tf.reduce_max(tf.abs(solutions_backward)*tf.cast(tf.logical_not(invalid_backward), tf.double), axis=1), axis=1)

    n_invalid_forward = tf.reduce_sum(tf.cast(invalid_forward, tf.int32), axis=1)
    n_invalid_backward = tf.reduce_sum(tf.cast(invalid_backward, tf.int32), axis=1)
    merge_index = (G - n_invalid_forward - n_invalid_backward)//2 + n_invalid_forward

    merge_value_forward = tf.reduce_sum(tf.gather(tf.transpose(solutions_forward, perm=[0, 2, 1]), tf.expand_dims(merge_index, axis=2), batch_dims=2), axis=2)
    merge_value_backward = tf.reduce_sum(tf.gather(tf.transpose(solutions_backward, perm=[0, 2, 1]), tf.expand_dims(merge_index, axis=2), batch_dims=2), axis=2)

    factor = merge_value_forward/merge_value_backward
    solutions_backward *= tf.expand_dims(factor, axis=1)

    join_mask = tf.expand_dims(tf.expand_dims(tf.range(G), axis=0), axis=2) < tf.expand_dims(merge_index, axis=1)

    solutions = tf.where(join_mask, solutions_forward, solutions_backward)

    #normalization
    density = solutions ** 2
    norm = integrate(density, params['h'])
    solutions *= 1 / tf.sqrt(tf.expand_dims(norm, axis=1))

    E = 0.5*(E_forward + E_backward)
    
    return E, solutions

In [None]:
%%writefile ../quantumflow/utils.py
import os
import tensorflow as tf
import numpy as np

def integrate(y, h):
    return h*tf.reduce_sum((y[:, :-1] + y[:, 1:])/2., axis=1, name='trapezoidal_integral_approx')

def laplace(data, h):  # time_axis=1
    temp_laplace = 1 / h ** 2 * (data[:, :-2, :] + data[:, 2:, :] - 2 * data[:, 1:-1, :])
    return tf.pad(temp_laplace, ((0, 0), (1, 1), (0, 0)), 'constant')


def derivative_five_point(density, h):
    return tf.concat([1/(2*h)*(density[:, 2:3] - density[:, 0:1]), 
                      1/(12*h)*(-density[:, 4:] + 8*density[:, 3:-1] - 8*density[:, 1:-3] + density[:, 0:-4]),
                      1/(2*h)*(density[:, -1:] - density[:, -3:-2])], axis=1)

def laplace_five_point(density, h):
    return 1/(12*h**2)*(-density[:, 4:] + 16*density[:, 3:-1] - 30*density[:, 2:-2] + 16*density[:, 1:-3] - density[:, 0:-4])


def weizsaecker_functional(density, h):
    derivative_density = derivative_five_point(density, h)
    inverse_density = 1/density[:, 1:-1]

    weizsaecker_kinetic_energy_density = wked = 1/8*derivative_density**2*inverse_density
    weizsaecker_kinetic_energy_density = tf.concat([2*wked[:, 0:1] - wked[:, 1:2], wked, 2*wked[:, -1:] - wked[:, -2:-1]], axis=1)

    return integrate(weizsaecker_kinetic_energy_density, h)

def weizsaecker_functional_derivative(density, h):
    derivative_density = derivative_five_point(density, h)[:, 1:-1]
    laplace_density = laplace_five_point(density, h)
    inverse_density = 1/density[:, 2:-2]

    weizsaecker_kinetic_energy_functional_derivative = wkefd = 1/8*(derivative_density*inverse_density)**2 - 1/4*laplace_density*inverse_density
    weizsaecker_kinetic_energy_functional_derivative = tf.concat([3*wkefd[:, 0:1] - 2*wkefd[:, 1:2], 2*wkefd[:, 0:1] - wkefd[:, 1:2], wkefd, 2*wkefd[:, -1:] - wkefd[:, -2:-1], 3*wkefd[:, -1:] - 2*wkefd[:, -2:-1]], axis=1)

    return weizsaecker_kinetic_energy_functional_derivative


def load_hyperparameters(file_hyperparams, run_name='default', globals=None):
    from ruamel.yaml import YAML

    if globals is not None:
        with open(file_hyperparams) as f:
            globals_list = YAML().load(f)['globals']

    with open(file_hyperparams) as f:
        hparams = YAML().load(f)[run_name]

    if globals is None:
        return hparams

    dicts = [hparams]
    while len(dicts) > 0:
        data = dicts[0]
        for idx, obj in enumerate(data):
            if isinstance(data[obj], dict):
                dicts.append(data[obj])
                continue

            if data[obj] in globals_list:
                data[obj] = globals[data[obj]]
        del dicts[0]
    return hparams


import ipywidgets as widgets
from IPython.display import HTML, display
from matplotlib import animation, rc
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'

def anim_plot(array, x=None, interval=100, bar="", figsize=(15, 3), **kwargs):
    frames = len(array)
    
    if not bar == "":
        import ipywidgets as widgets
        widget = widgets.IntProgress(min=0, max=frames, description=bar, bar_style='success',
                                     layout=widgets.Layout(width='92%'))
        display(widget)

    fig, ax = plt.subplots(figsize=figsize)
    
    if x is None:
        plt_h = ax.plot(array[0], **kwargs)
    else:
        plt_h = ax.plot(x, array[0], **kwargs) 
        
    min_last = np.min(array[-1])
    max_last = np.max(array[-1])
    span_last = max_last - min_last
        
    ax.set_ylim([min_last - span_last*0.2, max_last + span_last*0.2])

    def init():
        return plt_h

    def animate(f):
        if not bar == "":
            widget.value = f

        for i, h in enumerate(plt_h):
            if x is None:
                h.set_data(np.arange(len(array[f][:, i])), array[f][:, i], **kwargs)
            else:
                h.set_data(x, array[f][:, i], **kwargs)
        return plt_h

    # call the animator. blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=interval,
                                   blit=True, repeat=False)

    plt.close(fig)
    rc('animation', html='html5')
    display(HTML(anim.to_html5_video(embed_limit=1024)))

    if not bar == "":
        widget.close()


def calculate_density_and_energies(potential, wavefunctions, energies, N, h):
    assert(N <= wavefunctions.shape[2])
    density = np.sum(np.square(wavefunctions)[:, :, :N], axis=-1)

    lpwf = 1/(12*h**2)*(-wavefunctions[:, 4:] + 16*wavefunctions[:, 3:-1] - 30*wavefunctions[:, 2:-2] + 16*wavefunctions[:, 1:-3] - wavefunctions[:, 0:-4])
    laplace_wavefunctions = tf.concat([3*lpwf[:, 0:1] - 2*lpwf[:, 1:2], 2*lpwf[:, 0:1] - lpwf[:, 1:2], lpwf, 2*lpwf[:, -1:] - lpwf[:, -2:-1], 3*lpwf[:, -1:] - 2*lpwf[:, -2:-1]], axis=1) 

    kinetic_energy_densities = -0.5*wavefunctions*laplace_wavefunctions
    potential_energy_densities = np.expand_dims(potential, axis=2)*wavefunctions**2
    
    potential_energies = h * (np.sum(potential_energy_densities, axis=1) - 0.5 * (np.take(potential_energy_densities, 0, axis=1) + np.take(potential_energy_densities, -1, axis=1)))
    kinetic_energies = h * (np.sum(kinetic_energy_densities, axis=1) - 0.5 * (np.take(kinetic_energy_densities, 0, axis=1) + np.take(kinetic_energy_densities, -1, axis=1)))

    energy = np.sum(energies[:, :N], axis=-1)
    potential_energy = np.sum(potential_energies[:, :N], axis=-1)
    kinetic_energy = np.sum(kinetic_energies[:, :N], axis=-1)
    
    kinetic_energy_density = np.sum(kinetic_energy_densities[:, :, :N], axis=-1)
    potential_energy_density = np.sum(potential_energy_densities[:, :, :N], axis=-1)

    return density, energy, potential_energy, kinetic_energy, potential_energy_density, kinetic_energy_density


def calculate_system_properties(potential, wavefunctions, energies, N, h):
    assert(N <= wavefunctions.shape[2])

    density, energy, potential_energy, kinetic_energy, potential_energy_density, kinetic_energy_density = calculate_density_and_energies(potential, wavefunctions, energies, N, h)
    derivative = -potential

    vW_kinetic_energy = weizsaecker_functional(density, h).numpy()
    vW_derivative = weizsaecker_functional_derivative(density, h).numpy()

    return density, energy, potential_energy, kinetic_energy, potential_energy_density, kinetic_energy_density, derivative, vW_kinetic_energy, vW_derivative


class QFDataset():
    def __init__(self, dataset_file, params):
        extension = dataset_file.split('.')[-1]
        
        if extension in ['pkl', 'pickle']:
            import pickle
            with open(dataset_file, 'rb') as f:
                x, h, potential, wavefunctions, energies = pickle.load(f).values()

        elif extension in ['hdf5', 'h5']:
            import h5py
            with h5py.File(dataset_file, 'r') as f:
                x = f.attrs['x']
                h = f.attrs['h']
                potential = f['potential'][()]
                wavefunctions = f['wavefunctions'][()]
                energies = f['energies'][()]
        else:
            raise NotImplementedError('File extension missing or not supported.')  

        if params['N'] == 'all':
            all_data = [calculate_system_properties(potential, wavefunctions, energies, N, h) for N in range(1, energies.shape[1]+1)]
            density, energy, potential_energy, kinetic_energy, potential_energy_density, kinetic_energy_density, derivative, vW_kinetic_energy, vW_derivative = \
                [np.concatenate([all_data[i][j] for i in range(len(all_data))], axis=0) for j in range(len(all_data[0]))]
        else:
            density, energy, potential_energy, kinetic_energy, potential_energy_density, kinetic_energy_density, derivative, vW_kinetic_energy, vW_derivative = \
                calculate_system_properties(potential, wavefunctions, energies, params['N'], h)
            
        self.dataset_size, self.discretisation_points = density.shape

        if params.get('subtract_von_weizsaecker', False):
            kinetic_energy -= params.get('von_weizsaecker_factor', 1.0)*vW_kinetic_energy
            derivative -= params.get('von_weizsaecker_factor', 1.0)*vW_derivative
            
        if params['dtype'] == 'double' or params['dtype'] == 'float64':
            if potential.dtype == np.float32:
                raise ImportError("requested dtype={}, but dataset is saved with dtype={}, which is less precise.".format(params['dtype'], potential.dtype))
            self.dtype = np.float64
        elif params['dtype'] == 'float' or params['dtype'] == 'float32':
            self.dtype = np.float32
        else:
            raise ValueError('unknown dtype {}'.format(params['dtype']))

        self.x = x.astype(self.dtype)
        self.h = h.astype(self.dtype)
        self.potential = potential.astype(self.dtype)
        self.density = density.astype(self.dtype)
        self.energy = energy.astype(self.dtype)
        self.potential_energy = potential_energy.astype(self.dtype)
        self.kinetic_energy = kinetic_energy.astype(self.dtype)
        self.potential_energy_density = potential_energy_density.astype(self.dtype)
        self.kinetic_energy_density = kinetic_energy_density.astype(self.dtype)
        self.derivative = derivative.astype(self.dtype)
        self.vW_kinetic_energy = vW_kinetic_energy.astype(self.dtype)
        self.vW_derivative = vW_derivative.astype(self.dtype)

        if not 'features' in params or not 'targets' in params: 
            return

        self.features = {}
        self.targets = {}

        def add_by_name(dictionary, name):
            if name == 'density':
                dictionary['density'] = self.density
            elif name == 'derivative':
                dictionary['derivative'] = self.derivative
            elif name == 'potential':
                dictionary['potential'] = self.potential
            elif name == 'kinetic_energy':
                dictionary['kinetic_energy'] = self.kinetic_energy
            elif name == 'kinetic_energy_density':
                dictionary['kinetic_energy_density'] = self.kinetic_energy_density
            else:
                raise KeyError('feature/target {} does not exist or is not implemented.'.format(name))

        for feature in params['features']:
            add_by_name(self.features, feature)

        for target in params['targets']:
            add_by_name(self.targets, target)

    def get_params(self, shapes=True, h=True, mean=False):
        import numpy as np

        params = {}
        if h:
            params['h'] = self.h

        if shapes:
            params['features_shape'] = {name:feature.shape[1:] for name, feature in self.features.items()}
            params['targets_shape'] = {name:target.shape[1:] for name, target in self.targets.items()}

        if mean:
            params['features_mean'] = {name:np.mean(feature, axis=0) for name, feature in self.features.items()}
            params['targets_mean'] = {name:np.mean(target, axis=0) for name, target in self.targets.items()}

        return params

def run_experiment(experiment, run_name, data_dir='../data'): 
    base_dir = os.path.join(data_dir, experiment)
    model_dir = os.path.join(base_dir, run_name)

    file_model = os.path.join(base_dir, "model.py")
    exec(open(file_model).read(), globals())

    file_hyperparams = os.path.join(base_dir, "hyperparams.config")
    params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())

    train(params, model_dir, data_dir)

def run_multiple(experiment, run_name, data_dir='../data'): 
    import copy

    base_dir = os.path.join(data_dir, experiment)
    model_dir = os.path.join(base_dir, run_name)

    file_model = os.path.join(base_dir, "model.py")
    exec(open(file_model).read(), globals())

    file_hyperparams = os.path.join(base_dir, "hyperparams.config")
    params = load_hyperparameters(file_hyperparams, run_name=run_name, globals=globals())

    def apply_configuration(hparams, configuration):
        dicts = [hparams]
        while len(dicts) > 0:
            data = dicts[0]
            for idx, obj in enumerate(data):
                if obj in ['int_min', 'int_max']:
                    continue

                if isinstance(data[obj], dict):
                    dicts.append(data[obj])
                    continue

                if obj in configuration.keys():
                    data[obj] = configuration[obj]
            del dicts[0]
        return hparams

    def extend_configurations(configurations_out, run_appendices_out, configurations_in, run_appendices_in):
        if len(configurations_out) == 0:
            return configurations_in, run_appendices_in

        merged_configurations = []
        merged_run_appendices = []
        for configuration_out, run_appendix_out in zip(configurations_out, run_appendices_out):
            for configuration_in, run_appendix_in in zip(configurations_in, run_appendices_in):
                merged_configurations.append(configuration_out.copy().update(configuration_in))
                merged_run_appendices.append(run_appendix_out + '_' + run_appendix_in)

        return merged_configurations, merged_run_appendices

    configurations = []
    run_appendices = []

    if 'int_min' in params and 'int_max' in params:
        for (int_key, int_min), (int_key_max, int_max) in zip(params['int_min'].items(), params['int_max'].items()):
            assert int_key == int_key_max 
            int_configurations = []
            int_run_appendices = []
            for int_value in range(int_min, int_max+1):
                int_configurations.append({int_key: int_value})
                int_run_appendices.append(('{}{:0' + str(len(str(int_max))) + 'd}').format(int_key, int_value)) # TODO: support negative values
            configurations, run_appendices = extend_configurations(configurations, run_appendices, int_configurations, int_run_appendices)
    
    for configuration, run_appendix in zip(configurations, run_appendices):
        train(apply_configuration(copy.deepcopy(params), configuration), os.path.join(model_dir, run_appendix), data_dir)


def build_model(params, data_dir='../data', dataset_train=None):
    if dataset_train is None:
        dataset_train = QFDataset(os.path.join(data_dir, params['dataset_train']), params)
        params['dataset'] = dataset_train.get_params(shapes=True, h=True, mean=True)
    
    tf.keras.backend.clear_session()
    return params['model'](params)

def train(params, model_dir=None, data_dir='../data', callbacks=None):
    dataset_train = QFDataset(os.path.join(data_dir, params['dataset_train']), params)
    dataset_validate = QFDataset(os.path.join(data_dir, params['dataset_validate']), params) if 'dataset_validate' in params else None
    params['dataset'] = dataset_train.get_params(shapes=True, h=True, mean=True)

    tf.keras.backend.clear_session()
    if 'seed' in params:
        tf.random.set_seed(params['seed'])
        
    model = build_model(params, data_dir=data_dir, dataset_train=dataset_train)

    optimizer_kwargs = params['optimizer_kwargs'].copy()
    if isinstance(params['optimizer_kwargs']['learning_rate'], float):
        learning_rate = params['optimizer_kwargs']['learning_rate']
    elif isinstance(params['optimizer_kwargs']['learning_rate'], str):
        optimizer_kwargs['learning_rate'] = learning_rate = getattr(tf.keras.optimizers.schedules, params['optimizer_kwargs']['learning_rate'])(**params['optimizer_kwargs']['learning_rate_kwargs'])
        del optimizer_kwargs['learning_rate_kwargs']
    elif issubclass(params['optimizer_kwargs']['learning_rate'], tf.keras.optimizers.schedules.LearningRateSchedule):
        optimizer_kwargs['learning_rate'] = learning_rate = params['optimizer_kwargs']['learning_rate'](**params['optimizer_kwargs']['learning_rate_kwargs'])
        del optimizer_kwargs['learning_rate_kwargs']

    optimizer = getattr(tf.keras.optimizers, params['optimizer'])(**optimizer_kwargs) if isinstance(params['optimizer'], str) else params['optimizer'](**optimizer_kwargs)
    model.compile(optimizer, loss=params['loss'], loss_weights=params.get('loss_weights', None), metrics=params.get('metrics', None))

    if params.get('load_checkpoint', None) is not None:
        model.load_weights(os.path.join(data_dir, params['load_checkpoint']))
        if params['fit_kwargs'].get('verbose', 0) > 0:
            print("loading weights from ", os.path.join(data_dir, params['load_checkpoint']))

    if callbacks is None:
        callbacks = []

    if model_dir is not None and params.get('checkpoint', False):
        checkpoint_params = params['checkpoint_kwargs'].copy()
        checkpoint_params['filepath'] = os.path.join(model_dir, checkpoint_params.pop('filename', 'weights.{epoch:05d}.hdf5'))
        checkpoint_params['verbose'] = checkpoint_params.get('verbose', min(1, params['fit_kwargs'].get('verbose', 1)))
        callbacks.append(tf.keras.callbacks.ModelCheckpoint(**checkpoint_params))

    if model_dir is not None and params.get('tensorboard', False):
        tensorboard_callback_class = params['tensorboard'] if callable(params['tensorboard']) else tf.keras.callbacks.TensorBoard
        callbacks.append(tensorboard_callback_class(log_dir=model_dir, learning_rate=learning_rate, **params['tensorboard_kwargs']))

    model.fit(x=dataset_train.features, 
              y=dataset_train.targets, 
              callbacks=callbacks,
              validation_data=(dataset_validate.features, dataset_validate.targets) if dataset_validate is not None else None,
              **params['fit_kwargs'])

    if model_dir is not None and params['save_model'] is True:
        model.save(os.path.join(model_dir, 'model.h5')) 

    if model_dir is not None and params['export'] is True:
        export_model = getattr(model, params['export_model']) if not params.get('export_model', 'self') == 'self' else model
        tf.saved_model.save(export_model, os.path.join(model_dir, 'saved_model'))

    return model, params

In [None]:
%%writefile ../quantumflow/keras_utils.py
import tensorflow as tf

class KineticEnergyFunctionalDerivativeModel(tf.keras.Model):
    def __init__(self, params):
        super().__init__()
        self.model = params['model_kwargs']['base_model'](params)
        self.h = tf.constant(params['dataset']['h'], dtype=params['dtype'])

        self.output_names = sorted(['derivative'] + self.model.output_names)
        self.input_names = self.model.input_names

    @tf.function
    def call(self, density):
        density = tf.nest.flatten(density)

        with tf.GradientTape() as tape:
            tape.watch(density)
            kinetic_energy = self.model(density)

        derivative = tf.identity(1/self.h*tape.gradient(kinetic_energy, density), name='derivative')
        return derivative, kinetic_energy

    def fit(self, y=None, validation_data=None, **kwargs):
        if isinstance(y, dict):
            y = tf.nest.flatten(y)
        
        if isinstance(validation_data, (tuple, list)) and isinstance(validation_data[1], dict):
            validation_data = (validation_data[0], tf.nest.flatten(validation_data[1]))

        super().fit(y=y, validation_data=validation_data, **kwargs)

    def _set_output_attrs(self, outputs):
        super()._set_output_attrs(outputs)
        self.output_names = sorted(['derivative'] + self.model.output_names)

    def summary(self, *args, **kwargs):
        return self.model.summary(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self.model.save(*args, **kwargs)

    def save_weights(self, *args, **kwargs):
        self.model.optimizer = self.optimizer
        returns = self.model.save_weights(*args, **kwargs)
        self.model.optimizer = None
        return returns

    def load_weights(self, *args, **kwargs):
        self.model.optimizer = self.optimizer
        returns = self.model.load_weights(*args, **kwargs)
        self.optimizer = self.model.optimizer
        self.model.optimizer = None
        return returns

import time

class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
    def __init__(self, *args, metrics_freq=0, learning_rate=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_time = None
        self.last_step = None
        self.learning_rate = learning_rate
        self.metrics_freq = metrics_freq

    def on_epoch_end(self, epoch, logs=None):
        """Runs metrics and histogram summaries at epoch end."""
        if self.metrics_freq and epoch % self.metrics_freq == 0 or any(['val_' in key for key in logs.keys()]):
            self._log_metrics(logs, prefix='epoch_', step=epoch)

        if self.histogram_freq and epoch % self.histogram_freq == 0:
            self._log_weights(epoch)

        if self.embeddings_freq and epoch % self.embeddings_freq == 0:
            self._log_embeddings(epoch)

    def _log_metrics(self, logs, prefix, step):
        
        if self.last_time is not None:
            new_time = time.time()
            logs['epochs_per_second'] = (step - self.last_step)/(new_time - self.last_time)
            self.last_time = new_time
            self.last_step = step
        else:
            self.last_time = time.time()
            self.last_step = step

        if self.learning_rate is not None and isinstance(self.learning_rate, float):
            logs['learning_rate'] = self.learning_rate
        elif isinstance(self.learning_rate, tf.keras.optimizers.schedules.LearningRateSchedule):
            logs['learning_rate'] = self.learning_rate(self.model.optimizer.iterations)

        def rename_key(key):
            prepend_val = False
            if 'val_' in key:
                prepend_val = True
                key = key.replace('val_', '')
            if '_loss' in key:
                key = 'loss/' + key.replace('_loss', '')
            if '_mean_absolute_error' in key:
                key = 'mean_absolute_error/' + key.replace('_mean_absolute_error', '')
            if prepend_val:
                key = 'val_' + key
            return key

        logs = {rename_key(key): value for key, value in logs.items()}
        super()._log_metrics(logs, '', step)


class WarmupExponentialDecay(tf.keras.optimizers.schedules.ExponentialDecay):
    def __init__(self, warmup_steps=None, cold_steps=None, cold_factor=0.1, final_learning_rate=0.0, **kwargs):
        super().__init__(**kwargs)
        self.warmup_steps = warmup_steps
        self.cold_steps = cold_steps
        self.cold_factor = cold_factor
        self.final_learning_rate = final_learning_rate

    @tf.function
    def __call__(self, step):
        return tf.where(step <= self.cold_steps + self.warmup_steps, 
                        tf.where(step <= self.cold_steps, 
                                 self.initial_learning_rate*self.cold_factor,
                                 self.initial_learning_rate*(self.cold_factor + tf.cast(step - self.cold_steps, tf.float32)*(1 - self.cold_factor)/tf.cast(self.warmup_steps, tf.float32))), 
                        tf.maximum(super().__call__(step - self.cold_steps - self.warmup_steps), self.final_learning_rate))

    def get_config(self):
        config = super().get_config()
        config.update({'warmup_steps': self.warmup_steps})
        config.update({'cold_steps': self.cold_steps})
        config.update({'cold_factor': self.cold_factor})
        config.update({'final_learning_rate': self.final_learning_rate})
        return config
