In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.stats.multivariate_normal import pdf as gaussian_pdf
from jax import jacfwd, jacrev
import optax
import numpy as np
import matplotlib.pyplot as plt

import pickle

from scipy.optimize import minimize

# 1. Problem Description

In [12]:
try:
    import dolfin
except ImportError:
    !wget "https://fem-on-colab.github.io/releases/fenics-install-real.sh" -O "/tmp/fenics-install.sh" && bash "/tmp/fenics-install.sh"
    import dolfin

import fenics as fcs
import numpy as np

def M(x):

    # Create Mesh
    nx = ny = 20
    mesh = fcs.UnitSquareMesh(nx, ny)
    V = fcs.FunctionSpace(mesh, 'P', 1)

    # Boundary condition
    u_D = fcs.Constant(0.0)

    def boundary(x, on_boundary):
        return on_boundary

    bc = fcs.DirichletBC(V, u_D, boundary)

    # Inital temparture
    u_initial = fcs.Expression('exp(-10*((x[0]-0.1)*(x[0]-0.1) + (x[1]-0.9)*(x[1]-0.9)))', degree=2)
    u_n = fcs.interpolate(u_initial, V)

    # Time and step
    dt = 1
    T = 5.00

    # Variational problem
    u = fcs.TrialFunction(V)
    v = fcs.TestFunction(V)

    # Heat source function
    f = fcs.Expression('t*exp(-3*((x[0]-0.9)*(x[0]-0.9) + (x[1]-0.1)*(x[1]-0.1)))', degree=2, t=0)

    # Weakform
    kappa = fcs.Constant(x)
    a = u*v*fcs.dx + dt*kappa*fcs.dot(fcs.grad(u), fcs.grad(v))*fcs.dx
    L = (u_n + dt*f)*v*fcs.dx

    # Solving
    u = fcs.Function(V)
    t = 0
    d = []
    while t < T:
        t += dt
        f.t = t
        fcs.solve(a == L, u, bc)
        u_n.assign(u)
        point1 = np.array([0.3, 0.3])
        point2 = np.array([0.3, 0.7])
        point3 = np.array([0.7, 0.3])
        point4 = np.array([0.7, 0.7])
        point5 = np.array([0.5, 0.5])
        d.append([u(point1), u(point2), u(point3), u(point4), u(point5)])

    return np.array(d)

