## Parte 1: Eikonal del tutorial de PINNs

Supongamos que estamos analizando una región 2D recubierta por células del corazón. Algunas de ellas son marcapasos cardíacos, originando pulsos de propagación eléctrica hacia el resto del tejido. Usted dispone de múltiples electrodos en el dominio de observación, que le permiten determinar el tiempo en que una onda llega a cada uno de ellos.

Los tiempos de llegada normalizados pueden ser modelados mediante la ecuación Eikonal:

$$\sqrt{\nabla T \cdot \nabla T} = 1$$




In [None]:
import jax.numpy as np
from jax import random, grad, jit, vmap
import jax
from jax.scipy.optimize import minimize
import numpy as onp
from jax.example_libraries import optimizers
from functools import partial
from matplotlib import pyplot as plt
from pyDOE2 import lhs
key = random.PRNGKey(0)

In [None]:
def init_params(layers, key):

  Ws = [] # matriz de pesos
  bs = [] # vector de sesgos

  # inicialización de Glorot
  for i in range(len(layers) - 1):

    std_glorot = np.sqrt(2/(layers[i] + layers[i + 1]))
    key, subkey = random.split(key)
    Ws.append(random.normal(subkey, (layers[i], layers[i + 1]))*std_glorot)
    bs.append(np.zeros(layers[i + 1]))

  return [Ws, bs]

@jit
def forward_pass(H, params):

  Ws = params[0] # obtener pesos de la red
  bs = params[1] # obtener sesgos

  N_layers = len(Ws) # número total de capas (entrada + ocultas + salida)

  for i in range(N_layers - 1):

    H = np.matmul(H, Ws[i]) + bs[i] # pasar por aplicación lineal
    H = np.tanh(H) # activación tangente hiperbólica

  Y = np.matmul(H, Ws[-1]) + bs[-1] # pasar por capa de salida (sin función de activación)

  return Y

def create_grads():

  fp_wrapper = lambda x, params: forward_pass(x, params)[0] # output del paso forward
  du_dx = grad(fp_wrapper) # gradiente de jax

  du_dx_wrapper = lambda x, params: du_dx(x, params)[0] # output del gradiente
  du_dxx = grad(du_dx_wrapper)

  dU_dx = vmap(du_dx, in_axes = (0, None), out_axes=0) # vectorización de du_dx en eje 0
  dU_dxx = vmap(du_dxx, in_axes = (0, None), out_axes=0) # vectorización de du_dxx en eje 0

  return dU_dx, dU_dxx

grad_X, grad_XX = create_grads() # crear gradientes

@partial(jit, static_argnums=(0,)) # definir argumentos que serán tratados como estáticos

def step(loss, i, opt_state, X_batch, Y_batch, X_c, X_bd, X_bn):

    params = get_params(opt_state) # obtener parámetros
    g = grad(loss)(params, X_batch, Y_batch, X_c, X_bd, X_bn) # calcular gradiente de pérdida para el batch

    return opt_update(i, g, opt_state)

def train(loss, X, Y, X_c, opt_state, X_bd = None, X_bn = None, nIter = 10000):

    train_loss = []
    val_loss = []

    for it in range(nIter): # iterador de épocas

        opt_state = step(loss, it, opt_state, X, Y, X_c, X_bd, X_bn)

        if it % 100 == 0: # imprimir estados

            params = get_params(opt_state) # obtener parámetros
            train_loss_value = loss(params, X, Y, X_c, X_bd, X_bn) # ver pérdida en dicha época
            train_loss.append(train_loss_value) # agregar a lista

            to_print = "it %i, train loss = %e" % (it, train_loss_value)
            print(to_print)

    return get_params(opt_state), train_loss, val_loss

