In [1]:
import time
import numpy as np
import tensorflow as tf
import math
from scipy.stats import multivariate_normal as normal
import matplotlib.pyplot as plt

tf.keras.backend.set_floatx('float64')

class Solver(object):
    def __init__(self,):
        self.valid_size = 512
        self.batch_size = 64
        self.num_iterations = 10000
        self.logging_frequency = 100
        self.lr_values = [5e-2, 5e-3, 1e-3]
        
        self.lr_boundaries = [5000, 8000]
        self.config = Config()

        self.model = WholeNet()
        self.y_init = self.model.y_init
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(self.lr_boundaries, self.lr_values)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=1e-8)

    def train(self):
        """Training the model"""
        start_time = time.time()
        training_history = []
        dW = self.config.sample(self.valid_size)
        valid_data = dW
        for step in range(self.num_iterations+1):
            if step % self.logging_frequency == 0:
                loss, cost = self.model(valid_data, training=True)
                y_init = self.y_init.numpy()[0][0]
                elapsed_time = time.time() - start_time
                training_history.append([step, cost, y_init, loss])
                print("step: %5u, loss: %.4e, Y0: %.4e, cost: %.4e,  elapsed time: %3u" % (step, loss, y_init, cost, elapsed_time))
            self.train_step(self.config.sample(self.batch_size))
        print('Y0_true: %.4e' % y_init)
        self.training_history = training_history

    @tf.function
    def train_step(self, train_data):
        """Updating the gradients"""
        with tf.GradientTape(persistent=True) as tape:
            loss, cost = self.model(train_data, training = True)
        grad = tape.gradient(loss, self.model.trainable_variables)
        del tape
        self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))

class WholeNet(tf.keras.Model):
    """Building the neural network architecture"""
    def __init__(self):
        super(WholeNet, self).__init__()
        self.config = Config()
        self.y_init = tf.Variable(tf.random.normal([1, self.config.dim_y], mean=0, stddev=1, dtype=tf.dtypes.float64))
        self.z_net = [FNNet() for _ in range(self.config.num_time_interval)]

    def call(self, dw, training):
        x_init = tf.ones([1, self.config.dim_x], dtype=tf.dtypes.float64) * 1.0
        time_stamp = np.arange(0, self.config.num_time_interval) * self.config.delta_t
        all_one_vec = tf.ones([tf.shape(dw)[0], 1], dtype=tf.dtypes.float64)
        x = tf.matmul(all_one_vec, x_init)
        y = tf.matmul(all_one_vec, self.y_init)
        l = 0.0 # The cost functional
        for t in range(0, self.config.num_time_interval):
            data = time_stamp[t], x, y
            z = self.z_net[t](data, training=True)
            u = self.config.u_fn(time_stamp[t], x, y, z)

            l = l + self.config.f_fn(time_stamp[t], x, u) * self.config.delta_t

            b_ = self.config.b_fn(time_stamp[t], x, y, z)
            sigma_ = self.config.sigma_fn(time_stamp[t], x, y, z)
            f_ = self.config.Hx_fn(time_stamp[t], x, y, z)

            x = x + b_ * self.config.delta_t + sigma_ * dw[:, :, t]
            y = y - f_ * self.config.delta_t + z * dw[:, :, t]

        delta = y + self.config.hx_tf(self.config.total_T, x)
        loss = tf.reduce_mean(tf.reduce_sum(delta**2, 1, keepdims=True))

        l = l + self.config.h_fn(self.config.total_T, x)
        cost = tf.reduce_mean(l)

        return loss, cost

