#Create shared functions

Please copy all the notebooks of this project to your Google Drive into a folder named "notebooks" inside a folder of your choice. Please also set the "notebook_path" variable in every notebook to this folder. This script will generate the folder structure for the data generated by the project.

- base folder
    - notebooks
        - 0_define_helper_functions.ipynb = THIS NOTEBOOK
        - 1a_generate_datasets.ipynb
        - ...
    - data (will be created)
    - quantumflow (will be created / overwritten)


In [None]:
notebook_path = "Projects/QuantumFlow/notebooks"
import os

try:
    from google.colab import drive
    drive.mount('/content/gdrive')
    os.chdir("/content/gdrive/My Drive/" + notebook_path)
except:
    pass

if not os.path.exists('../data'):
    os.makedirs('../data')

if not os.path.exists('../quantumflow'):
    os.makedirs('../quantumflow')

In [None]:
%%writefile ../quantumflow/generate_potentials.py
import tensorflow as tf
import numpy as np

def tf_generate_potentials(dataset_size=2000, points=500, n_gauss=3, length=1.0,
                           a_minmax=(0.0, 3*10.0), b_minmax=(0.4, 0.6), c_minmax=(0.03, 0.1), return_x=False):
    x = tf.linspace(0.0, length, points, name="x")

    a = tf.random_uniform((dataset_size, 1, n_gauss), minval=a_minmax[0], maxval=a_minmax[1], name="a")
    b = tf.random_uniform((dataset_size, 1, n_gauss), minval=b_minmax[0]*length, maxval=b_minmax[1]*length, name="b")
    c = tf.random_uniform((dataset_size, 1, n_gauss), minval=c_minmax[0]*length, maxval=c_minmax[1]*length, name="c")

    curves = -tf.square(tf.expand_dims(tf.expand_dims(x, 0), 2) - b)/(2*tf.square(c))
    curves = -a*tf.exp(curves)

    potentials = tf.reduce_sum(curves, -1, name="potentials")

    if return_x:
        return potentials, x
    else:
        return potentials

def generate_potentials(dataset_size=2000, points=500, n_gauss=3, length=1.0,
                        a_minmax=(0.0, 3*10.0), b_minmax=(0.4, 0.6), c_minmax=(0.03, 0.1), return_x=False, seed=0):
    g = tf.Graph()
    with g.as_default():
        tf.set_random_seed(seed)
        potentials, x = tf_generate_potentials(dataset_size, points, n_gauss, length,
                                               a_minmax, b_minmax, c_minmax, return_x=True)
        sess = tf.Session(graph=g)
        np_potentials, np_x = sess.run([potentials, x])

    if return_x:
        return np_potentials, np_x
    else:
        return np_potentials

In [None]:
%%writefile ../quantumflow/calculus_utils.py
import numpy as np


def integrate(data, h, axis=-1):
    if data.shape[axis] < 2:
        raise ValueError(
            "Integration failed: time-axis {} has {} elements, required: >=2".format(axis, data.shape[axis]))
    return h * (np.sum(data, axis=axis) - 0.5 * (np.take(data, 0, axis=axis) + np.take(data, -1, axis=axis)))


def integrate_simpson(data, h, axis=-1):
    if data.shape[axis] < 2:
        raise ValueError(
            "Integration failed: time-axis {} has {} elements, required: >=2".format(axis, data.shape[axis]))
    integral = 0
    if not (data.shape[axis] > 2 and data.shape[axis] % 2 == 1):
        integral = integrate(np.take(data, [-2, -1], axis=axis), h, axis)
        if data.shape[axis] == 2:
            return integral
        data = np.take(data, range(0, data.shape[axis] - 1), axis=axis)

    even = np.take(data, range(0, data.shape[axis], 2), axis=axis)
    odd = np.take(data, range(1, data.shape[axis], 2), axis=axis)

    return integral + h / 3 * (2 * np.sum(even, axis=axis) + 4 * np.sum(odd, axis=axis) - np.take(data, 0, axis=axis)
                                                                                        - np.take(data, -1, axis=axis))

