In [None]:
import os 
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import sys
sys.path.append("..")

from make_dir import mkdir
from load_yaml import get_yaml

import models.uq_parity_net as solutions
import equations.uq_parity_eqn as equation

from uq_parity_dataset import Sampler
import parity_solver as solver 

import matplotlib.pyplot as plt
import time


In [None]:
# load config
current_path = os.path.abspath(".")
yaml_path = os.path.join(current_path, "uq_parity.yaml")
Config = get_yaml(yaml_path)

# load reference data
absolute_path = os.path.abspath("..")
ref_path = os.path.join(absolute_path, "data/uq_simulation.npz")
rho_mean = torch.mean(torch.Tensor(np.load(ref_path)["macro_frames"]), axis=0)

ref_rho = rho_mean[-1, :].to('cpu').reshape((-1, 1)) # shape: (100, 1)


In [None]:
time_dimension = Config["physical_config"]["time_dimension"]
space_dimension = Config["physical_config"]["space_dimension"]
velocity_dimension = Config["physical_config"]["velocity_dimension"]
uq_dimension = Config["physical_config"]["uq_dimension"]
rho_d_in = time_dimension + space_dimension + uq_dimension
layers_rho = Config["model_config"]["units_rho"]
rj_d_in = time_dimension + space_dimension + velocity_dimension  + uq_dimension
layers_r = Config["model_config"]["units_r"]
layers_j = Config["model_config"]["units_j"]

# build neural networks for rho, g
Model_rho = "solutions.Model_rho_" + \
    "{}".format(Config["model_config"]["neural_network_type"])
Model_rho = eval(Model_rho)

model_rho = Model_rho(input_size = rho_d_in, layers = layers_rho, output_size = 1)

Model_r = "solutions.Model_r_" + \
    "{}".format(Config["model_config"]["neural_network_type"])
Model_r = eval(Model_r)

model_r = Model_r(input_size = rj_d_in, layers = layers_r, output_size = 1)

Model_j = "solutions.Model_j_" + \
    "{}".format(Config["model_config"]["neural_network_type"])
Model_j = eval(Model_j)

model_j = Model_j(input_size = rj_d_in, layers = layers_j, output_size = 1)