--2024-07-29 03:29:16--  https://fem-on-colab.github.io/releases/fenics-install-real.sh
Resolving fem-on-colab.github.io (fem-on-colab.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to fem-on-colab.github.io (fem-on-colab.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4161 (4.1K) [application/x-sh]
Saving to: ‘/tmp/fenics-install.sh’


2024-07-29 03:29:16 (14.4 MB/s) - ‘/tmp/fenics-install.sh’ saved [4161/4161]

+ INSTALL_PREFIX=/usr/local
++ awk -F/ '{print NF-1}'
++ echo /usr/local
+ INSTALL_PREFIX_DEPTH=2
+ PROJECT_NAME=fem-on-colab
+ SHARE_PREFIX=/usr/local/share/fem-on-colab
+ FENICS_INSTALLED=/usr/local/share/fem-on-colab/fenics.installed
+ [[ ! -f /usr/local/share/fem-on-colab/fenics.installed ]]
+ PYBIND11_INSTALL_SCRIPT_PATH=https://github.com/fem-on-colab/fem-on-colab.github.io/raw/54d8555/releases/pybind11-install.sh
+ [[ https://github.com/fem-on-colab/fem-on-colab.github.io/raw/54d8555/r

In [13]:
d_obs = 0.98*M(4)
mu = 4
sig = 1

def updf(x):
    def likelihood(x):
        d = M(x)
        gap = d.reshape((-1, )) - d_obs.reshape((-1, ))
        value = gaussian_pdf(gap, np.zeros_like(gap), 0.00005*np.diag(np.ones_like(gap)))
        return value

    def prior(x):
        return gaussian_pdf(x, jnp.array([mu]), np.diag([sig**2]))

    x = jnp.array(x)
    value = likelihood(x) * prior(x)
    value = jnp.where(value==0, 1e-16, value)
    return value

vupdf = vmap(updf, in_axes=(0,))

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling element ffc_element_3801828c0f66b7190a7fd5819465b3d5b34b9149

INFO:FFC:Compiler stage 1: Analyzing element(s)
INFO:FFC:--------------------------------------
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.00482917 seconds.

INFO:FFC:Compiler stage 2: Computing intermediate representation
INFO:FFC:-------------------------------------------------------
INFO:FFC:  Computing representation of 1 elements
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 1 dofmaps
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 0 coordinate mappings
INFO:FFC:  Computing representation of integrals
INFO:FFC:  Computing representation of forms
INFO:FFC:  
INFO:FFC:Compiler stage 2 finished in 0.184969 seconds.

INFO:FFC:Compiler stage 3: Optimizing intermediate representation
INFO:FFC:----------------------------

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling form ffc_form_b83f29bb4190e202e832bb84952cd56a2e6b93dc

INFO:FFC:Compiler stage 1: Analyzing form(s)
INFO:FFC:-----------------------------------
DEBUG:FFC:  Preprocessing form using 'uflacs' representation family.
INFO:UFL_LEGACY:Adjusting missing element cell to triangle.
INFO:FFC:  
INFO:FFC:  Geometric dimension:       2
  Number of cell subdomains: 0
  Rank:                      1
  Arguments:                 '(v_0)'
  Number of coefficients:    2
  Coefficients:              '[f_8, f_11]'
  Unique elements:           'CG1(?,?), CG2(?,?), Vector<2 x CG1(?,?)>'
  Unique sub elements:       'CG1(?,?), CG2(?,?), Vector<2 x CG1(?,?)>'
  
INFO:FFC:  representation:    auto --> uflacs
INFO:FFC:  quadrature_rule:   auto --> default
INFO:FFC:  quadrature_degree: auto --> 3
INFO:FFC:  quadrature_degree: 3
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.0695245 seconds.

INFO:FFC:Compiler s

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling element ffc_element_4f750817ecc896f3bedcb4ff8c9f3352153b1b38

INFO:FFC:Compiler stage 1: Analyzing element(s)
INFO:FFC:--------------------------------------
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.00301981 seconds.

INFO:FFC:Compiler stage 2: Computing intermediate representation
INFO:FFC:-------------------------------------------------------
INFO:FFC:  Computing representation of 1 elements
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 1 dofmaps
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 0 coordinate mappings
INFO:FFC:  Computing representation of integrals
INFO:FFC:  Computing representation of forms
INFO:FFC:  
INFO:FFC:Compiler stage 2 finished in 0.027472 seconds.

INFO:FFC:Compiler stage 3: Optimizing intermediate representatio

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling element ffc_element_6ab56968c6ffa883272fd990bd40fad8bf858cca

INFO:FFC:Compiler stage 1: Analyzing element(s)
INFO:FFC:--------------------------------------
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.00467801 seconds.

INFO:FFC:Compiler stage 2: Computing intermediate representation
INFO:FFC:-------------------------------------------------------
INFO:FFC:  Computing representation of 1 elements
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 1 dofmaps
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 0 coordinate mappings
INFO:FFC:  Computing representation of integrals
INFO:FFC:  Computing representation of forms
INFO:FFC:  
INFO:FFC:Compiler stage 2 finished in 0.0193169 seconds.

INFO:FFC:Compiler stage 3: Optimizing intermediate representati

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling coordinate_mapping ffc_coordinate_mapping_3720490578293ae8ad5feabedc46584f48fda4c4

INFO:FFC:Compiler stage 1: Analyzing coordinate_mapping(s)
INFO:FFC:-------------------------------------------------
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.00489879 seconds.

INFO:FFC:Compiler stage 2: Computing intermediate representation
INFO:FFC:-------------------------------------------------------
INFO:FFC:  Computing representation of 0 elements
INFO:FFC:  Computing representation of 0 dofmaps
INFO:FFC:  Computing representation of 1 coordinate mappings
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of integrals
INFO:FFC:  Computing representation of forms
INFO:FFC:  
INFO:FFC:Compiler stage 2 finished in 0.0104163 seconds.

INFO:FFC:Compiler stage 3: Optimizing intermediate representation
INFO:FFC:--------------------------------------------------------
INFO

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling form ffc_form_fe3dc54b4940c1c866b424707ef49f437f802a54

INFO:FFC:Compiler stage 1: Analyzing form(s)
INFO:FFC:-----------------------------------
DEBUG:FFC:  Preprocessing form using 'uflacs' representation family.
INFO:UFL_LEGACY:Adjusting missing element cell to triangle.
INFO:FFC:  
INFO:FFC:  Geometric dimension:       2
  Number of cell subdomains: 0
  Rank:                      2
  Arguments:                 '(v_0, v_1)'
  Number of coefficients:    1
  Coefficients:              '[f_12]'
  Unique elements:           'CG1(?,?), R0(?,?), Vector<2 x CG1(?,?)>'
  Unique sub elements:       'CG1(?,?), R0(?,?), Vector<2 x CG1(?,?)>'
  
INFO:FFC:  representation:    auto --> uflacs
INFO:FFC:  quadrature_rule:   auto --> default
INFO:FFC:  quadrature_degree: auto --> 2
INFO:FFC:  quadrature_degree: 2
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.0586886 seconds.

INFO:FFC:Compiler sta

Calling FFC just-in-time (JIT) compiler, this may take some time.


Level 25:FFC:Calling FFC just-in-time (JIT) compiler, this may take some time.
INFO:FFC:Compiling element ffc_element_17d5bd7e022a45e71c9f390bd00f3f09885a1fd0

INFO:FFC:Compiler stage 1: Analyzing element(s)
INFO:FFC:--------------------------------------
INFO:FFC:  
INFO:FFC:Compiler stage 1 finished in 0.00365162 seconds.

INFO:FFC:Compiler stage 2: Computing intermediate representation
INFO:FFC:-------------------------------------------------------
INFO:FFC:  Computing representation of 1 elements
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 1 dofmaps
DEBUG:FFC:  Reusing element from cache
INFO:FFC:  Computing representation of 0 coordinate mappings
INFO:FFC:  Computing representation of integrals
INFO:FFC:  Computing representation of forms
INFO:FFC:  
INFO:FFC:Compiler stage 2 finished in 0.0141115 seconds.

INFO:FFC:Compiler stage 3: Optimizing intermediate representati

# 2. MCMC

In [14]:
def RWMCMC(iter, x_set=None):
    if x_set is None:
        x_curr = np.array([0.1])
        x_set = []
    else:
        x_set = np.array(x_set).tolist()
        x_curr = x_set[-1]

    try:
        for i in range(1, iter):
            x_next = np.random.multivariate_normal(x_curr, np.identity(1))

            alpha1 = updf(x_next.item())
            alpha2 = updf(x_curr.item())
            alpha = alpha1/alpha2
            alpha = np.min([1, alpha])
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            print(len(x_set), alpha, x_curr)

        return x_set
    except KeyboardInterrupt:
        return x_set


def AMCMC(iter):
    x_curr = np.array([0])
    x_set = []

    try:
        for i in range(0, iter):
            if i<=4:
                x_next = np.random.multivariate_normal(x_curr, 0.01*(np.identity(1)/2))
            else:
                COV = np.array(x_set)
                COV = np.cov(COV.T)
                x_next = (1-0.05)*np.random.multivariate_normal(x_curr, (2.38**2*COV/2)*np.identity(1)) + 0.05*np.random.multivariate_normal(x_curr, 0.01*(np.identity(1)/2))

            alpha1 = updf(x_next.item())
            alpha2 = updf(x_curr.item())
            alpha = alpha1/alpha2
            alpha = np.min([1, alpha])
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            print(len(x_set), alpha, x_curr)

        return x_set
    except KeyboardInterrupt:
        return x_set


def EMCEEMCMC(iter):
    x_curr = np.array([0.0])
    x_set = []
    FS = np.random.uniform(-1, 1, (iter, 1))

    try:
        for i in range(0, iter):
            full_range = np.arange(iter)
            left_range = full_range[full_range != i]
            x_next_index = np.random.choice(left_range)
            x_next = FS[x_next_index]

            r = np.random.rand()
            z = ((1 + r * (2 - 1))**2) / 2
            x_next = x_next + z * (x_curr - x_next)

            alpha = z * (updf(x_next.item())/ updf(x_curr.item()))
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            FS[i] = x_curr
            print(len(x_set), alpha, x_curr)
        return x_set
    except KeyboardInterrupt:
        return x_set

In [None]:
# x_RWMCMC = RWMCMC(100000)
# x_AMCMC = AMCMC(100000)
# x_EMCEEMCMC = EMCEEMCMC(100000)

# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E2/Data_RMCMC.npy', x_RWMCMC)
# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E2/Data_AMCMC.npy', x_AMCMC)
# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E2/Data_EMCMC.npy', x_EMCEEMCMC)

# 3. TM_polynomial

In [2]:
from itertools import product

def generate_J(n, i, p):
    combinations = product(range(p+1), repeat=i)
    valid_combinations = [comb for comb in combinations if sum(comb) <= p]
    j_vectors = [comb + (0,)*(n-i) for comb in valid_combinations]
    return jnp.array(j_vectors)

# @jit
def T1(param1, J1, x1):
    def uni_polynomial(p, ri):
        return jnp.where(p == 0, 1,
            jnp.where(p == 1, ri,
            jnp.where(p == 2, 0.5 * (3 * jnp.power(ri, 2) - 1),
            jnp.where(p == 3, 0.5 * (5 * jnp.power(ri, 3) - 3 * ri),
            jnp.where(p == 4, (1/8) * (35 * jnp.power(ri, 4) - 30 * jnp.power(ri, 2) + 3),
                                (1/8) * (63 * jnp.power(ri, 5) - 70 * jnp.power(ri, 3) + 15*ri)
            )))))
    vuni_polynomial = vmap(uni_polynomial, in_axes=(0, 0))

    def psi(Ji, r):
        result = vuni_polynomial(Ji, r)
        result = jnp.prod(result)
        return result
    vpsi = vmap(psi, in_axes=(0, None))

    x = jnp.array([x1]).reshape((-1, ))
    result = jnp.dot(param1, vpsi(J1, x))
    return result

gT1 = grad(T1, argnums=2)
vT1 = vmap(T1, in_axes=(None, None, 0))

In [3]:
# @jit
def obj1(param1, J1, x_batch):
    def obj_fun(param1, J1, x1):
        result = 0.5*T1(param1, J1, x1)**2 - jnp.log(gT1(param1, J1, x1))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0))

    result = vobj_fun(param1, J1, x_batch).mean()
    return result

# @jit
def obj_test1(param1, J1, x_batch):
    def obj_fun(param1, J1, x1):
        result = 0.5*T1(param1, J1, x1)**2 - jnp.log(gT1(param1, J1, x1))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0))

    result = vobj_fun(param1, J1, x_batch)
    return result

