# INIT

In [1]:
from jax import vmap
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import time

from equinox_module import training_MODEL, create_FNN, save_MODEL
import platform

system = platform.system()
machine = platform.machine().lower()

jax.config.update("jax_platform_name", "cpu")

In [2]:
SEED = 111
INPUT_DIM = 2 #(x,t)
OUTPUT_DIM = 1 # u
WIDTH = 20
DEPTH = 9
ACTIVATION_FN = 'tanh'
NU = 0.01/jnp.pi
# sine, cosine, relu, tanh, gelu, swish

In [3]:
def u_st(x,t):
    return -1 * jnp.sin(jnp.pi * x)

def f_physics(x,t): # no forcing term
    return 0.0
    
def g_BC(xt):
    return u_st(xt[0],xt[1])

def g_IC(xt):
    return u_st(xt[0],xt[1])

def loss_physics(model):
    u     = lambda x,t: model(jnp.stack([x, t]))[0]
    dx_u  = lambda x,t: jax.grad(u,argnums=0)(x,t)
    dxx_u = lambda x,t: jax.grad(dx_u,argnums=0)(x,t)
    dt_u  = lambda x,t: jax.grad(u,argnums=1)(x,t)
    eq    = lambda xt: dt_u(xt[0],xt[1]) + u(xt[0], xt[1]) * dx_u(xt[0], xt[1]) - NU * dxx_u(xt[0], xt[1]) # swapped for burger's eqn
    return eq



In [4]:
LEARNING_RATE = 4e-4
OPTIMIZER_NAME = 'adam'
# adabelief, adadelta, adan, adafactor, adagrad, adam, adamw,
# adamax, adamaxw, amsgrad, lion, nadam, nadamw, novograd, radam,
# rmsprop, sgd, sm3, yogi, polyak_sgd
MAXITER = 16000
PRINT_EVERY = 100

LR_SCHEDULER = 'constant' # exponential or constant
LR_DECAY = 0.90
LR_STEP  = 1000

LBFGS_USE = 'on'
LBFGS_MAXITER = 4000
LBFGS_PRINT_EVERY = 100
if LBFGS_USE != 'on':
    LBFGS_MAXITER = 0
    LBFGS_PRINT_EVERY = 0

HYPER_MODEL = {"input_dim": INPUT_DIM, 
               "output_dim": OUTPUT_DIM, 
               "width": WIDTH, 
               "depth": DEPTH-1, 
               "act_func": ACTIVATION_FN}

HYPER_OPTIM = {"MAXITER": MAXITER, 
               "NAME": OPTIMIZER_NAME, 
               "LEARNING_RATE_SCHEDULER": LR_SCHEDULER, 
               "LEARNING_RATE_INITIAL": LEARNING_RATE, 
               "LEARNING_RATE_DECAY": LR_DECAY, 
               "LEARNING_RATE_STEP": LR_STEP, 
               "PRINT_EVERY": PRINT_EVERY,
               "LBFGS": {"USE": LBFGS_USE, "MAXITER": LBFGS_MAXITER, "PRINT_EVERY": LBFGS_PRINT_EVERY}
               }

# MODELS

## 2000, UNIFORM

In [5]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_2000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
u_2000_model, u_2000_log_loss, u_2000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

I0000 00:00:1760535922.134398 77546314 service.cc:145] XLA service 0x11374d830 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1760535922.134422 77546314 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1760535922.138289 77546314 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1760535922.138313 77546314 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 4.9982e-01, minloss = 4.9982e-01, Time: 0.89s
Epoch  100: loss = 3.5443e-01, minloss = 3.5443e-01, Time: 0.88s
Epoch  200: loss = 2.6916e-01, minloss = 2.6916e-01, Time: 0.85s
Epoch  300: loss = 1.4407e-01, minloss = 1.4407e-01, Time: 0.85s
Epoch  400: loss = 1.2942e-01, minloss = 1.2942e-01, Time: 0.85s
Epoch  500: loss = 1.1088e-01, minloss = 1.1088e-01, Time: 0.85s
Epoch  600: loss = 9.5982e-02, minloss = 9.5982e-02, Time: 0.85s
Epoch  700: loss = 8.7346e-02, minloss = 8.7346e-02, Time: 0.85s
Epoch  800: loss = 8.1500e-02, minloss = 8.1385e-02, Time: 0.85s
Epoch  900: loss = 7.2335e-02, minloss = 7.2335e-02, Time: 0.85s
Epoch 1000: loss = 5.6708e-02,