class FNNet(tf.keras.Model):
    """ Define the feedforward neural network """
    def __init__(self):
        super(FNNet, self).__init__()
        self.config = Config()
        num_hiddens = [self.config.dim_x+10, self.config.dim_x+10, self.config.dim_x+10]
        self.bn_layers = [
            tf.keras.layers.BatchNormalization(
                momentum=0.99,
                epsilon=1e-6,
                beta_initializer=tf.random_normal_initializer(0.0, stddev=0.1),
                gamma_initializer=tf.random_uniform_initializer(0.1, 0.5)
            )
            for _ in range(len(num_hiddens) + 2)]
        self.dense_layers = [tf.keras.layers.Dense(num_hiddens[i],
                                                   use_bias=False,
                                                   activation=None)
                             for i in range(len(num_hiddens))]
        # final output should be gradient of size dim_z
        self.dense_layers.append(tf.keras.layers.Dense(self.config.dim_z, activation=None))

    def call(self, inputs, training):
        """structure: bn -> (dense -> bn -> relu) * len(num_hiddens) -> dense -> bn"""
        t, x, y = inputs
        ts = tf.ones([tf.shape(x)[0], 1], dtype=tf.dtypes.float64) * t
        x = tf.concat([ts, x, y], axis=1)
        x = self.bn_layers[0](x, training=True)
        for i in range(len(self.dense_layers) - 1):
            x = self.dense_layers[i](x)
            x = self.bn_layers[i+1](x, training=True)
            x = tf.nn.relu(x)
        x = self.dense_layers[-1](x)
        x = self.bn_layers[-1](x, training=True)
        return x

class Config(object):
    """Define the configs in the systems"""
    def __init__(self):
        super(Config, self).__init__()
        self.dim_x = 20
        self.dim_y = 20
        self.dim_z = 20
        self.num_time_interval = 25
        self.total_T = 0.1
        self.delta_t = (self.total_T + 0.0) / self.num_time_interval
        self.sqrth = np.sqrt(self.delta_t)
        self.t_stamp = np.arange(0, self.num_time_interval) * self.delta_t

    def sample(self, num_sample):
        dw_sample = normal.rvs(size=[num_sample, self.num_time_interval]) * self.sqrth
        return dw_sample[:, np.newaxis, :]

    def f_fn(self, t, x, u):
        return 0.25 * tf.reduce_sum(x ** 2, 1, keepdims=True) + tf.reduce_sum(u ** 2, 1, keepdims=True)

    def h_fn(self, t, x):
        ones = tf.ones(shape=tf.stack([self.dim_x, self.dim_x]), dtype=tf.dtypes.float64)
        inputs = tf.matmul(x, ones)
        return 0.5 * tf.reduce_sum(inputs*x, 1, keepdims=True)

    def b_fn(self, t, x, y, z):
        return -0.25*x + 0.5*y + 0.5*z

    def sigma_fn(self, t, x, y, z):
        return 0.2*x + 0.5*y + 0.5*z

    def Hx_fn(self, t, x, y, z):
        return -0.5*x - 0.25*y + 0.2*z

    def hx_tf(self, t, x):
        ones = tf.ones(shape=tf.stack([self.dim_x, self.dim_x]), dtype=tf.dtypes.float64) * 1.0
        return tf.matmul(x, ones)
    
    def u_fn(self, t, x, y, z):
        return 0.5 * (y+z)

def main():
    print('Training time 1:')
    solver = Solver()
    solver.train()
    k = 10
    data = np.array(solver.training_history)
    output = np.zeros((len(data[:, 0]), 3 + k))
    output[:, 0] = data[:, 0] # step
    output[:, 1] = data[:, 2] # y_init
    output[:, 2] = data[:, 3] # loss
    output[:, 3] = data[:, 1] # cost

    for i in range(k - 1):
        print('Training time %3u:' % (i + 2))
        solver = Solver()
        solver.train()
        data = np.array(solver.training_history)
        output[:, 4 + i] = data[:, 1]

    a = ['%d', '%.5e', '%.5e']
    for i in range(k):
        a.append('%.5e')
    np.savetxt('./LQ_data_d20.csv', output, fmt=a, delimiter=',')

    print('Solving is done!')

if __name__ == '__main__':
    main()

