In [1]:
import numpy as np
import matplotlib.pyplot as plt

coll_points = np.load("./data/sampling_points.npz")["coll_points"]
bc1_points = np.load("./data/sampling_points.npz")["bc1_points"]
bc2_points = np.load("./data/sampling_points.npz")["bc2_points"]
bc3_points = np.load("./data/sampling_points.npz")["bc3_points"]
bc4_points = np.load("./data/sampling_points.npz")["bc4_points"]

In [2]:
import jax.numpy as jnp

coll_points = jnp.array(coll_points)
bc1_points = jnp.array(bc1_points)
bc2_points = jnp.array(bc2_points)
bc3_points = jnp.array(bc3_points)
bc4_points = jnp.array(bc4_points)

In [6]:
import jax
import jax.numpy as jnp
from jax import random, jit, grad, value_and_grad
import optax
from functools import partial
import time

# ------------------------------
# 모델 정의 (이미 작성한 mlp_jax 사용)
import mlp_jax

key = random.PRNGKey(42)
layer_sizes = [2, 5, 5, 1]
params = mlp_jax.init_mlp_params(layer_sizes, key)

# ------------------------------
# Loss 함수 정의 (PINN 예제)

# PDE Loss 예시: (Laplace PDE)
def pde_residual(params, x):
    u = lambda x: mlp_jax.mlp_forward(params, x, activation="tanh").squeeze()
    grad_u = jax.grad(u)
    grad_u_x = lambda x: grad_u(x)[0]
    grad_u_y = lambda x: grad_u(x)[1]

    u_xx = jax.grad(grad_u_x)
    u_yy = jax.grad(grad_u_y)

    return u_xx(x)[0] + u_yy(x)[1]

# Losses (각 조건에 맞게 수정 필요)
@jit
def total_loss(params, coll_points, bc1, bc2, bc3, bc4):
    pde_loss = jnp.mean(jax.vmap(lambda x: pde_residual(params, x)**2)(coll_points))
    bc1_loss = jnp.mean((mlp_jax.mlp_forward(params, bc1) - 0)**2)
    bc2_loss = jnp.mean((mlp_jax.mlp_forward(params, bc2) - 1)**2)
    bc3_loss = jnp.mean((mlp_jax.mlp_forward(params, bc3) - 0)**2)
    bc4_loss = jnp.mean((mlp_jax.mlp_forward(params, bc4) - 0)**2)

    return pde_loss + bc1_loss + bc2_loss + bc3_loss + bc4_loss

# ------------------------------
# 최적화 설정 (Optax 사용 권장)

learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# ------------------------------
# 학습 루프 정의 (JIT 컴파일)

@jit
def train_step(params, opt_state, coll_points, bc1, bc2, bc3, bc4):
    loss_val, grads = value_and_grad(total_loss)(params, coll_points, bc1, bc2, bc3, bc4)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

# ------------------------------
# Training 실행 루프

epochs = 20000
loss_history = []

start_time = time.time()
print("Training start...")

for epoch in range(1, epochs + 1):
    params, opt_state, loss_val = train_step(params, opt_state, coll_points, bc1_points, bc2_points, bc3_points, bc4_points)
    
    loss_history.append(loss_val.item())
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch:05d}: Total Loss = {loss_val:.2e}")

end_time = time.time()
elapsed_time = end_time - start_time
print("Training complete!\n")
print(f"Total training time: {elapsed_time:.2f} seconds")

# ------------------------------
# 결과 저장 (numpy를 사용하여 저장)
import numpy as np
np.save("./results/loss_data_jax.npy", np.array(loss_history))

# 모델 파라미터 저장
import pickle
with open("./results/model_jax.pkl", "wb") as f:
    pickle.dump(params, f)

Training start...
Epoch 00100: Total Loss = 4.11e+00
Epoch 00200: Total Loss = 1.06e+00
Epoch 00300: Total Loss = 7.75e-01
Epoch 00400: Total Loss = 6.91e-01
Epoch 00500: Total Loss = 6.54e-01
Epoch 00600: Total Loss = 6.30e-01
Epoch 00700: Total Loss = 6.10e-01
Epoch 00800: Total Loss = 5.91e-01
Epoch 00900: Total Loss = 5.72e-01
Epoch 01000: Total Loss = 5.52e-01
Epoch 01100: Total Loss = 5.33e-01
Epoch 01200: Total Loss = 5.14e-01
Epoch 01300: Total Loss = 4.96e-01
Epoch 01400: Total Loss = 4.80e-01
Epoch 01500: Total Loss = 4.66e-01
Epoch 01600: Total Loss = 4.54e-01
Epoch 01700: Total Loss = 4.43e-01
Epoch 01800: Total Loss = 4.33e-01
Epoch 01900: Total Loss = 4.25e-01
Epoch 02000: Total Loss = 4.17e-01
Epoch 02100: Total Loss = 4.09e-01
Epoch 02200: Total Loss = 4.02e-01
Epoch 02300: Total Loss = 3.96e-01
Epoch 02400: Total Loss = 3.90e-01
Epoch 02500: Total Loss = 3.85e-01
Epoch 02600: Total Loss = 3.80e-01
Epoch 02700: Total Loss = 3.75e-01
Epoch 02800: Total Loss = 3.71e-01
Ep