## 4000, UNIFORM

In [6]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_4000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
u_4000_model, u_4000_log_loss, u_4000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 5.5946e-01, minloss = 5.5946e-01, Time: 0.97s
Epoch  100: loss = 3.9042e-01, minloss = 3.9042e-01, Time: 0.82s
Epoch  200: loss = 2.9876e-01, minloss = 2.9876e-01, Time: 0.81s
Epoch  300: loss = 1.5590e-01, minloss = 1.5590e-01, Time: 0.82s
Epoch  400: loss = 1.4198e-01, minloss = 1.4198e-01, Time: 0.82s
Epoch  500: loss = 1.2247e-01, minloss = 1.2247e-01, Time: 0.81s
Epoch  600: loss = 1.0532e-01, minloss = 1.0532e-01, Time: 0.81s
Epoch  700: loss = 9.9770e-02, minloss = 9.9770e-02, Time: 0.81s
Epoch  800: loss = 9.5174e-02, minloss = 9.5174e-02, Time: 0.81s
Epoch  900: loss = 8.8723e-02, minloss = 8.8723e-02, Time: 0.81s
Epoch 1000: loss = 7.7678e-02, minloss = 7.7678e-02, Time: 0.82s
Epoch 1100: loss = 5.2649e-02, minloss = 5.264

## 6000, UNIFORM

In [7]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_6000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
u_6000_model, u_6000_log_loss, u_6000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 5.1897e-01, minloss = 5.1897e-01, Time: 0.88s
Epoch  100: loss = 3.6392e-01, minloss = 3.6392e-01, Time: 1.13s
Epoch  200: loss = 2.8079e-01, minloss = 2.8079e-01, Time: 1.14s
Epoch  300: loss = 1.4789e-01, minloss = 1.4789e-01, Time: 1.15s
Epoch  400: loss = 1.3110e-01, minloss = 1.3110e-01, Time: 1.15s
Epoch  500: loss = 1.1150e-01, minloss = 1.1150e-01, Time: 1.16s
Epoch  600: loss = 9.6904e-02, minloss = 9.6904e-02, Time: 1.15s
Epoch  700: loss = 8.8281e-02, minloss = 8.8281e-02, Time: 1.15s
Epoch  800: loss = 8.2827e-02, minloss = 8.2827e-02, Time: 1.15s
Epoch  900: loss = 7.7326e-02, minloss = 7.7326e-02, Time: 1.16s
Epoch 1000: loss = 6.8460e-02, minloss = 6.8460e-02, Time: 1.15s
Epoch 1100: loss = 5.2163e-02, minloss = 5.216

## 2000, MESH

In [8]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_mesh_2000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
m_2000_model, m_2000_log_loss, m_2000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 4.8741e-01, minloss = 4.8741e-01, Time: 0.74s
Epoch  100: loss = 3.6114e-01, minloss = 3.6114e-01, Time: 0.84s
Epoch  200: loss = 2.5299e-01, minloss = 2.5299e-01, Time: 0.85s
Epoch  300: loss = 1.4717e-01, minloss = 1.4717e-01, Time: 0.84s
Epoch  400: loss = 1.3121e-01, minloss = 1.3121e-01, Time: 0.84s
Epoch  500: loss = 1.1963e-01, minloss = 1.1963e-01, Time: 0.83s
Epoch  600: loss = 1.0412e-01, minloss = 1.0412e-01, Time: 0.83s
Epoch  700: loss = 9.5052e-02, minloss = 9.5052e-02, Time: 0.83s
Epoch  800: loss = 8.4791e-02, minloss = 8.4791e-02, Time: 0.83s
Epoch  900: loss = 6.0666e-02, minloss = 6.0666e-02, Time: 0.83s
Epoch 1000: loss = 5.2414e-02, minloss = 4.9868e-02, Time: 0.83s
Epoch 1100: loss = 4.6755e-02, minloss = 4.446

## 4000, MESH

