In [2]:
using Lux, Random,TaylorDiff, ComponentArrays, Optimisers,Zygote, Plots,ForwardDiff, Statistics

rng = MersenneTwister()
Random.seed!(rng, 1)

# Define the model
model = Chain(Dense(1 => 50, tanh), Dense(50 => 50, tanh), Dense(50 => 1)) 

# Initialize model parameters
ps, st = Lux.setup(Xoshiro(0), model)
ps = ps |> ComponentArray

# Function to evaluate the model
function trial(model, x, ps, st)
    u, st = Lux.apply(model, x, ps, st)
    return u
end

f(x) = trial(model, x, ps, st)

# Define the loss function
x = collect(range(-1f0, 1f0, length=200))
x = reshape(x, 1, :)
y = Float32.(cos.(x))
data = (x, y)

opt = Adam(0.01f0)

function trial(model, x, ps, st)
    u, st = Lux.apply(model, x, ps, st)
    return u
end

# Define the loss function
function loss_function(model, ps, st, data)
    x = data[1]
    y = data[2]
    f(x) = trial(model, x, ps, st)
    y_pred = f(x)
    dydx = TaylorDiff.derivative(f, x, Float32.(ones(size(x))), Val(1)) 
    error = y_pred .- y
    loss = mean(error.^2)
    return loss, st, ()
end

ps, st = Lux.setup(rng, model)
loss_function(model, ps, st, data)

tstate = Training.TrainState(model, ps, st, opt)

grads, loss, stats, ts = Lux.Training.compute_gradients(AutoZygote(), loss_function, data, tstate)
Lux.Training.apply_gradients(ts, grads)

# Training loop
epochs = 5000
for epoch in 1:epochs
    grads, loss, stats, ts = Lux.Training.compute_gradients(AutoZygote(), loss_function, data, tstate)
    tstate = Lux.Training.apply_gradients(ts, grads)
    if epoch % 100 == 0
        println("Epoch: $epoch, Loss: $loss")
    end
end

Epoch: 100, Loss: 9.2041686e-5
Epoch: 200, Loss: 2.3943021e-5
Epoch: 300, Loss: 1.3304576e-5
Epoch: 400, Loss: 1.006307e-5
Epoch: 500, Loss: 8.013504e-6
Epoch: 600, Loss: 6.537302e-6
Epoch: 700, Loss: 5.4696675e-6
Epoch: 800, Loss: 4.67683e-6
Epoch: 900, Loss: 4.060914e-6
Epoch: 1000, Loss: 3.5580517e-6
Epoch: 1100, Loss: 3.130429e-6
Epoch: 1200, Loss: 8.34266e-6
Epoch: 1300, Loss: 3.2771097e-6
Epoch: 1400, Loss: 2.8803854e-6
Epoch: 1500, Loss: 2.5371392e-6
Epoch: 1600, Loss: 2.233566e-6
Epoch: 1700, Loss: 8.884234e-6
Epoch: 1800, Loss: 2.218052e-6
Epoch: 1900, Loss: 1.9036665e-6
Epoch: 2000, Loss: 1.6403837e-6
Epoch: 2100, Loss: 1.4173472e-6
Epoch: 2200, Loss: 1.2291441e-6
Epoch: 2300, Loss: 1.5624073e-6
Epoch: 2400, Loss: 9.771035e-7
Epoch: 2500, Loss: 8.555381e-7
Epoch: 2600, Loss: 7.5751376e-7
Epoch: 2700, Loss: 6.766129e-7
Epoch: 2800, Loss: 0.0004099977
Epoch: 2900, Loss: 6.218015e-7
Epoch: 3000, Loss: 5.331904e-7
Epoch: 3100, Loss: 4.841846e-7
Epoch: 3200, Loss: 4.453135e-7
Epoc