Training time 1:
step:     0, loss: 7.2093e+03, Y0: 7.8070e-01, cost: 1.8573e+02,  elapsed time:   2
step:   100, loss: 6.7642e+03, Y0: 7.8070e-01, cost: 1.7397e+02,  elapsed time:  18
step:   200, loss: 6.8023e+03, Y0: 7.8070e-01, cost: 1.7442e+02,  elapsed time:  19
step:   300, loss: 6.7764e+03, Y0: 7.8070e-01, cost: 1.7352e+02,  elapsed time:  21
step:   400, loss: 6.7343e+03, Y0: 7.8070e-01, cost: 1.7249e+02,  elapsed time:  22
step:   500, loss: 6.6643e+03, Y0: 7.8070e-01, cost: 1.7060e+02,  elapsed time:  23
step:   600, loss: 6.6505e+03, Y0: 7.8070e-01, cost: 1.7044e+02,  elapsed time:  25
step:   700, loss: 6.7177e+03, Y0: 7.8070e-01, cost: 1.7189e+02,  elapsed time:  26
step:   800, loss: 6.7137e+03, Y0: 7.8070e-01, cost: 1.7199e+02,  elapsed time:  28
step:   900, loss: 6.7850e+03, Y0: 7.8070e-01, cost: 1.7350e+02,  elapsed time:  29
step:  1000, loss: 6.7818e+03, Y0: 7.8070e-01, cost: 1.7354e+02,  elapsed time:  30
step:  1100, loss: 6.6767e+03, Y0: 7.8070e-01, cost: 1.7111

step:  9800, loss: 6.7310e+03, Y0: 7.8070e-01, cost: 1.7217e+02,  elapsed time: 149
step:  9900, loss: 6.7314e+03, Y0: 7.8070e-01, cost: 1.7218e+02,  elapsed time: 151
step: 10000, loss: 6.7310e+03, Y0: 7.8070e-01, cost: 1.7217e+02,  elapsed time: 152
Y0_true: 7.8070e-01
Training time   2:
step:     0, loss: 7.9132e+03, Y0: -1.7870e+00, cost: 1.9539e+02,  elapsed time:   2
step:   100, loss: 7.1716e+03, Y0: -1.7870e+00, cost: 1.7638e+02,  elapsed time:  19
step:   200, loss: 7.1350e+03, Y0: -1.7870e+00, cost: 1.7465e+02,  elapsed time:  20
step:   300, loss: 6.9007e+03, Y0: -1.7870e+00, cost: 1.6911e+02,  elapsed time:  22
step:   400, loss: 6.9607e+03, Y0: -1.7870e+00, cost: 1.7025e+02,  elapsed time:  23
step:   500, loss: 6.9471e+03, Y0: -1.7870e+00, cost: 1.6985e+02,  elapsed time:  24
step:   600, loss: 7.0000e+03, Y0: -1.7870e+00, cost: 1.7109e+02,  elapsed time:  26
step:   700, loss: 7.0090e+03, Y0: -1.7870e+00, cost: 1.7141e+02,  elapsed time:  27
step:   800, loss: 6.9958e+03

step:  9300, loss: 6.9611e+03, Y0: -1.7870e+00, cost: 1.7004e+02,  elapsed time: 147
step:  9400, loss: 6.9615e+03, Y0: -1.7870e+00, cost: 1.7005e+02,  elapsed time: 148
step:  9500, loss: 6.9616e+03, Y0: -1.7870e+00, cost: 1.7005e+02,  elapsed time: 150
step:  9600, loss: 6.9605e+03, Y0: -1.7870e+00, cost: 1.7002e+02,  elapsed time: 151
step:  9700, loss: 6.9625e+03, Y0: -1.7870e+00, cost: 1.7007e+02,  elapsed time: 153
step:  9800, loss: 6.9631e+03, Y0: -1.7870e+00, cost: 1.7009e+02,  elapsed time: 154
step:  9900, loss: 6.9619e+03, Y0: -1.7870e+00, cost: 1.7005e+02,  elapsed time: 156
step: 10000, loss: 6.9600e+03, Y0: -1.7870e+00, cost: 1.7001e+02,  elapsed time: 157
Y0_true: -1.7870e+00
Training time   3:
step:     0, loss: 7.1713e+03, Y0: 1.5110e+00, cost: 1.8542e+02,  elapsed time:   2
step:   100, loss: 6.8673e+03, Y0: 1.5110e+00, cost: 1.7715e+02,  elapsed time:  19
step:   200, loss: 6.7788e+03, Y0: 1.5110e+00, cost: 1.7437e+02,  elapsed time:  20
step:   300, loss: 6.7218e+0