def laplace(data, h):  # time_axis=1
    temp_laplace = 1 / h ** 2 * (data[:, :-2, :] + data[:, 2:, :] - 2 * data[:, 1:-1, :])
    return np.pad(temp_laplace, ((0, 0), (1, 1), (0, 0)), 'constant')

def normalize(function, h, axis=-1):   
    norm = integrate_simpson(function, h, axis=axis)
    return function * 1 / np.expand_dims(norm, axis=axis)

def rbf_kernel(X, X_train, gamma):
    return np.exp(-gamma*np.sum(np.square(X[:, :, np.newaxis] - np.transpose(X_train)[np.newaxis, :, :]), 1))

def predict(X, X_train, weights, gamma):
    return np.sum(weights[np.newaxis, :]*rbf_kernel(X, X_train, gamma), 1)

def functional_derivative(X, X_train, weights, gamma, h):
    return -1/h*np.sum(weights[np.newaxis, :]*2*gamma*(X[:, :, np.newaxis] - np.transpose(X_train)[np.newaxis, :, :])*rbf_kernel(X, X_train, gamma)[:, np.newaxis, :], 2)



In [None]:
%%writefile ../quantumflow/numerov_solver.py
import tensorflow as tf
import numpy as np

from quantumflow.calculus_utils import integrate, integrate_simpson, laplace

# recurrent tensorflow cell for solving the numerov equation recursively
class ShootingNumerovCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self, h=1.0):
        super().__init__()
        self._h2_scaled = 1 / 12 * h ** 2

    def __call__(self, inputs, state, scope=None):
        k_m2, k_m1, y_m2, y_m1 = tf.unstack(state, axis=-1)

        y = (2 * (1 - 5 * self._h2_scaled * k_m1) * y_m1 - (1 + self._h2_scaled * k_m2) * y_m2) / (
                    1 + self._h2_scaled * inputs)

        new_state = tf.stack([k_m1, inputs, y_m1, y], axis=-1)
        return y, new_state

    @property
    def state_size(self):
        return 4

    @property
    def output_size(self):
        return 1

# tf function for using the shooting numerov method
#
# the init_factor is the slope of the solution at x=0
# it can be constant>0 because it's actual value will be determined when the wavefunction is normalized
#
def shooting_numerov(k_squared, h=1, init_factor=1e-128):
    shooting_cell = ShootingNumerovCell(h=h)
    init_state = tf.stack([k_squared[:, 0], k_squared[:, 1], tf.zeros_like(k_squared[:, 2]),
                           init_factor * h * tf.ones_like(k_squared[:, 3])], axis=-1)
    outputs, _ = tf.nn.static_rnn(shooting_cell, tf.unstack(k_squared, axis=1)[2:], initial_state=init_state)
    output = tf.stack([init_state[:, 2], init_state[:, 3]] + outputs, axis=-1)
    return output

# returns the rearranged schroedinger equation term in the numerov equation
# k_squared = 2*m_e/h_bar**2*(E - V(x))
def numerov_k_squared(potentials, energies):
    return 2 * (np.expand_dims(energies, axis=1) - np.repeat(np.expand_dims(potentials, axis=2), energies.shape[1], axis=2))


def detect_roots(array1):
    return np.logical_or(array1[:, 1:] == 0, array1[:, 1:] * array1[:, :-1] < 0)