In [9]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_mesh_4000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
m_4000_model, m_4000_log_loss, m_4000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 4.3274e-01, minloss = 4.3274e-01, Time: 0.76s
Epoch  100: loss = 3.2423e-01, minloss = 3.2423e-01, Time: 0.81s
Epoch  200: loss = 2.4405e-01, minloss = 2.4405e-01, Time: 0.80s
Epoch  300: loss = 1.3075e-01, minloss = 1.3075e-01, Time: 0.81s
Epoch  400: loss = 1.1369e-01, minloss = 1.1369e-01, Time: 0.83s
Epoch  500: loss = 9.9317e-02, minloss = 9.9317e-02, Time: 0.81s
Epoch  600: loss = 8.7867e-02, minloss = 8.7867e-02, Time: 0.80s
Epoch  700: loss = 8.0812e-02, minloss = 8.0812e-02, Time: 0.81s
Epoch  800: loss = 6.9593e-02, minloss = 6.9593e-02, Time: 0.81s
Epoch  900: loss = 5.3115e-02, minloss = 5.2555e-02, Time: 0.80s
Epoch 1000: loss = 5.2930e-02, minloss = 4.7188e-02, Time: 0.80s
Epoch 1100: loss = 4.3324e-02, minloss = 4.332

## 6000, MESH

In [10]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/uniform_mesh_6000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
m_6000_model, m_6000_log_loss, m_6000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 5.6671e-01, minloss = 5.6671e-01, Time: 0.81s
Epoch  100: loss = 3.9074e-01, minloss = 3.9074e-01, Time: 1.13s
Epoch  200: loss = 3.1506e-01, minloss = 3.1506e-01, Time: 1.16s
Epoch  300: loss = 1.8804e-01, minloss = 1.8804e-01, Time: 1.14s
Epoch  400: loss = 1.3898e-01, minloss = 1.3898e-01, Time: 1.17s
Epoch  500: loss = 1.1875e-01, minloss = 1.1875e-01, Time: 1.13s
Epoch  600: loss = 1.0232e-01, minloss = 1.0232e-01, Time: 1.12s
Epoch  700: loss = 9.3240e-02, minloss = 9.3240e-02, Time: 1.14s
Epoch  800: loss = 8.5934e-02, minloss = 8.5934e-02, Time: 1.13s
Epoch  900: loss = 7.4546e-02, minloss = 7.4322e-02, Time: 1.13s
Epoch 1000: loss = 4.9797e-02, minloss = 4.9797e-02, Time: 1.14s
Epoch 1100: loss = 3.9175e-02, minloss = 3.917

## 2000, LHS

In [11]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/LHS_2000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
l_2000_model, l_2000_log_loss, l_2000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 5.5695e-01, minloss = 5.5695e-01, Time: 0.80s
Epoch  100: loss = 3.9291e-01, minloss = 3.9291e-01, Time: 0.86s
Epoch  200: loss = 2.8823e-01, minloss = 2.8823e-01, Time: 0.85s
Epoch  300: loss = 1.5618e-01, minloss = 1.5618e-01, Time: 0.85s
Epoch  400: loss = 1.3546e-01, minloss = 1.3546e-01, Time: 0.85s
Epoch  500: loss = 1.2423e-01, minloss = 1.2423e-01, Time: 0.85s
Epoch  600: loss = 1.1104e-01, minloss = 1.1104e-01, Time: 0.85s
Epoch  700: loss = 9.7774e-02, minloss = 9.7774e-02, Time: 0.85s
Epoch  800: loss = 8.8640e-02, minloss = 8.8640e-02, Time: 0.85s
Epoch  900: loss = 7.3783e-02, minloss = 7.3783e-02, Time: 0.85s
Epoch 1000: loss = 4.8099e-02, minloss = 4.8099e-02, Time: 0.85s
Epoch 1100: loss = 3.5991e-02, minloss = 3.599

## 4000, LHS

In [12]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/LHS_4000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
l_4000_model, l_4000_log_loss, l_4000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 4.7875e-01, minloss = 4.7875e-01, Time: 1.29s
Epoch  100: loss = 3.3970e-01, minloss = 3.3970e-01, Time: 0.82s
Epoch  200: loss = 2.5497e-01, minloss = 2.5497e-01, Time: 0.83s
Epoch  300: loss = 1.3691e-01, minloss = 1.3691e-01, Time: 0.82s
Epoch  400: loss = 1.2278e-01, minloss = 1.2278e-01, Time: 0.83s
Epoch  500: loss = 1.0721e-01, minloss = 1.0721e-01, Time: 0.83s
Epoch  600: loss = 9.5762e-02, minloss = 9.5762e-02, Time: 0.87s
Epoch  700: loss = 8.8312e-02, minloss = 8.8312e-02, Time: 0.84s
Epoch  800: loss = 8.0403e-02, minloss = 8.0403e-02, Time: 0.83s
Epoch  900: loss = 6.7882e-02, minloss = 6.7882e-02, Time: 0.83s
Epoch 1000: loss = 5.2646e-02, minloss = 5.2646e-02, Time: 0.83s
Epoch 1100: loss = 4.3341e-02, minloss = 4.221