step:  8900, loss: 6.7015e+03, Y0: 1.5110e+00, cost: 1.7229e+02,  elapsed time: 145
step:  9000, loss: 6.7039e+03, Y0: 1.5110e+00, cost: 1.7234e+02,  elapsed time: 146
step:  9100, loss: 6.7033e+03, Y0: 1.5110e+00, cost: 1.7232e+02,  elapsed time: 147
step:  9200, loss: 6.7040e+03, Y0: 1.5110e+00, cost: 1.7234e+02,  elapsed time: 149
step:  9300, loss: 6.7036e+03, Y0: 1.5110e+00, cost: 1.7233e+02,  elapsed time: 150
step:  9400, loss: 6.7039e+03, Y0: 1.5110e+00, cost: 1.7234e+02,  elapsed time: 152
step:  9500, loss: 6.7037e+03, Y0: 1.5110e+00, cost: 1.7234e+02,  elapsed time: 153
step:  9600, loss: 6.7046e+03, Y0: 1.5110e+00, cost: 1.7235e+02,  elapsed time: 155
step:  9700, loss: 6.7044e+03, Y0: 1.5110e+00, cost: 1.7235e+02,  elapsed time: 156
step:  9800, loss: 6.7035e+03, Y0: 1.5110e+00, cost: 1.7233e+02,  elapsed time: 158
step:  9900, loss: 6.7036e+03, Y0: 1.5110e+00, cost: 1.7233e+02,  elapsed time: 159
step: 10000, loss: 6.7014e+03, Y0: 1.5110e+00, cost: 1.7228e+02,  elapsed ti

step:  8500, loss: 6.8906e+03, Y0: -1.5862e+00, cost: 1.7330e+02,  elapsed time: 140
step:  8600, loss: 6.8853e+03, Y0: -1.5862e+00, cost: 1.7319e+02,  elapsed time: 141
step:  8700, loss: 6.8865e+03, Y0: -1.5862e+00, cost: 1.7321e+02,  elapsed time: 143
step:  8800, loss: 6.8867e+03, Y0: -1.5862e+00, cost: 1.7321e+02,  elapsed time: 144
step:  8900, loss: 6.8867e+03, Y0: -1.5862e+00, cost: 1.7321e+02,  elapsed time: 145
step:  9000, loss: 6.8834e+03, Y0: -1.5862e+00, cost: 1.7314e+02,  elapsed time: 147
step:  9100, loss: 6.8836e+03, Y0: -1.5862e+00, cost: 1.7314e+02,  elapsed time: 148
step:  9200, loss: 6.8859e+03, Y0: -1.5862e+00, cost: 1.7319e+02,  elapsed time: 150
step:  9300, loss: 6.8862e+03, Y0: -1.5862e+00, cost: 1.7319e+02,  elapsed time: 151
step:  9400, loss: 6.8878e+03, Y0: -1.5862e+00, cost: 1.7323e+02,  elapsed time: 153
step:  9500, loss: 6.8881e+03, Y0: -1.5862e+00, cost: 1.7324e+02,  elapsed time: 154
step:  9600, loss: 6.8891e+03, Y0: -1.5862e+00, cost: 1.7325e+02,

step:  8100, loss: 6.4256e+03, Y0: 9.0146e-01, cost: 1.6358e+02,  elapsed time: 142
step:  8200, loss: 6.4247e+03, Y0: 9.0146e-01, cost: 1.6356e+02,  elapsed time: 143
step:  8300, loss: 6.4224e+03, Y0: 9.0146e-01, cost: 1.6351e+02,  elapsed time: 144
step:  8400, loss: 6.4223e+03, Y0: 9.0146e-01, cost: 1.6350e+02,  elapsed time: 146
step:  8500, loss: 6.4216e+03, Y0: 9.0146e-01, cost: 1.6349e+02,  elapsed time: 147
step:  8600, loss: 6.4181e+03, Y0: 9.0146e-01, cost: 1.6341e+02,  elapsed time: 149
step:  8700, loss: 6.4178e+03, Y0: 9.0146e-01, cost: 1.6340e+02,  elapsed time: 150
step:  8800, loss: 6.4166e+03, Y0: 9.0146e-01, cost: 1.6338e+02,  elapsed time: 152
step:  8900, loss: 6.4148e+03, Y0: 9.0146e-01, cost: 1.6334e+02,  elapsed time: 153
step:  9000, loss: 6.4167e+03, Y0: 9.0146e-01, cost: 1.6338e+02,  elapsed time: 155
step:  9100, loss: 6.4191e+03, Y0: 9.0146e-01, cost: 1.6344e+02,  elapsed time: 156
step:  9200, loss: 6.4174e+03, Y0: 9.0146e-01, cost: 1.6339e+02,  elapsed ti