class NumerovSolver():
    def __init__(self, G, h):
        self.K_SQUARED = tf.placeholder(tf.float64, shape=(None, G))
        self.solution = shooting_numerov(self.K_SQUARED, h=h)
        self.sess = tf.Session()
        self.h = h
        self.G = G
        
    # functtion to solve the shooting numerov equation for a given tensor of k_squared functions
    # the tensor has to have one dimension for the time along wich to solve the equation
    # all other dimensions will be flattened internally but the return value will be reshaped back
    def run_numerov(self, k_squared, time_axis=-1):
        shape = k_squared.shape[:time_axis] + k_squared.shape[time_axis + 1:]
        flattened = np.reshape(np.moveaxis(k_squared, time_axis, -1), (-1, k_squared.shape[time_axis]))
        flattened_solutions = self.sess.run(self.solution, feed_dict={self.K_SQUARED: flattened})
        solutions = np.reshape(flattened_solutions, shape + (k_squared.shape[time_axis],))
        return np.moveaxis(solutions, -1, time_axis)

    
    def solve_numerov(self, np_potentials, target_roots, split_energies, cut_after_last_root=True, progress=None):

        np_E_low = split_energies[:, :-1].copy()
        np_E_high = split_energies[:, 1:].copy()

        # because the search interval is halved at every step
        # 32 iterations will always converge to the best numerically possible accuracy of E
        # (empirically ~25 steps)

        np_E = 0.5 * (np_E_low + np_E_high)
        np_E_last = np.copy(np_E) * 2

        
        if progress is not None:
            progress.value = 0
            progress.max = np.prod(np_E.shape)
            progress.description = 'Numerov Pass: '
        
        step = 0
        while np.any(np_E_last - np_E):
            np_V = numerov_k_squared(np_potentials, np_E)
            np_solutions = self.run_numerov(np_V, time_axis=1)
            np_roots = np.sum(detect_roots(np_solutions), axis=1)

            np_E_low[np_roots <= target_roots] = np_E[np_roots <= target_roots]
            np_E_high[np_roots > target_roots] = np_E[np_roots > target_roots]

            np_E_last = np_E
            np_E = 0.5 * (np_E_low + np_E_high)

            if progress is not None:
                progress.value = progress.max - np.sum(np_E_last - np_E != 0)
                progress.description = 'Numerov Pass: ' + str(progress.value) + '/' + str(progress.max)
            step += 1

        np_solutions_low = self.run_numerov(numerov_k_squared(np_potentials, np_E_low), time_axis=1)
        np_roots_low = 1 * detect_roots(np_solutions_low)

        np_solutions_high = self.run_numerov(numerov_k_squared(np_potentials, np_E_high), time_axis=1)
        np_roots_high = 1 * detect_roots(np_solutions_high)

        np_roots_diff = np.abs(np_roots_high - np_roots_low)  # useless but keep it
        # assert(np.all(np.sum(np_roots_diff, axis=1) == 1)) # sometimes roots are at different places!

        if cut_after_last_root:
            np_nan_cumsum = np.cumsum(np.pad(np_roots_diff, ((0, 0), (1, 0), (0, 0)), 'constant'), axis=1)
            np_nan_index = np_nan_cumsum == np.expand_dims(np_nan_cumsum[:, -1], axis=1)

            np_solutions_low[np_nan_index] = np.nan

        return np_solutions_low, np_E, step

    
    def find_split_energies(self, np_potentials, N, progress=None):
        M = np_potentials.shape[0]
        
        # Knotensatz: number of roots = quantum state
        # so target root = target excited state quantum number
        target_roots = np.repeat(np.expand_dims(np.arange(N + 1), axis=0), M, axis=0)

        # lowest value of potential as lower bound
        np_E_split = np.repeat(np.expand_dims(np.min(np_potentials, axis=1), axis=1), N + 1, axis=1)

        np_solutions_split = np.zeros((np_potentials.shape[0], np_potentials.shape[1], N + 1), dtype=np.float64)
        not_converged = np.ones(np_potentials.shape[0], dtype=np.bool)
        search_boost = np.ones_like(np_E_split)
        np_E_delta = np.ones_like(np_E_split)

        if progress is not None:
            progress.value = 0
            progress.max = M
            progress.description = 'Searching Roots:'

        step = 0
        while np.any(not_converged):
            np_V_split = numerov_k_squared(np_potentials[not_converged], np_E_split[not_converged])
            np_solutions_split[not_converged] = self.run_numerov(np_V_split, time_axis=1)
            np_roots_split = np.sum(detect_roots(np_solutions_split), axis=1)

            not_converged[np.all(np_roots_split == target_roots, axis=1)] = False

            search_direction = 1 * (np_roots_split < target_roots) - 1 * (np_roots_split > target_roots)
            np_E_delta[np.logical_and(search_direction == np.sign(np_E_delta), search_boost)] *= 2
            search_boost[search_direction * np.sign(np_E_delta) < 0] = 0
            np_E_delta[search_direction * np.sign(np_E_delta) < 0] *= -0.5

            np_E_split[not_converged] += np_E_delta[not_converged]

            if progress is not None:
                progress.value = progress.max - np.sum(not_converged)
                progress.description = 'Searching Roots: ' + str(progress.value) + '/' + str(progress.max)
            step += 1

        return np_E_split, step

    
    def solve_schroedinger(self, np_potentials, N, progress=None):
        M = np_potentials.shape[0]
        G = np_potentials.shape[1]
        
        assert (G == self.G)
        np_E_split, _ = self.find_split_energies(np_potentials, N, progress=progress)

        target_roots = np.repeat(np.expand_dims(np.arange(N), axis=0), M, axis=0)
        np_solutions_forward, np_E_forward, _ = self.solve_numerov(np_potentials, target_roots, np_E_split, progress=progress)
        np_solutions_forward /= np.expand_dims(np.nanmax(np.abs(np_solutions_forward), axis=1), axis=1)

        assert not np.any(np.all(np.isnan(np_solutions_forward), axis=1))

        np_solutions_backward, np_E_backward, _ = self.solve_numerov(np.flip(np_potentials, axis=1), target_roots, np_E_split, progress=progress)
        np_solutions_backward = np.flip(np_solutions_backward, axis=1)
        np_solutions_backward /= np.expand_dims(np.nanmax(np.abs(np_solutions_backward), axis=1), axis=1)

        assert not np.any(np.all(np.isnan(np_solutions_backward), axis=1))

        np_factor = np_solutions_forward / np_solutions_backward

        assert not np.any(np.all(np.isnan(np_factor), axis=1))

        np_solutions_backward *= np.expand_dims(np.nanmedian(np_factor, axis=1), axis=1)

        join_error = np.nanmin(np.abs(np_solutions_backward - np_solutions_forward), axis=1)

        join_error = np.nanmax(np_solutions_backward / np_solutions_forward, axis=1)

        join_index = np.nanargmin(np.abs(np_solutions_backward - np_solutions_forward), axis=1)

        join_mask = np.expand_dims(np.expand_dims(np.arange(np_solutions_backward.shape[1]), axis=0), axis=2) >= np.expand_dims(join_index, axis=1)

        np_solutions = np_solutions_forward
        np_solutions[join_mask] = np_solutions_backward[join_mask]

        # normalization
        np_norm = np_solutions ** 2
        np_norm = integrate_simpson(np_norm, self.h, axis=1)
        np_solutions *= 1 / np.sqrt(np.expand_dims(np_norm, axis=1))

        assert not np.any(np.all(np.isnan(np_solutions), axis=1))
        
        np_E = 0.5*(np_E_forward + np_E_backward)
        
        return np_E, np_solutions
    

In [None]:
%%writefile ../quantumflow/colab_train_utils.py
import numpy as np
from quantumflow.calculus_utils import integrate, integrate_simpson, laplace

def test_colab_devices():
    import os
    import tensorflow as tf

    has_gpu = False
    has_tpu = False

    has_gpu = (tf.test.gpu_device_name() == '/device:GPU:0')

    try:
        device_name = os.environ['COLAB_TPU_ADDR']
        has_tpu = True
    except KeyError:
        pass

    return has_gpu, has_tpu


def unpack_dataset(N, dataset):
    x, potentials, solutions, E = dataset.values()
    density = np.sum(np.square(solutions)[:, :, :N], axis=-1)
    
    dataset_size, discretization_points, _ = solutions.shape
    h = (max(x) - min(x))/(discretization_points-1)
    
    potential = np.expand_dims(potentials, axis=2)*solutions**2
    P = integrate_simpson(potential, h, axis=1)
    K = E - P

    kinetic_energy = np.sum(K[:, :N], axis=-1)
    
    return x, potentials, solutions, E, density, kinetic_energy, dataset_size, discretization_points, h    


class InputPipeline(object):
    def __init__(self, N, dataset_file, is_training=False):
        import pickle
        self.is_training = is_training

        with open(dataset_file, 'rb') as f:
            self.x, self.potentials, _, self.energies, self.densities, self.kenergies, self.M, self.G, self.h = unpack_dataset(N, pickle.load(f))
        self.derivatives = -self.potentials

    def input_fn(self, params):
        import tensorflow as tf

        dataset_densities = tf.data.Dataset.from_tensor_slices(self.densities.astype(np.float32))
        dataset_kenergies = tf.data.Dataset.from_tensor_slices(self.kenergies.astype(np.float32))
        dataset_derivatives = tf.data.Dataset.from_tensor_slices(self.derivatives.astype(np.float32))

        dataset = tf.data.Dataset.zip((dataset_densities, tf.data.Dataset.zip((dataset_kenergies, dataset_derivatives))))

        if self.is_training:
            dataset = dataset.repeat()

        if params['shuffle']:
            dataset = dataset.shuffle(buffer_size=params['shuffle_buffer_size'], seed=params.get('seed', None))

        dataset = dataset.batch(params['batch_size'], drop_remainder=True)
        return dataset

    def features_shape(self):
        return self.densities.shape

    def targets_shape(self):
        return (self.kenergies.shape, self.derivatives.shape)

    def __str__(self):
        string = ''
        if self.is_training:
            string += 'Train Dataset: '
        else:
            string += 'Dataset: '
            
        string += str(self.densities.shape) + ' ' + str(self.kenergies.shape) + ' ' + str(self.derivatives.shape) + ' ' + str(self.densities.dtype)
        return string


