In [None]:
import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax import random
import optax
import numpy as np
import pickle
from scipy.stats import beta as beta_dist
alpha=2.7
beta=3.9

In [None]:
def init_network_params(sizes, key=random.PRNGKey(4)):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.uniform(w_key, (n, m), minval=1, maxval=3), scale * random.uniform(b_key, (n,), minval=1, maxval=3)

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def predict_(NN, u):
    activations = u.reshape((-1, ))
    activations = jnp.log(activations/(1-activations))
    for w, b in NN[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = jax.nn.sigmoid(outputs)
    final_w, final_b = NN[-1]
    logit = jnp.dot(final_w, activations) + final_b
    return logit.squeeze()

def predict(NN, u):
    return predict_(NN, u) + beta_dist.ppf(0.5, alpha, beta) - predict_(NN, np.array([0.5]))

vpredict = vmap(predict, in_axes=(None, 0))
gpredict = grad(predict, argnums=1)
vgpredict = jit(vmap(gpredict, in_axes=(None, 0)))

In [None]:
def pi(x):
    from scipy.special import gamma
    B_ab = (gamma(alpha) * gamma(beta)) / gamma(alpha + beta)
    pdf = (x**(alpha - 1)) * ((1 - x)**(beta - 1)) / B_ab
    return pdf
vpi = vmap(pi, in_axes=(0,))

@jit
def obj(NN, test_uz):
    def ell(NN, uz):
        u, z = uz
        return jnp.square(gpredict(NN, u)-1/pi(z))
    vell = vmap(ell, in_axes=(None, 0))

    result = vell(NN, test_uz)
    result = jnp.mean(result)
    return result

@jit
def constraint(NN, test_uz):
    u_batch, z_batch = test_uz[:, 0], test_uz[:, 1]
    epsilon = 1e-12
    def g(NN, u):
        return jnp.maximum(-gpredict(NN, u)+epsilon, 0)
    vg = vmap(g, in_axes=(None, 0))

    result = vg(NN, u_batch)
    result = jnp.sum(result)
    return result

gobj = jit(grad(obj))
gconst = jit(grad(constraint))

@jit
def evaluation(NN, test_uz):
    u_batch, z_batch = test_uz[:, 0], test_uz[:, 1]
    result = jnp.square(vpredict(NN, u_batch) - z_batch)
    result = jnp.mean(result)
    return result

In [None]:
# def gen_uz_batch(N=10000):
#     import jax.numpy as jnp
#     import jax.random as random
#     from jax.scipy.stats import beta

#     key = random.PRNGKey(0)
#     a = 2.7  # Alpha 参数
#     b = 3.9  # Beta 参数
#     z = random.beta(key, a=a, b=b, shape=(N,))
#     u = beta.cdf(z, a=a, b=b)
#     uz_batch = jnp.stack((u, z), axis=-1)
#     return uz_batch
# train_uz = gen_uz_batch(N=10000)
# np.save('/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/train_uz.npy', train_uz)

In [19]:
LAYER_SIZES = [1, 10, 10, 10, 1]
NN = init_network_params(LAYER_SIZES, random.PRNGKey(15))

optimizer= optax.adam(0.001)
opt_state = optimizer.init(NN)
bug = 0
co = 0
test_uz = np.load(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/test_uz.npy')
train_uz = np.load(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/train_uz.npy')
lowest = evaluation(NN, test_uz)
Lloss = []
Llowest = []
NN_best = NN
for epoch in range(2000000):
    while constraint(NN, train_uz).item()>0:
        grads = gconst(NN, train_uz)
        updates, opt_state = optimizer.update(grads, opt_state)
        NN = optax.apply_updates(NN, updates)
        co += 1

    grads = gobj(NN, train_uz)
    updates, opt_state = optimizer.update(grads, opt_state)
    NN = optax.apply_updates(NN, updates)

    Lloss.append(obj(NN, train_uz).item())
    Llowest.append(lowest.item())

    if  evaluation(NN, test_uz)<lowest:
        lowest = evaluation(NN, test_uz)
        NN_best = NN

    if epoch % 10000 == 0:
        print(f"Epoch: {epoch}, loss:{obj(NN, train_uz).item():.7f}, const:{constraint(NN, train_uz).item():.4f}, lowest:{lowest:.4f}, co:{co}, bug:{bug}")
        np.save(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/Lloss.npy', Lloss)
        np.save(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/Llowest.npy', Llowest)
        with open(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/NN (best).pkl', 'wb') as f:
            pickle.dump(NN_best, f)
        with open(f'/content/drive/MyDrive/No. 22 Physica D (inverse)/Beta checkpoints (PINN)/NN_{epoch}.pkl', 'wb') as f:
            pickle.dump(NN, f)

Epoch: 0, loss:7.2738004, const:0.0000, lowest:0.0314, co:0, bug:0
Epoch: 10000, loss:0.0000060, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 20000, loss:0.0000007, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 30000, loss:0.0000003, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 40000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 50000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 60000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 70000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 80000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 90000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 100000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 110000, loss:0.0000009, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 120000, loss:0.0000001, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: 130000, loss:0.0000038, const:0.0000, lowest:0.0000, co:0, bug:0
Epoch: