# Heat equation - Retrieving operator through ODE solving

In [1]:
using LinearAlgebra
using Flux
using DiffEqFlux
using OrdinaryDiffEq
using GalacticOptim
using Plots

include("src/equations/initial_functions.jl")
include("src/equations/equations.jl")
include("src/utils/graphic_tools.jl")
include("src/utils/generators.jl")
include("src/utils/processing_tools.jl")
include("src/neural_ode/objectives.jl")
include("src/neural_ode/models.jl")

Main.Models

## Analytical solution

Compute the analytical solution u(t, x) and the analytical derivative by time of the solution u(t,x)

### Different methods

We make usage of different method to evaluate a solution to the heat equation. Starting from the well-known analytical solution to the equation $u(x,t) = \sum^{\inf}_{n=1} c_n e^{-\pi^2 n^2 t} sin(n \pi x)$,
we implemented as well different methods to study their efficiency with regards to the stability of the equation (explicit, implicit finite-difference, finite-element, pseudo-spectral)

While for testing the quality of our implementation we started with the analytical solution, in the case of heat equation it is interesting to continue with a pseudo-spectral form using FFRT, a numerically stable solution for a large range of parameters.

In [None]:
t_max = 1.;
t_min = 0.;
x_max = 1.;
x_min = 0.;
t_n = 64;
x_n = 64;

typ = 3;
d = 1.;
k = 1.;
kappa = 0.005;
L = x_max - x_min
c = [0.7, 0.3, 0.4];
n = [3, 7, 10];

In [None]:
dt = round((t_max - t_min) / (t_n - 1), digits=8);
dx = round((x_max - x_min) / (x_n - 1), digits=8);

t = LinRange(t_min, t_max, t_n);
x = LinRange(x_min, x_max, x_n);

u0 = InitialFunctions.heat_analytical_init(t, x, n, c, kappa);
t, u = Equations.get_heat_fft(t, dx, x_n, kappa, u0[1, :]);
GraphicTools.show_state(u, "Snapshot") # need to reverse u, odesolver switch dimensions

In [None]:
ta, u_a = Generator.heat_snapshot_generator(t_max, t_min, x_max, x_min, t_n, x_n, 4, kappa, k)
GraphicTools.show_state(u_a, "Snapshot mock")

## Generate dataset

In [None]:
function create_solution(c, k, ka)
    X(k, x) = sqrt(2 / L) * sin(pi * k * (x - x_min) / L)
    u(x, t) = sum(c * exp(-ka * (pi * k / L)^2 * t) * X(k, x) for (c, k) in zip(c, k))
    u
end

In [None]:
function syver_cond(t_max, t_min, x_max, x_min, t_n, x_n, ka, nsample) 
    tsnap = LinRange(t_min, t_max, t_n)#[2:end]
    K = 50
    k = 1:K
    c = [randn(K) ./ k for _ = 1:nsample]
    u = [create_solution(c, k, ka) for c ∈ c]
    init = [u(x, 0.) for x ∈ x, u ∈ u]
    train = [u(x, t) for x ∈ x, u ∈ u, t ∈ tsnap]
    return tsnap, init, train
end

In [None]:
res = InitialFunctions.analytical_heat_1d(t, x, 1:50, [], kappa);
GraphicTools.show_state(res, "")

In [None]:
syver_dataset = syver_cond(t_max, t_min, x_max, x_min, t_n, x_n, kappa, 2);
t, init_set, true_set = syver_dataset;
t, u = Equations.get_heat_fft(t, dx, x_n, kappa, init_set[:, 1]);
display(
    plot(
        GraphicTools.show_state(u, ""), # need to reverse u, odesolver switch dimensions
        GraphicTools.show_state(true_set[:, 1, :], "");
    );
);

In [None]:
dataset = Generator.generate_heat_training_dataset(t_max, t_min, x_max, x_min, t_n, x_n, 128, 4, kappa, k, "./src/dataset/hand_analytical_heat_training_set.jld2", "training_set");
# hand_dataset = Generator.read_dataset("./src/dataset/hand_analytical_heat_training_set.jld2")["training_set"];
# analytic_dataset = Generator.read_dataset("./src/dataset/odesolver_analytical_heat_training_set.jld2")["training_set"];
# high_dataset = Generator.read_dataset("./src/dataset/high_dim_training_set.jld2")["training_set"];

In [None]:
function check_training_dataset(dataset)
    for i in range(1, 5, step=1)
        a, b, c, d = dataset[i];
        display(GraphicTools.show_state(b, ""));
    end
end

check_training_dataset(dataset)

In [None]:
# syver_dataset = syver_cond(t_max, t_min, x_max, x_min, t_n, x_n, kappa, 128);
# t, init_set, true_set = syver_dataset;
t, init_set, true_set = ProcessingTools.process_dataset(dataset);
loss(A, u₀, uₜ, t) = sum(abs2, S(A, u₀, t) - uₜ) / prod(size(uₜ));
loss(A) = loss(A, init_set, true_set, t);
A = zeros(x_n, x_n);
callback(A, loss) = (println(loss);flush(stdout);false);
result = DiffEqFlux.sciml_train(loss, A, ADAM(0.01); cb = callback, maxiters = 100);
K = result.u;
GraphicTools.show_state(K, "")