def get_resolver():
    import os
    import tensorflow as tf

    try:
        device_name = os.environ['COLAB_TPU_ADDR']
        TPU_WORKER = 'grpc://' + device_name
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_WORKER)
        tf.config.experimental_connect_to_host(resolver.master())
        tf.tpu.experimental.initialize_tpu_system(resolver)

    except KeyError:
        resolver = None

    return resolver

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)/ float(N))
    return cumsum[N:] - cumsum[:-N]


def load_hyperparameters(file_hyperparams, run_name='default', globals=None):
    from ruamel.yaml import YAML

    if globals is not None:
        with open(file_hyperparams) as f:
            globals_list = YAML().load(f)['globals']

    with open(file_hyperparams) as f:
        hparams = YAML().load(f)[run_name]

    if globals is None:
        return hparams

    dicts = [hparams]
    while len(dicts) > 0:
        data = dicts[0]
        for idx, obj in enumerate(data):
            if isinstance(data[obj], dict):
                dicts.append(data[obj])
                continue

            if data[obj] in globals_list:
                data[obj] = globals[data[obj]]
        del dicts[0]
    return hparams


import ipywidgets as widgets
from IPython.display import Audio, HTML, display
from matplotlib import animation, rc
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none'

def anim_plot(array, x=None, interval=100, bar="", figsize=(15, 3), **kwargs):
    frames = len(array)
    
    if not bar == "":
        import ipywidgets as widgets
        widget = widgets.IntProgress(min=0, max=frames, description=bar, bar_style='success',
                                     layout=widgets.Layout(width='92%'))
        display(widget)

    fig, ax = plt.subplots(figsize=figsize)
    
    if x is None:
        plt_h = ax.plot(array[0], **kwargs)
    else:
        plt_h = ax.plot(x, array[0], **kwargs) 
        
    min_last = np.min(array[-1])
    max_last = np.max(array[-1])
    span_last = max_last - min_last
        
    ax.set_ylim([min_last - span_last*0.2, max_last + span_last*0.2])

    def init():
        return plt_h

    def animate(f):
        if not bar == "":
            widget.value = f

        for i, h in enumerate(plt_h):
            if x is None:
                h.set_data(np.arange(len(array[f][:, i])), array[f][:, i], **kwargs)
            else:
                h.set_data(x, array[f][:, i], **kwargs)
        return plt_h

    # call the animator. blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=frames, interval=interval,
                                   blit=True, repeat=False)

    plt.close(fig)
    rc('animation', html='html5')
    display(HTML(anim.to_html5_video(embed_limit=1024)))

    if not bar == "":
        widget.close()