step:  7700, loss: 7.0442e+03, Y0: -1.8707e-01, cost: 1.7564e+02,  elapsed time: 129
step:  7800, loss: 7.0519e+03, Y0: -1.8707e-01, cost: 1.7580e+02,  elapsed time: 130
step:  7900, loss: 7.0506e+03, Y0: -1.8707e-01, cost: 1.7578e+02,  elapsed time: 132
step:  8000, loss: 7.0566e+03, Y0: -1.8707e-01, cost: 1.7594e+02,  elapsed time: 133
step:  8100, loss: 7.0571e+03, Y0: -1.8707e-01, cost: 1.7595e+02,  elapsed time: 134
step:  8200, loss: 7.0621e+03, Y0: -1.8707e-01, cost: 1.7605e+02,  elapsed time: 136
step:  8300, loss: 7.0615e+03, Y0: -1.8707e-01, cost: 1.7604e+02,  elapsed time: 137
step:  8400, loss: 7.0595e+03, Y0: -1.8707e-01, cost: 1.7599e+02,  elapsed time: 139
step:  8500, loss: 7.0605e+03, Y0: -1.8707e-01, cost: 1.7601e+02,  elapsed time: 140
step:  8600, loss: 7.0606e+03, Y0: -1.8707e-01, cost: 1.7600e+02,  elapsed time: 141
step:  8700, loss: 7.0599e+03, Y0: -1.8707e-01, cost: 1.7598e+02,  elapsed time: 143
step:  8800, loss: 7.0618e+03, Y0: -1.8707e-01, cost: 1.7603e+02,

step:  7200, loss: 6.9672e+03, Y0: -6.2662e-01, cost: 1.7499e+02,  elapsed time: 119
step:  7300, loss: 6.9738e+03, Y0: -6.2662e-01, cost: 1.7516e+02,  elapsed time: 121
step:  7400, loss: 6.9682e+03, Y0: -6.2662e-01, cost: 1.7505e+02,  elapsed time: 122
step:  7500, loss: 6.9773e+03, Y0: -6.2662e-01, cost: 1.7522e+02,  elapsed time: 124
step:  7600, loss: 6.9744e+03, Y0: -6.2662e-01, cost: 1.7513e+02,  elapsed time: 125
step:  7700, loss: 6.9828e+03, Y0: -6.2662e-01, cost: 1.7532e+02,  elapsed time: 126
step:  7800, loss: 6.9697e+03, Y0: -6.2662e-01, cost: 1.7500e+02,  elapsed time: 128
step:  7900, loss: 6.9630e+03, Y0: -6.2662e-01, cost: 1.7487e+02,  elapsed time: 129
step:  8000, loss: 6.9785e+03, Y0: -6.2662e-01, cost: 1.7521e+02,  elapsed time: 130
step:  8100, loss: 6.9795e+03, Y0: -6.2662e-01, cost: 1.7523e+02,  elapsed time: 132
step:  8200, loss: 6.9798e+03, Y0: -6.2662e-01, cost: 1.7524e+02,  elapsed time: 133
step:  8300, loss: 6.9780e+03, Y0: -6.2662e-01, cost: 1.7521e+02,

step:  6800, loss: 6.8594e+03, Y0: 3.6919e-02, cost: 1.7191e+02,  elapsed time: 113
step:  6900, loss: 6.8604e+03, Y0: 3.6919e-02, cost: 1.7193e+02,  elapsed time: 115
step:  7000, loss: 6.8580e+03, Y0: 3.6919e-02, cost: 1.7188e+02,  elapsed time: 116
step:  7100, loss: 6.8662e+03, Y0: 3.6919e-02, cost: 1.7207e+02,  elapsed time: 118
step:  7200, loss: 6.8612e+03, Y0: 3.6919e-02, cost: 1.7195e+02,  elapsed time: 119
step:  7300, loss: 6.8547e+03, Y0: 3.6919e-02, cost: 1.7180e+02,  elapsed time: 120
step:  7400, loss: 6.8546e+03, Y0: 3.6919e-02, cost: 1.7181e+02,  elapsed time: 122
step:  7500, loss: 6.8531e+03, Y0: 3.6919e-02, cost: 1.7178e+02,  elapsed time: 123
step:  7600, loss: 6.8550e+03, Y0: 3.6919e-02, cost: 1.7183e+02,  elapsed time: 125
step:  7700, loss: 6.8284e+03, Y0: 3.6919e-02, cost: 1.7121e+02,  elapsed time: 126
step:  7800, loss: 6.8309e+03, Y0: 3.6919e-02, cost: 1.7125e+02,  elapsed time: 127
step:  7900, loss: 6.8307e+03, Y0: 3.6919e-02, cost: 1.7123e+02,  elapsed ti