## Training with NeuralODE object

In [None]:
callback(theta, loss, u) = (IJulia.clear_output(true); display(loss); false)

In [None]:
function heat_training(net, epochs, u0, u_true, t)
  optimizer = DiffEqFlux.ADAM(0.01, (0.9, 0.999), 1.0e-8)
  
  tspan = (t[1], t[end])
  neural_ode = NeuralODE(net, tspan, Tsit5(), saveat=t)

  function predict_neural_ode(theta)
    return Array(neural_ode(u0, theta))
  end

  function loss(theta)
    u_pred = predict_neural_ode(theta)
    noise = 1e-8 .* randn(size(u_pred))
    l = Objectives.mseloss(u_pred + noise, u_true)
    return l, u_pred
  end
    
  lossL2(K) = loss(K) + 1e-2 * sum(abs2, K) / prod(size(K))

  result = DiffEqFlux.sciml_train(lossL2, neural_ode.p, optimizer; cb = callback, maxiters = epochs);
  return result
end

In [None]:
t, init_set, true_set = syver_dataset;
#t, init_set, true_set = ProcessingTools.process_dataset(analytic_dataset);
net = Models.HeatModel(x_n);
result = heat_training(net, 100, init_set, true_set, t);

### Operator reconstructed

Visualize reconstructed operator and check efficiency to determine solution from unknown sample

In [None]:
K = reshape(result, (x_n, x_n))
GraphicTools.show_state(reverse(K; dims = 1), "Operator K")

In [None]:
a, b, c = syver_dataset
print(size(c))

In [None]:
# t, init, train = syver_dataset;
u0 = init_set[:, 50];
u = true_set[:, 50, :];
#t, u0, u = Generator.get_heat_batch(t_max, t_min, x_max, x_min, t_n, x_n, 1, kappa, k);


u_pred = Array(S(K, u0, t));

plot(
    GraphicTools.show_state(u, ""),
    GraphicTools.show_state(u_pred, "");
    layout = (1, 2),
)

## Training with solver

In [None]:
plot(x, init_set[:, 1];)

In [None]:
t, u0, u_true = ProcessingTools.process_dataset(hand_dataset)
plot(x, u0[:, 1];)

In [None]:
function f(u, K, t)
  return K * u
end

In [None]:
function S(net, u0, t)
  tspan = (t[1], t[end])
  prob = ODEProblem(ODEFunction(f), copy(u0), tspan, net)
  sol = solve(prob, Tsit5(), saveat=t, reltol=1e-8, abstol=1e-8)
end

In [None]:
callback(A, loss) = (println(loss); flush(stdout); false)

function heat_training_2(A, epochs, u0, u_true, tsnap)
    
  function loss(A)
    u_pred = Array(S(A, u0, tsnap))
    l = Objectives.mseloss(u_pred, u_true)
    return l
  end

  result = DiffEqFlux.sciml_train(loss, A, ADAM(0.01); cb = callback, maxiters = epochs);
  return result
end

In [None]:
net = zeros(x_n, x_n);

In [None]:
t, u0, u_true = ProcessingTools.process_dataset(analytic_dataset);
result_2 = heat_training_2(copy(net), 100, u0, u_true, t);

In [None]:
t2, init, train = syver_dataset;
result_3 = heat_training_2(copy(net), 100, init, train, t2);

In [None]:
t3, u03, u_true3 = ProcessingTools.process_dataset(high_dataset);
result_4 = heat_training_2(copy(net), 100, u03, u_true3, t3);

In [None]:
K2 = result_2.u
K3 = result_3.u;
K4 = result_4.u;

display(
    plot(
        GraphicTools.show_state(reverse(K2; dims = 1), ""),
        GraphicTools.show_state(reverse(K3; dims = 1), ""),
        GraphicTools.show_state(reverse(K4; dims = 1), "");
        layout = (1, 3),
    ),
);  

In [None]:
# t, u0, u_true = ProcessingTools.process_dataset(dataset);

t, u0, u = Generator.get_heat_batch(t_max, t_min, x_max, x_min, t_n, x_n, 3, kappa, k);
# u_pred = Array(S(Afit, u0[, t));
u_pred_2 = Array(S(K2, u0, t));
u_pred_3 = Array(S(K3, u0, t));
u_pred_4 = Array(S(K4, u0, t));
# neural_ode = NeuralODE(net, (t[1], t[end]), Tsit5(), saveat=t)
# u_pred = Array(neural_ode(u0, result))

display(
    plot(
        GraphicTools.show_state(u, "data"),
        #GraphicTools.show_state(u_pred, "prediction 1"),
        GraphicTools.show_state(u_pred_2, "prediction 2"),
        GraphicTools.show_state(u_pred_3, "prediction 3"),
        GraphicTools.show_state(u_pred_3, "prediction 4");
        layout = (2, 2),
    ),
);