In [None]:
%%writefile ../quantumflow/cnn_tpu_training.py

import tensorflow as tf
import numpy as np
import os
from tensorflow.contrib import summary
from tensorflow.contrib.training.python.training import evaluation

def conv_nn(input, return_layers=False, filters=(16, 16, 16), kernel_size=(121, 121, 121), strides=(1, 1, 1), padding='valid', activation=tf.nn.softplus, **kwargs):
    layers_list = []
    value = tf.expand_dims(input, axis=-1)
    value = tf.expand_dims(value, axis=2)
    layers_list.append(tf.reduce_sum(value, axis=2))
    
    assert len(filters) == len(kernel_size)
    layers = len(filters)
    
    for l in range(layers -1):
        value = tf.layers.conv2d(value, filters=filters[l], kernel_size=(kernel_size[l], 1), strides=strides[l], activation=activation, padding=padding, **kwargs)
        layers_list.append(tf.reduce_sum(value, axis=2))
        
    value = tf.layers.conv2d(value, filters=filters[-1], kernel_size=(kernel_size[-1], 1), padding=padding, **kwargs)
    value = tf.reduce_sum(value, axis=2)
    layers_list.append(value)
    
    value = tf.reduce_sum(value, axis=2)
    value = tf.reduce_sum(value, axis=1)
    layers_list.append(value)
    
    if return_layers:
        return value, layers_list
    else:
        return value


class SineWaveInitializer(tf.initializers.variance_scaling):
    def __call__(self, shape, dtype=None, partition_info=None):
        G = shape[0]
        shape[0] = 1

        weights = super().__call__(shape=shape, dtype=dtype, partition_info=partition_info)

        lin = tf.reshape(tf.linspace(0.0, np.pi, G), (G, 1, 1, 1))
        freq = tf.reshape(tf.range(shape[-1], dtype=tf.float32)+2, (1, 1, 1, shape[-1]))
        sine = tf.sin(lin*freq)/G

        return sine*weights