def train_bfgs(loss, params, x_train, u_train, x_c, x_bd = None, x_bn = None): # método bfgs

    def concat_params(params):

        flat_params, params_tree = jax.tree_util.tree_flatten(params)
        params_shape = [x.shape for x in flat_params]

        return np.concatenate([x.reshape(-1) for x in flat_params]), (params_tree, params_shape)

    def reconstruct_params(param_vector, params_shape):

        split_params = onp.split(param_vector, onp.cumsum([onp.prod(s) for s in params_shape[:-1]]))
        flat_params = [x.reshape(s) for x, s in zip(split_params, params_shape)]
        params = jax.tree_util.tree_unflatten(params_tree, flat_params)

        return params

    param_vector, (params_tree, params_shape) = concat_params(params)

    @jit
    def func(param_vector):
      params = reconstruct_params(param_vector, params_shape)

      return loss(params, x_train, u_train, x_c, x_bd, x_bn)

    results = minimize(func, param_vector, method = 'bfgs')

    print(results.fun)

    return reconstruct_params(results.x, params_shape)


In [None]:
def F(x, c = np.array([0.55,0.45])):
  return np.linalg.norm(x - c)

xs = ys = np.linspace(0,1,50)
Xm, Ym = np.meshgrid(xs, ys)
x_star = np.c_[Xm.ravel(), Ym.ravel()] # concatenar arreglos

origin = np.array([0.55,0.45])
N_train = 20

x_train = lhs(2, N_train, random_state = 1234) # muestras aleatorias en el dominio
u_train = vmap(F, in_axes = (0, None))(x_train,origin)[:,None]

u_star = vmap(F, in_axes = (0, None))(x_star,origin)[:,None]

plt.figure(dpi = 150)
plt.contourf(Xm, Ym, u_star.reshape(Xm.shape))
plt.colorbar()

plt.plot(x_train[:,0],x_train[:,1], "k.", label = "posición de electrodos")
plt.plot(origin[0], origin[1], "r.", label = "marcapasos")

plt.xlabel("x")
plt.ylabel("y")
plt.title("tiempos de llegada")

plt.axis("equal")
plt.legend()

In [None]:
# redefinir función de entrenamiento para generar puntos de colocación aleatorios en cada iteración

def train(loss, X, Y, X_c_shape, opt_state, X_bd = None, X_bn = None, nIter = 10000):

    train_loss = []
    val_loss = []

    for it in range(nIter):

        X_c = lhs(X_c_shape[1], X_c_shape[0])
        opt_state = step(loss, it, opt_state, X, Y, X_c, X_bd, X_bn)

        if it % 100 == 0:

            params = get_params(opt_state)
            train_loss_value = loss(params, X, Y, X_c, X_bd, X_bn)
            train_loss.append(train_loss_value)

            to_print = "it %i, train loss = %e" % (it, train_loss_value)

            print(to_print)

    return get_params(opt_state), train_loss, val_loss


In [None]:
layers = [2, 100, 50, 20, 1]

params = init_params(layers, key)

opt_init, opt_update, get_params = optimizers.adam(1e-3)
opt_state = opt_init(params)

X_c_shape = (200,2) # number of collocations points

@jit
def loss(params, X, U, Xc, X_bd, X_bn):
  MSE_data = np.average((forward_pass(X, params) - U)**2)
  MSE_PDE = np.average((np.linalg.norm(grad_X(Xc,params),axis = 1)-1)**2)
  return  MSE_data + MSE_PDE
params, train_loss, val_loss = train(loss,x_train, u_train, X_c_shape, opt_state, nIter = 10000)

In [None]:
u_pred = forward_pass(x_star, params)
pred_origin = x_star[np.argmin(u_pred)] #la predicción será dd llegue en tpo 0

plt.figure(dpi=150)
plt.contourf(Xm, Ym, u_pred.reshape(Xm.shape))

plt.plot(x_train[:,0], x_train[:,1], "k.")
plt.plot(origin[0], origin[1], "r.", label = "origen")
plt.plot(pred_origin[0], pred_origin[1], "b.", label = "origen pred.")