## 6000, LHS

In [13]:
ds = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw3/data/LHS_6000.npz')
xt_re, xt_bc, xt_ic = ds["xt_re"], ds["xt_bc"], ds["xt_ic"]

def loss_fn(model):
    eq = loss_physics(model)
    residual = vmap(eq)(xt_re)
    bc = jax.vmap(model)(xt_bc) - jax.vmap(g_BC)(xt_bc)[:,None]
    ic = jax.vmap(model)(xt_ic) - jax.vmap(g_IC)(xt_ic)[:,None]
    return jnp.mean(residual**2) + jnp.mean(bc**2) + jnp.mean(ic**2)

key = jr.PRNGKey(SEED)
key, train_key = jr.split(key, num=2)
model = create_FNN(key=train_key, **HYPER_MODEL)

start_time = time.time()
l_6000_model, l_6000_log_loss, l_6000_log_minloss = training_MODEL(model, loss_fn, HYPER_OPTIM)
end_time = time.time()

Selected Optimizer is [ adam ], Initial Learning Rate is 4.00e-04
You are using [ constant ] learning rate
-------------------------TRAINING STARTS-------------------------
-----------------------------------------------------------------
Epoch    0: loss = 4.7854e-01, minloss = 4.7854e-01, Time: 0.87s
Epoch  100: loss = 3.4689e-01, minloss = 3.4689e-01, Time: 1.17s
Epoch  200: loss = 2.6812e-01, minloss = 2.6812e-01, Time: 1.15s
Epoch  300: loss = 1.4248e-01, minloss = 1.4248e-01, Time: 1.14s
Epoch  400: loss = 1.2680e-01, minloss = 1.2680e-01, Time: 1.16s
Epoch  500: loss = 1.1091e-01, minloss = 1.1091e-01, Time: 1.17s
Epoch  600: loss = 9.8112e-02, minloss = 9.8112e-02, Time: 1.15s
Epoch  700: loss = 8.5764e-02, minloss = 8.5764e-02, Time: 1.15s
Epoch  800: loss = 8.0413e-02, minloss = 8.0413e-02, Time: 1.15s
Epoch  900: loss = 7.3805e-02, minloss = 7.3805e-02, Time: 1.14s
Epoch 1000: loss = 5.6813e-02, minloss = 5.6813e-02, Time: 1.13s
Epoch 1100: loss = 4.1277e-02, minloss = 4.127

# EXPORT

In [14]:
# %% [markdown]
# # EXPORT ALL LEARNING TRAJECTORIES

# %%
export_data = {
    # Uniform sampling
    "u_2000_log_loss": np.array(u_2000_log_loss),
    "u_2000_log_minloss": np.array(u_2000_log_minloss),
    "u_4000_log_loss": np.array(u_4000_log_loss),
    "u_4000_log_minloss": np.array(u_4000_log_minloss),
    "u_6000_log_loss": np.array(u_6000_log_loss),
    "u_6000_log_minloss": np.array(u_6000_log_minloss),

    # Uniform mesh sampling
    "m_2000_log_loss": np.array(m_2000_log_loss),
    "m_2000_log_minloss": np.array(m_2000_log_minloss),
    "m_4000_log_loss": np.array(m_4000_log_loss),
    "m_4000_log_minloss": np.array(m_4000_log_minloss),
    "m_6000_log_loss": np.array(m_6000_log_loss),
    "m_6000_log_minloss": np.array(m_6000_log_minloss),

    # Latin hypercube sampling
    "l_2000_log_loss": np.array(l_2000_log_loss),
    "l_2000_log_minloss": np.array(l_2000_log_minloss),
    "l_4000_log_loss": np.array(l_4000_log_loss),
    "l_4000_log_minloss": np.array(l_4000_log_minloss),
    "l_6000_log_loss": np.array(l_6000_log_loss),
    "l_6000_log_minloss": np.array(l_6000_log_minloss),
}

# Create a timestamped filename for clarity
timestamp = time.strftime("%Y%m%d_%H%M%S")
export_filename = f"learning_trajectories_{timestamp}.npz"

np.savez(export_filename, **export_data)
print(f"✅ Saved all learning trajectories to {export_filename}")


✅ Saved all learning trajectories to learning_trajectories_20251015_102417.npz