def learning_rate_schedule(params, global_step):
    batches_per_epoch = params['train_total_size'] / params['train_batch_size']
    current_epoch = tf.cast((tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32)

    initial_learning_rate = params['learning_rate']

    if params['use_learning_rate_warmup']:
        warmup_decay = params['learning_rate_decay']**(
        (params['warmup_epochs'] + params['cold_epochs']) /
        params['learning_rate_decay_epochs'])
        adj_initial_learning_rate = initial_learning_rate * warmup_decay

    final_learning_rate = params['final_learning_rate_factor'] * initial_learning_rate

    learning_rate = tf.train.exponential_decay(
        learning_rate=initial_learning_rate,
        global_step=global_step,
        decay_steps=int(params['learning_rate_decay_epochs'] * batches_per_epoch),
        decay_rate=params['learning_rate_decay'],
        staircase=True)

    if params['use_learning_rate_warmup']:
        wlr = 0.1 * adj_initial_learning_rate
        wlr_height = tf.cast(0.9 * adj_initial_learning_rate / 
                                (params['warmup_epochs'] + params['learning_rate_decay_epochs'] - 1), tf.float32)
        
        epoch_offset = tf.cast(params['cold_epochs'] - 1, tf.int32)
        exp_decay_start = (params['warmup_epochs'] + params['cold_epochs'] + params['learning_rate_decay_epochs'])

        lin_inc_lr = tf.add(wlr, tf.multiply(tf.cast(tf.subtract(current_epoch, epoch_offset), tf.float32), wlr_height))

        learning_rate = tf.where(
            tf.greater_equal(current_epoch, params['cold_epochs']),
            (tf.where(tf.greater_equal(current_epoch, exp_decay_start), learning_rate, lin_inc_lr)), 
            tf.ones_like(learning_rate)*wlr)

    # Set a minimum boundary for the learning rate.
    learning_rate = tf.maximum(learning_rate, final_learning_rate, name='learning_rate')

    return learning_rate


def deriv_conv_nn_model_fn(features, labels, mode, params):

    if isinstance(features, dict):
        features = features['feature']

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    is_eval = (mode == tf.estimator.ModeKeys.EVAL)   

    target_prediction = conv_nn(features, **params['kwargs'])
    derivative_prediction = 1/params['h']*tf.gradients(target_prediction, features)[0]

    predictions = {
        'value': target_prediction,
        'derivative': derivative_prediction
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
            'regression': tf.estimator.export.PredictOutput(predictions)
            })
    
    target, derivative = labels
    
    loss_y = tf.losses.mean_squared_error(target_prediction, target)
    loss_gradient = tf.losses.mean_squared_error(derivative_prediction, derivative)

    loss_l2 = []
    for v in tf.trainable_variables():
        if 'kernel' in v.name:
            loss_l2.append(tf.nn.l2_loss(v))
    loss_l2 = tf.add_n(loss_l2)
    
    loss = loss_y + params['balance']*loss_gradient
    
    if params['l2_loss'] > 0.0:
        loss += params['l2_loss']*loss_l2

    host_call = None
    train_op = None

    if is_training:
        batches_per_epoch = params['train_total_size'] / params['train_batch_size']
        global_step = tf.train.get_or_create_global_step()
        current_epoch = tf.cast((tf.cast(global_step, tf.float32) / batches_per_epoch), tf.int32)
        learning_rate = learning_rate_schedule(params, global_step)
        #tf.summary.scalar('lr', learning_rate) # doesn't work on TPU

        if params['optimizer'] == 'Adam':
            print('Using Adam optimizer')
            optimizer = tf.train.AdamOptimizer(learning_rate)
        elif params['optimizer'] == 'sgd':
            print('Using SGD optimizer')
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        elif params['optimizer'] == 'momentum':
            print('Using Momentum optimizer')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
        elif params['optimizer'] == 'RMS':
            print('Using RMS optimizer')
            optimizer = tf.train.RMSPropOptimizer(learning_rate)
        else:
            tf.logging.fatal('Unknown optimizer:', params['optimizer'])

        if params['gradient_clipping']:
            optimizer = tf.contrib.estimator.clip_gradients_by_norm(optimizer, params['gradient_clip_norm'])

        if params['use_tpu']:
            optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step=global_step)

        # To log the loss, current learning rate, and epoch for Tensorboard, the
        # summary op needs to be run on the host CPU via host_call. host_call
        # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
        # dimension. These Tensors are implicitly concatenated to
        # [params['batch_size']].
        gs_t = tf.reshape(global_step, [1])
        #loss_t = tf.reshape(loss, [1])
        #loss_y_t = tf.reshape(loss_y, [1])
        #loss_gradient_t = tf.reshape(loss_gradient, [1])
        #loss_l2_t = tf.reshape(loss_l2, [1])
        lr_t = tf.reshape(learning_rate, [1])
        ce_t = tf.reshape(current_epoch, [1])

        if not params['skip_host_call']:
            def host_call_fn(gs, lr, ce):
                gs = gs[0]
                with summary.create_file_writer(params['model_dir']).as_default():
                    with summary.always_record_summaries():
                        #summary.scalar('loss', tf.reduce_mean(loss), step=gs)
                        #summary.scalar('loss_y', tf.reduce_mean(loss_y), step=gs)
                        #summary.scalar('loss_gradient', tf.reduce_mean(loss_gradient), step=gs)
                        #summary.scalar('loss_l2', tf.reduce_mean(loss_l2), step=gs)

                        summary.scalar('learning_rate', tf.reduce_mean(lr), step=gs)
                        summary.scalar('current_epoch', tf.reduce_mean(ce), step=gs)

                        return summary.all_summary_ops()

            host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    eval_metrics = None
    if is_eval:
        def metric_fn(target_prediction, target, derivative_prediction, derivative):
            return {
                'value_mae': tf.metrics.mean_absolute_error(target_prediction, target),
                'derivative_mae': tf.metrics.mean_absolute_error(derivative_prediction, derivative),
            }

        eval_metrics = (metric_fn, [target_prediction, target, derivative_prediction, derivative])

    return tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        host_call=host_call,
        eval_metrics=eval_metrics)


