In [1]:
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax import random
import optax
import numpy as np
import pickle
import matplotlib.pyplot as plt

In [4]:
test_uz = np.load('/content/drive/MyDrive/No. 22 Physica D (inverse)/Abstract checkpoints (TM)/test_uz.npy')
test_zu = test_uz[:, [1, 0]]
np.save('/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/test_zu.npy', test_zu)

In [5]:
def init_network_params(sizes, key=random.PRNGKey(4)):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.uniform(w_key, (n, m), minval=1, maxval=3), scale * random.uniform(b_key, (n,), minval=1, maxval=3)

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def predict(NN, x):
    activations = x.reshape((-1, ))
    for w, b in NN[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = jax.nn.sigmoid(outputs)
    final_w, final_b = NN[-1]
    logit = jnp.dot(final_w, activations) + final_b
    logit = 1/(1+jnp.exp(-logit))
    return logit.squeeze()

vpredict = vmap(predict, in_axes=(None, 0))
gpredict = grad(predict, argnums=1)
vgpredict = vmap(gpredict, in_axes=(None, 0))

In [6]:
def pi(x):
    result = jnp.exp(-x**4)
    result = result/1.812804958472487
    result = jnp.where(result==0, 1e-10, result)
    return result
vpi = vmap(pi, in_axes=(0,))

def pi_pre(NN, x):
    result = gpredict(NN, x)
    return result

@jit
def obj(NN, x_batch):
    def obj1(NN, x):
        return pi(x) - pi_pre(NN, x)
    vobj1 = vmap(obj1, in_axes=(None, 0))
    result = vobj1(NN, x_batch)
    result = jnp.square(result)
    result = jnp.mean(result)
    return result
jgo = jit(grad(obj))

@jit
def evaluation(NN, test_zu):
    z_batch, u_batch = test_zu[:, 0], test_zu[:, 1]
    result = jnp.square(vpredict(NN, z_batch) - u_batch)
    result = jnp.mean(result)
    return result

In [None]:
LAYER_SIZES = [1, 10, 10, 10, 1]
NN = init_network_params(LAYER_SIZES, random.PRNGKey(15))

optimizer= optax.adam(0.001)
opt_state = optimizer.init(NN)
test_zu = np.load('/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/test_zu.npy')
lowest = evaluation(NN, test_zu)
NN_best = NN
Lloss = []
Llowest = []
for epoch in range(2000000):
    x_batch = np.random.uniform(-1.5, 1.5, (10000, 1)) #1000 refers to batch size

    grads = grad(obj)(NN, x_batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    NN = optax.apply_updates(NN, updates)

    Lloss.append(obj(NN, x_batch).item())
    Llowest.append(lowest.item())
    print(f"Epoch: {epoch}, loss:{obj(NN, x_batch)}, {lowest}")

    if  evaluation(NN, test_zu)<lowest:
        lowest = evaluation(NN, test_zu)
        NN_best = NN

    if epoch % 10000 == 0:
        np.save(f'/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/Lloss.npy', Lloss)
        np.save(f'/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/Llowest.npy', Llowest)
        with open(f'/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/NN (best).pkl', 'wb') as f:
            pickle.dump(NN_best, f)
        with open(f'/content/drive/MyDrive/No. 22 Physica D (forward)/Abstract checkpoints (PINN)/NN_{epoch}.pkl', 'wb') as f:
            pickle.dump(NN, f)

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Epoch: 1850647, loss:2.86108331692958e-07, 5.494405819206349e-09
Epoch: 1850648, loss:2.3684664540724043e-07, 5.494405819206349e-09
Epoch: 1850649, loss:1.7149640996194648e-07, 5.494405819206349e-09
Epoch: 1850650, loss:1.0188622212581322e-07, 5.494405819206349e-09
Epoch: 1850651, loss:5.267450831070164e-08, 5.494405819206349e-09
Epoch: 1850652, loss:1.9767192327435623e-08, 5.494405819206349e-09
Epoch: 1850653, loss:6.40966968390444e-09, 5.494405819206349e-09
Epoch: 1850654, loss:8.271054063868633e-09, 5.494405819206349e-09
Epoch: 1850655, loss:1.960720119598136e-08, 5.494405819206349e-09
Epoch: 1850656, loss:3.2617595735473515e-08, 5.494405819206349e-09
Epoch: 1850657, loss:4.546004106487089e-08, 5.494405819206349e-09
Epoch: 1850658, loss:4.512491358354964e-08, 5.494405819206349e-09
Epoch: 1850659, loss:4.570227218891887e-08, 5.494405819206349e-09
Epoch: 1850660, loss:4.0665749878598945e-08, 5.494405819206349e-09
Epoch: 1850661, loss:3.67445061