# @jit
def constrant1(param1, J1, x_batch):
    vgT1 = vmap(gT1, in_axes=(None, None, 0))
    value = vgT1(param1, J1, x_batch)
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

In [16]:
from scipy.optimize import minimize
def TM_train(param1, J1, x_set):
    x_set = np.array(x_set)

    index1 = jnp.where(np.isnan(obj_test1(param1, J1, x_set)))[0]
    x_set1 = np.delete(x_set, index1, axis=0)
    cons = {'type': 'ineq', 'fun': constrant1, 'args': (J1, x_set1)}
    solution1 = minimize(obj1, param1, args=(J1, x_set1), method='SLSQP', jac=grad(obj1), constraints=cons, options={'disp': True, 'maxiter': 1000})
    if solution1.success:
        param1 = solution1.x

    return param1


def TM_algo(order, x_batch):
    x_batch = jnp.array(x_batch)

    J1 = generate_J(n=1, i=1, p=order)

    param1 = 0.5*np.ones(len(J1))
    param1 = TM_train(param1, J1, x_batch)
    param1 = TM_train(param1, J1, x_batch)
    param1 = TM_train(param1, J1, x_batch)

    return param1

Data_AMCMC = np.load('/content/drive/MyDrive/No. 21 MCMC+TM/E2/Data_AMCMC.npy')
for i in [2500, 3000, 3500, 4000, 4500, 5000]:
    param1 = TM_algo(2, Data_AMCMC[:i])
    pickle.dump(param1, open(f'/content/drive/MyDrive/No. 21 MCMC+TM/E2/AMCMC_TMo{2}_{i}.pkl', 'wb'))