def train(params, resolver, dataset_train, dataset_eval):
    tpu_config = tf.contrib.tpu.TPUConfig(iterations_per_loop=params['iterations'], num_shards=params['num_shards'])

    run_config = tf.contrib.tpu.RunConfig(
        cluster=resolver,
        model_dir=params['model_dir'],
        tf_random_seed=params['seed'],
        save_checkpoints_secs=params['save_checkpoints_secs'],
        keep_checkpoint_max=params['keep_checkpoint_max'],
        save_summary_steps=params['save_summary_steps'],
        session_config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=params['log_device_placement']),
        tpu_config=tpu_config)

    model = tf.contrib.tpu.TPUEstimator(
        model_fn=params['model_fn'],
        use_tpu=params['use_tpu'],
        config=run_config,
        params=params,
        train_batch_size=params['train_batch_size'],
        eval_batch_size=params['eval_batch_size'])

    print('Training for {} steps with batch size {}, returning to CPU every {} steps\n'
        'summary every {} steps, saving every {} seconds.'.format(params['train_steps'], params['train_batch_size'], params['iterations'], 
                                                                    params['save_summary_steps'], params['save_checkpoints_secs']))

    latest_checkpoint = model.latest_checkpoint()
    current_step = int(latest_checkpoint.split('-')[-1]) if latest_checkpoint is not None else 0
    while current_step < params['train_steps']:
        train_steps = params['train_steps_per_eval'] if current_step % params['train_steps_per_eval'] == 0 else \
                                                        params['train_steps_per_eval'] - current_step % params['train_steps_per_eval']
        cycle = current_step // params['train_steps_per_eval']
        print('Starting training cycle {} - training for {} steps.'.format(cycle, train_steps))
        model.train(input_fn=dataset_train.input_fn, steps=train_steps)
        current_step += train_steps

        print('Starting evaluation cycle {}.'.format(cycle))
        eval_results = model.evaluate(input_fn=dataset_eval.input_fn, steps=params['eval_total_size'] // params['eval_batch_size'])
        print('Evaluation results: {}'.format(eval_results))

    def serving_input_receiver_fn():
        features = tf.placeholder(dtype=tf.float32, shape=[None] + list(dataset_eval.features_shape()[1:]))
        receiver_tensors = {'features': features}
        return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

    export_path = os.path.join(params['model_dir'], 'saved_model')
    print("Exporting model to {} with input placeholder {}".format(export_path, [None] + list(dataset_eval.features_shape()[1:])))
    model.export_saved_model(export_path, serving_input_receiver_fn)
