# Heatmap of Average Robustness Performance and Optimal Point Analysis

## Imports

In [None]:
import Pkg; Pkg.activate(@__DIR__); Pkg.instantiate();
using PiccoloQuantumObjects
using QuantumCollocation
using ForwardDiff
using LinearAlgebra
using Plots
using SparseArrays
using NamedTrajectories
using Statistics
using CairoMakie
using Random
using ColorSchemes
using Makie
using Printf

In [None]:
# Problem parameters
T = 20
Δt = 0.2
U_goal = GATES.H
H_drive = [PAULIS.X, PAULIS.Y, PAULIS.Z]
piccolo_opts = PiccoloOptions(verbose=false)
sys = QuantumSystem(H_drive)
∂ₑHₐ = [PAULIS.X, PAULIS.Y, PAULIS.Z]
varsys = VariationalQuantumSystem(
    H_drive,
    ∂ₑHₐ
)


## Robustness Metrics

In [None]:
function width_robustness(system::AbstractQuantumSystem, traj::NamedTrajectory; thresh::Float64=0.999)
    F = 1.0
    drift = system.H.H_drift
    drive = system.H.H_drives
    pauls = [PAULIS.X, PAULIS.Y, PAULIS.Z]
    widths = []
    for i in 1:3
        ε = 0.0
        err = pauls[i]
        F = 1.0
        while (ε < 0.5 && F >= thresh)
            noisy_drift = drift + ε * err
            noisy_sys = QuantumSystem(noisy_drift, drive)
            F = unitary_rollout_fidelity(traj, noisy_sys)
            ε += 0.0001
        end
        push!(widths, ε)
    end
    return widths
end

## Parameter Sweep and Optimization

In [None]:
dda_bounds = 10 .^ range(-1, 1, length=15)
n_seeds = 5

var_probs = Array{Any}(undef, n_seeds, length(dda_bounds))
fidelities = zeros(n_seeds, length(dda_bounds))
robustness_widths = zeros(n_seeds, length(dda_bounds))

for (i, dda) in enumerate(dda_bounds)
    println("Running for dda_bound = $dda")
    for seed in 1:n_seeds
        Random.seed!(seed)
        prob = UnitaryVariationalProblem(
            varsys, U_goal, T, Δt;
            piccolo_options=piccolo_opts,
            robust_times=[[T],[T],[T]],
            R_dda=dda
        )
        solve!(prob)
        var_probs[seed, i] = prob
        fidelities[seed, i] = unitary_fidelity(prob.trajectory, sys)
        widths = width_robustness(sys, prob.trajectory)
        robustness_widths[seed, i] = mean(widths)
    end
end

## Heatmap of Average Robustness

In [None]:
avg_robustness = mean(robustness_widths, dims=1)

fig = Figure(resolution=(800, 600))
ax = Axis(fig[1, 1], xlabel="log10(dda_bound)", ylabel="Average Robustness Width")
heatmap!(ax, log10.(dda_bounds), 1:1, reshape(avg_robustness, (length(dda_bounds), 1)))
Colorbar(fig[1, 2], label = "Average Robustness")
display(fig)

## Find Optimal Point and Plot Control Values

In [None]:
min_dda_with_fidelity_and_robustness = Inf
optimal_dda_index = -1
optimal_seed_index = -1

for i in 1:length(dda_bounds)
    for seed in 1:n_seeds
        if fidelities[seed, i] > 0.999 && robustness_widths[seed, i] > 0.0
            if dda_bounds[i] < min_dda_with_fidelity_and_robustness
                min_dda_with_fidelity_and_robustness = dda_bounds[i]
                optimal_dda_index = i
                optimal_seed_index = seed
            end
        end
    end
end

if optimal_dda_index != -1
    println("Optimal dda_bound found: ", min_dda_with_fidelity_and_robustness)
    optimal_prob = var_probs[optimal_seed_index, optimal_dda_index]
    
    # Plotting control values for the optimal point
    controls = optimal_prob.trajectory.a
    time = 0:Δt:(T-1)*Δt
    
    p = plot(time, controls', xlabel="Time", ylabel="Control Amplitude", title="Optimal Control Pulses")
    display(p)
else
    println("No point found that satisfies the criteria.")
end