step:  6400, loss: 6.9890e+03, Y0: -6.6560e-01, cost: 1.7210e+02,  elapsed time: 108
step:  6500, loss: 6.9894e+03, Y0: -6.6560e-01, cost: 1.7210e+02,  elapsed time: 109
step:  6600, loss: 6.9819e+03, Y0: -6.6560e-01, cost: 1.7192e+02,  elapsed time: 111
step:  6700, loss: 6.9846e+03, Y0: -6.6560e-01, cost: 1.7193e+02,  elapsed time: 112
step:  6800, loss: 6.9899e+03, Y0: -6.6560e-01, cost: 1.7205e+02,  elapsed time: 113
step:  6900, loss: 6.9817e+03, Y0: -6.6560e-01, cost: 1.7188e+02,  elapsed time: 115
step:  7000, loss: 6.9765e+03, Y0: -6.6560e-01, cost: 1.7177e+02,  elapsed time: 116
step:  7100, loss: 6.9645e+03, Y0: -6.6560e-01, cost: 1.7150e+02,  elapsed time: 118
step:  7200, loss: 6.9630e+03, Y0: -6.6560e-01, cost: 1.7149e+02,  elapsed time: 119
step:  7300, loss: 6.9599e+03, Y0: -6.6560e-01, cost: 1.7137e+02,  elapsed time: 120
step:  7400, loss: 6.9539e+03, Y0: -6.6560e-01, cost: 1.7123e+02,  elapsed time: 122
step:  7500, loss: 6.9538e+03, Y0: -6.6560e-01, cost: 1.7124e+02,

step:  6000, loss: 7.0528e+03, Y0: 8.2611e-01, cost: 1.7438e+02,  elapsed time: 101
step:  6100, loss: 7.0522e+03, Y0: 8.2611e-01, cost: 1.7434e+02,  elapsed time: 102
step:  6200, loss: 7.0507e+03, Y0: 8.2611e-01, cost: 1.7432e+02,  elapsed time: 103
step:  6300, loss: 7.0577e+03, Y0: 8.2611e-01, cost: 1.7443e+02,  elapsed time: 105
step:  6400, loss: 7.0648e+03, Y0: 8.2611e-01, cost: 1.7457e+02,  elapsed time: 106
step:  6500, loss: 7.0717e+03, Y0: 8.2611e-01, cost: 1.7473e+02,  elapsed time: 108
step:  6600, loss: 7.0829e+03, Y0: 8.2611e-01, cost: 1.7494e+02,  elapsed time: 109
step:  6700, loss: 7.0781e+03, Y0: 8.2611e-01, cost: 1.7484e+02,  elapsed time: 110
step:  6800, loss: 7.0665e+03, Y0: 8.2611e-01, cost: 1.7458e+02,  elapsed time: 112
step:  6900, loss: 7.0667e+03, Y0: 8.2611e-01, cost: 1.7458e+02,  elapsed time: 113
step:  7000, loss: 7.0633e+03, Y0: 8.2611e-01, cost: 1.7450e+02,  elapsed time: 115
step:  7100, loss: 7.0810e+03, Y0: 8.2611e-01, cost: 1.7489e+02,  elapsed ti

In [None]:
Results = pd.read_csv('./LQ_data_001_d100.csv', index_col = 0, header = None)
Algo1 = pd.concat([Results[i] for i in range(3,13)], axis = 0 , ignore_index= True)
Algo1.plot()
plt.title("Evolution du cost")
plt.show()