device_ids = Config["model_config"]["device_ids"]
device = torch.device("cuda:{:d}".format(device_ids[0]) if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    model_rho = nn.DataParallel(model_rho, device_ids = device_ids)    
    model_r = nn.DataParallel(model_r, device_ids = device_ids)
    model_j = nn.DataParallel(model_j, device_ids = device_ids)

    
model_rho.to(device)
model_r.to(device)
model_j.to(device)


In [None]:
# number of paramerters
rho_param_num = sum(neural.numel() for neural in model_rho.parameters())
r_param_num = sum(neural.numel() for neural in model_r.parameters())
j_param_num = sum(neural.numel() for neural in model_j.parameters())
print("Number of paramerters for networks u is: {:6d}, {:6d} and {:6d}. ".format(rho_param_num, r_param_num, j_param_num))


In [None]:
solutions.Xavier_initi(model_rho)
solutions.Xavier_initi(model_r)
solutions.Xavier_initi(model_j)


In [None]:
# Set optimizer and learning rate decay
optimizer = optim.Adam([
    {'params': model_rho.parameters()},
    {'params': model_r.parameters()},
    {'params': model_j.parameters()},
],  lr=Config["model_config"]["lr"])

scheduler = lr_scheduler.StepLR(
    optimizer, Config["model_config"]["stage_num"], Config["model_config"]["decay_rate"])


In [None]:
Sol = model_rho, model_r, model_j
eqn = equation.Parity(config = Config, sol = Sol)

Iter = Config["model_config"]["iterations"] 
regularizers = Config["model_config"]["regularizers"]

loss_record, error_record = np.array([[]]).T, np.array([[]]*1).T

mkdir(file_dir = "./model_saved")
mkdir(file_dir = "./record")
mkdir(file_dir = "./figure")

time_start = time.time()
print('Begin training.')
print('')
for it in range(Iter):
    
    sampler = Sampler(Config)
    trainloader = [sampler.interior(), sampler.boundary(), sampler.initial()]
        
    risk, error = solver.train_step(sol = Sol,
                                    trainloader = trainloader, 
                                    equation = eqn,  
                                    regularizers = regularizers,
                                    optimizer = optimizer, 
                                    scheduler = scheduler,
                                    ref = ref_rho)
    
    loss = risk["total_loss"]
    res_parity_1_eqn = risk["parity_1"]
    res_parity_2_eqn = risk["parity_2"]
    res_claw_eqn = risk["conservation"]
    res_constraint_eqn = risk["soft_constraint"]
    res_bc_f = risk["bc_f"]
    res_ic_rho = risk["ic_rho"]
    res_ic_f = risk["ic_f"]
    error = error["error"] 

    error = np.array(error, dtype=float).reshape(1, -1)
    loss_record = np.concatenate((loss_record, loss*np.ones((1, 1))), axis=0)
    error_record = np.concatenate((error_record, error), axis=0)

    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    if it % 10 == 0:
    
        print("[Iter: {:6d}/{:6d} - lr: {:.2e} and Loss: {:.2e}]".format(it + 1, Iter, lr, loss))
        print("[Error for density: {:.2e}]".format(float(error[:, 0])))
        print("[Eqn parity_1: {:.2e}, parity_2: {:.2e}, claw: {:.2e}, constraint: {:.2e}]".format(res_parity_1_eqn, res_parity_1_eqn, res_claw_eqn, res_constraint_eqn))
        print("[Boundary: {:.2e}, Initial - rho {:.2e}, f {:.2e}]".format(res_bc_f, res_ic_rho, res_ic_f))

        
    if np.max(error) < 5e-2:
        print("Iteration step: ", it)
        break

np.savez("./record/result.npz",
         loss=loss_record,
         error=error_record[:, 0])

solutions.save_param(model_rho, path = './model_saved/model_rho_params.pkl')
solutions.save_param(model_r, path = './model_saved/model_r_params.pkl')
solutions.save_param(model_j, path = './model_saved/model_j_params.pkl')

print("")
print("Finished training.")
time_end = time.time()
print("Total time is: {:.2e}".format(time_end - time_start), "seconds")


In [None]:
solutions.load_param(model_rho, './model_saved/model_rho_params.pkl')
solutions.load_param(model_r, './model_saved/model_r_params.pkl')
solutions.load_param(model_j, './model_saved/model_j_params.pkl')

Sol = [model_rho, model_r, model_j]# # load model



In [None]:
model_rho, model_r, model_j = Sol

In [None]:
ref_rho_1 = rho_mean[-1, :].to('cpu').reshape((-1, 1)) # shape: (100, 1)
ref_rho_2 = rho_mean[-501, :].to('cpu').reshape((-1, 1)) # shape: (100, 1)

num_sim = 10**3
nx = 100
uq_dimension = Config["physical_config"]["uq_dimension"]
xmin, xmax = Config["physical_config"]["x_range"]
zmin, zmax = Config["physical_config"]["z_range"]
tmax = Config["physical_config"]["t_range"][-1]
ref_x = (torch.linspace(xmin, xmax, nx).reshape(-1, 1)).to(device)
ref_xx = ref_x.to("cpu")

ref_t_1 = tmax * torch.ones((nx, 1)).to(device)
ref_t_2 = tmax * torch.ones((nx, 1)).to(device) / 2


ref_z = zmin + torch.rand((nx, num_sim, uq_dimension)).to(device) * (zmax - zmin)
ref_x = ref_x[:,None,:] * torch.ones((ref_z.shape[0], ref_z.shape[1], 1)).to(device)

ref_t1 = ref_t_1[:,None,:] * torch.ones((ref_z.shape[0], ref_z.shape[1], 1)).to(device)
ref_t2 = ref_t_2[:,None,:] * torch.ones((ref_z.shape[0], ref_z.shape[1], 1)).to(device)

ref_zt1x = torch.cat([ref_z, ref_t1, ref_x], -1)
density_approx_1 = torch.mean((ref_t1*model_rho(ref_zt1x)).to("cpu"), dim=-2).detach().numpy() # shape: (nx, 1)
ref_zt2x = torch.cat([ref_z, ref_t2, ref_x], -1)
density_approx_2 = torch.mean((ref_t2*model_rho(ref_zt2x)).to("cpu"), dim=-2).detach().numpy() # shape: (nx, 1)

plt.style.use("seaborn-dark") 
fig = plt.figure()
plt.plot(ref_xx, ref_rho_2, color = 'b', linewidth = 1.0, markersize = 5, label = 'Ref(t = 0.05)')
plt.plot(ref_xx[2:], density_approx_2[2:], color = 'g', marker = 'x', linewidth = 0.0, markersize = 8, markevery = 4, label = 'APNNs(t = 0.05)')
plt.plot(ref_xx, ref_rho_1, color = 'r', linewidth = 1.0, markersize = 5, label = 'Ref(t = 0.1)')
plt.plot(ref_xx[2:], density_approx_1[2:], color = 'k', marker = '*', linewidth = 0.0, markersize = 8, markevery = 4, label = 'APNNs(t = 0.1)')
plt.xlabel(r"x")
plt.ylabel(r"$\rho$")
plt.grid()
plt.legend()
plt.savefig('./figure/uq_1e-5.eps')
plt.savefig('./figure/uq_1e-5.pdf')
plt.show()

In [None]:
# # jupyter notebook to python
# try:   
#     !jupyter nbconvert --to python uq_parity_example.ipynb
# except:
#     pass