# wave with periodic bc with the scheme from Palma


In [15]:
using ComponentArrays
using Distributions
using GLMakie
using Lux
using LuxCUDA
using OptimizationOptimJL
using Random
using UnPack
using Zygote
using CairoMakie  # Backend que funciona en notebooks
using StatsBase
using Revise
import NaNMath



In [16]:
includet("neural_tools.jl")

In [17]:
# -------------------------------------------------------------------
# Configuración
# -------------------------------------------------------------------
config = Dict(
    :N_input => 2,          # [x; t]
    :N_neurons => 20,
    :N_layers => 3,
    :N_output => 1, 
    :N_points => 15000,     # puntos de colisión (x,t)
    :xmin => 0.0,
    :xmax => 5.0,           # = L dominio espacial
    :tmin => 0.0,           # t_min
    :tmax => 10.0,           # t_max
    :optimizer => BFGS(),
    :maxiters => 3000,
    :A => 0.010,
    :x0 => 2.0,
    :x1 => 3.0,
    :p => 2,
    :c => 1.0
)

Dict{Symbol, Any} with 16 entries:
  :tmax      => 10.0
  :maxiters  => 3000
  :p         => 2
  :N_input   => 2
  :c         => 1.0
  :N_layers  => 3
  :N_points  => 15000
  :N_neurons => 20
  :xmin      => 0.0
  :x0        => 2.0
  :xmax      => 5.0
  :A         => 0.01
  :N_output  => 1
  :tmin      => 0.0
  :x1        => 3.0
  :optimizer => BFGS{InitialStatic{Float64}, HagerZhang{Float64, RefValue{Bool}…

In [18]:
# -------------------------------------------------------------------
# Ecuación (onda 1D): u_tt - c^2 u_xx = 0
# -------------------------------------------------------------------
wave_equation(∂2u_∂x2, ∂2u_∂t2, c) = ∂2u_∂t2 .- (c^2) .* ∂2u_∂x2

# -------------------------------------------------------------------
# Pérdida: solo residual PDE (hard enforcement ya fija IC/periodicidad)
# -------------------------------------------------------------------
function loss_function(input, NN, Θ, st)
    x, t = input[1:1, :], input[2:2, :]
    _, u_xx, u_tt = calculate_derivatives_Dirichlet(x, t, NN, Θ, st)
    res = wave_equation(u_xx, u_tt, config[:c])
    #return NaNMath.log10(sum(abs2, res) / length(res))
    return log10(sum(abs2, res) / length(res))
end

# -------------------------------------------------------------------
# Callback
# -------------------------------------------------------------------
function callback(p, l, losses)
    push!(losses, l)
    println("Current loss: ", l)
    return false
end

callback (generic function with 1 method)

In [19]:
# -------------------------------------------------------------------
# Entrenamiento
# -------------------------------------------------------------------
losses = Float64[]

NN, Θ, st = create_neural_network(config)
input = generate_input_x_t(config)

@show typeof(input) size(input)



typeof(input) = Matrix{Float64}
size(input) = (2, 15000)


(2, 15000)

In [20]:


#calculate_Dirichlet_f(input[1:1, :], input[2:2, :], NN, Θ, st)
f, ∂2f_∂x2, ∂2f_∂t2 = calculate_derivatives_Dirichlet(input[1:1, :], input[2:2, :], NN, Θ, st)

f

1×15000 Matrix{Float64}:
 -5.68943  -8.56678  -1.34227  -1.9844  …  -7.99418  0.453512  -0.97215

In [None]:


optf = OptimizationFunction((Θ, input) -> loss_function(input, NN, Θ, st), AutoZygote())
optprob = OptimizationProblem(optf, Θ, input)

optresult = solve(
    optprob,
    callback = (p, l) -> callback(p, l, losses),
    config[:optimizer],
    maxiters = config[:maxiters],
)

# Parámetros optimizados a CPU si procede
Θ = optresult.u |> cpu_device()


Current loss: 2.0541126835089845
Current loss: 2.0267088179692228
Current loss: 1.9901749672686546
Current loss: 1.978092625397449
Current loss: 1.967268328987639
Current loss: 1.9442313692692512
Current loss: 1.9180677164813456
Current loss: 1.8952345261370735
Current loss: 1.8840532547254984
Current loss: 1.8612961254012512
Current loss: 1.81498519162908
Current loss: 1.7808586484267472
Current loss: 1.7614413861772382
Current loss: 1.7218294927029159
Current loss: 1.6897392068455768
Current loss: 1.626089310050856
Current loss: 1.5846027617008203
Current loss: 1.553390887952458
Current loss: 1.502407374176138
Current loss: 1.4562174831870722
Current loss: 1.4271906077963004
Current loss: 1.4071231296383833
Current loss: 1.3632669238912964
Current loss: 1.3172570707112028
Current loss: 1.2816046260706442
Current loss: 1.235005128761297
Current loss: 1.1625740518608891
Current loss: 1.0810351214900324
Current loss: 1.0464586249155363
Current loss: 0.9755813454270339
Current loss: 0.95

In [None]:
ax = (title = "loss vs iterations")
fig, ax = lines(log10.(losses), label = "Loss")
axislegend()
fig

In [None]:
calculate_Dirichlet_f(input[1:1,:], input[2:2,:], NN, Θ, st)

In [None]:
ax = (title = "solution")
t = 1.0
xs = reshape(collect(range(config[:xmin], config[:xmax], length=200)), 1, :)
t_fix = reshape(fill(t, length(xs)), 1, :)
#sol_p = [calculate_Dirichlet_f(xs, t_fix, NN, Θ, st) for x in xs]
sol_p = [calculate_Dirichlet_f(xs[:,i], t_fix[:,i], NN, Θ, st)[1] for i in 1:length(xs[1,:])]

#xs[1,1]
fig, ax = lines(xs[1,:],sol_p, label = "t=0")
#lines!(xs[1,:],bump.(xs[1,:], config[:x0], config[:x1], config[:p], config[:A]), label = "initial condition")
lines!(xs[1,:],-(xs[1,:] .- config[:xmin]).^4 .* (xs[1,:] .- config[:xmax]).^4, label = "initial condition")     
axislegend()
fig