Optimization terminated successfully    (Exit mode 0)
            Current function value: -0.011614345945417881
            Iterations: 34
            Function evaluations: 63
            Gradient evaluations: 34
Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.0264666099101305
            Iterations: 9
            Function evaluations: 12
            Gradient evaluations: 9
Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.0264666099101305
            Iterations: 1
            Function evaluations: 12
            Gradient evaluations: 1
Optimization terminated successfully    (Exit mode 0)
            Current function value: -0.032528601586818695
            Iterations: 27
            Function evaluations: 51
            Gradient evaluations: 27
Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.007653668522834778
            Iterations: 1
            Function 

# 4. TM_NN

In [None]:
from sklearn.model_selection import train_test_split

def init_network_params(sizes, key=random.PRNGKey(3)):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.uniform(w_key, (n, m), minval=1, maxval=3), scale * random.uniform(b_key, (n,), minval=1, maxval=3)

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def TM_NN1(NN1, x1):
    activations = x1.reshape((-1, ))
    for w, b in NN1[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = jax.nn.sigmoid(outputs)
    final_w, final_b = NN1[-1]
    logit = jnp.dot(final_w, activations) + final_b
    return logit.squeeze()
gTM_NN1 = grad(TM_NN1, argnums=1)

@jit
def obj1(NN1, x_batch):
    def obj_fun(NN1, x1):
        result = 0.5*TM_NN1(NN1, x1)**2 - jnp.log(gTM_NN1(NN1, x1)+1e-12)
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, 0))

    result = vobj_fun(NN1, x_batch[:, 0]).mean()
    return result

@jit
def constraint1(NN1, x_batch):
    vgT1 = vmap(gTM_NN1, in_axes=(None, 0))
    value = vgT1(NN1, x_batch[:, 0])
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

In [None]:
Data_AMCMC = np.load('/content/drive/MyDrive/No. 21 MCMC+TM/E2/Data_AMCMC.npy')
for num_sample in [6000, 7000, 8000, 9000, 10000]:
    NN1 = init_network_params([1, 100, 100, 1])

    x_batch_train = np.array(Data_AMCMC[:num_sample])
    x_batch_test = np.array(Data_AMCMC)
    # x_batch_train, x_batch_test = train_test_split(x_batch, test_size=0.1, random_state=41)

    tobj1 = jit(grad(obj1))
    optimizer= optax.adam(0.0001)
    opt_state = optimizer.init(NN1)
    lowest_value = 10000
    for i in range(10000):
        if constraint1(NN1, x_batch_test)>0:
            grads = grad(constraint1)(NN1, x_batch_test)
        else:
            grads = tobj1(NN1, x_batch_train)
        updates, opt_state = optimizer.update(grads, opt_state)
        NN1 = optax.apply_updates(NN1, updates)
        print(f'Train loss: {i}, {obj1(NN1, x_batch_train):.4f}, {constraint1(NN1, x_batch_train)}   \
                Test loss: {obj1(NN1, x_batch_test):.4f}, {constraint1(NN1, x_batch_test)}     {lowest_value:.4f}')
        if obj1(NN1, x_batch_test)<lowest_value:
            lowest_value = obj1(NN1, x_batch_test)
            NN1_best = NN1

    pickle.dump(NN1_best, open(f'/content/drive/MyDrive/No. 21 MCMC+TM/E2/AMCMC_NNbest_{num_sample}.pkl', 'wb'))

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Train loss: 5000, -0.0640, 0.0                   Test loss: -0.0519, 0.0     -0.0519
Train loss: 5001, -0.0640, 0.0                   Test loss: -0.0519, 0.0     -0.0519
Train loss: 5002, -0.0641, 0.0                   Test loss: -0.0519, 0.0     -0.0519
Train loss: 5003, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0519
Train loss: 5004, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5005, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5006, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5007, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5008, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5009, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5010, -0.0641, 0.0                   Test loss: -0.0520, 0.0     -0.0520
Train loss: 5011, -0.064