plt.title('tiempos de llegada predecidos')
plt.xlabel("x")
plt.ylabel("y")

plt.colorbar()
plt.axis('equal')
plt.legend()

print(f"error: {onp.linalg.norm(pred_origin - origin)}")

## Transformación del código a PyTorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad, Variable
import numpy as np
import matplotlib.pyplot as plt
from pyDOE2 import lhs

### Función de pérdida

In [None]:
def compute_pde_loss(model, x_c):
    x_c.requires_grad_(True)
    u_c = model(x_c)
    gradients = torch.autograd.grad(outputs=u_c, inputs=x_c, grad_outputs=torch.ones_like(u_c), create_graph=True)[0]
    mse_pde = torch.mean((torch.linalg.norm(gradients, dim=1) - 1)**2)  # Suponiendo la EDP: |grad(u)| = 1
    return mse_pde

### Función de entrenamiento

In [None]:
def train(model, optimizer, criterion, x_train, y_train, x_c, iterations=10000):
    model.train()
    for it in range(iterations):
        optimizer.zero_grad()
        u_pred = model(x_train)
        mse_data = criterion(u_pred, y_train)
        mse_pde = compute_pde_loss(model, x_c)

        loss = mse_data + mse_pde
        loss.backward()
        optimizer.step()

        if it % 100 == 0:
            print(f"Iteration {it}, Loss: {loss.item()}")


### Generación de datos
Datos aleatorios de la EDP para generar puntos de entrenamiento para la PINNs.

In [None]:
def generate_data():
    # Definir dominio
    xs = ys = np.linspace(0, 1, 50)
    Xm, Ym = np.meshgrid(xs, ys)
    x_star = np.hstack([Xm.ravel()[:, None], Ym.ravel()[:, None]])  # Dominio completo para predicción

    origin = np.array([0.55, 0.45])  # Origen de la propagación
    N_train = 20  # número de datos de entrenamiento

    # Generar puntos de entrenamiento usando LHS (Latin Hypercube Sampling)
    x_train = torch.tensor(lhs(2, N_train, criterion="center"), dtype=torch.float32)
    y_train = torch.sqrt((x_train[:, 0] - origin[0])**2 + (x_train[:, 1] - origin[1])**2).unsqueeze(1)

    return x_train, y_train, x_star, Xm, Ym, origin

x_train, y_train, x_star, Xm, Ym, origin = generate_data()


### Entrenamiento del modelo

In [None]:
def train(model, optimizer, criterion, x_train, y_train, x_c, iterations=10000):
    model.train()
    for it in range(iterations):
        optimizer.zero_grad()
        u_pred = model(x_train)
        mse_data = criterion(u_pred, y_train)
        mse_pde = compute_pde_loss(model, x_c)

        loss = mse_data + mse_pde
        loss.backward()
        optimizer.step()

        if it % 100 == 0:
            print(f"Iteration {it}, Loss: {loss.item()}")

# Datos de colocación para la EDP
x_c = torch.rand(200, 2, dtype=torch.float32)  # Simulando puntos de colocación

# Entrenar el modelo
model = init_params([2, 100, 50, 20, 1])
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
train(model, optimizer, criterion, x_train, y_train, x_c)

u_pred = model(torch.tensor(x_star, dtype=torch.float32)).detach().numpy().reshape(Xm.shape)
pred_origin = x_star[np.argmin(u_pred)]

plt.figure(dpi=150)
plt.contourf(Xm, Ym, u_pred)
plt.plot(x_train[:, 0], x_train[:, 1], "k.", label="Posición de electrodos")
plt.plot(origin[0], origin[1], "r.", label="Marcapasos")
plt.plot(pred_origin[0], pred_origin[1], "b.", label="Origen predicho")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Tiempos de llegada predecidos")
plt.colorbar()
plt.legend()
plt.axis('equal')
plt.show()
print(f"Error: {np.linalg.norm(pred_origin - origin)}")
