In [1]:
module_path = "../"
if findfirst(x -> x == module_path, LOAD_PATH) == nothing
    push!(LOAD_PATH, module_path)
end
using Revise

In [None]:
using Utils
using Solve
using Analysis
using Plots
using DSP
using Optim
using FFTW
using Statistics

In [None]:
B0 = 0.050 # In Gauss
initial_phases=(pi/2, pi/2)
t_samp = 1e-3 #how often to record data

# For a given solution, find the phase difference
# between the neutron and helium-3
function phase_diff(sol)
    # We don't care about the whole ensemble, since
    # the solution only consists of one run
    pn = planephase(sol.u[1:3,:,1]) # planephase computes the angle in the x-y plane.
    p3 = planephase(sol.u[4:6,:,1])
    d = mod2pi.(pn .- p3) #phase difference mod 2pi
    
    # This next part smooths out d
    for i=2:length(d) 
        delta_d = d[i]-d[i-1]
        if delta_d > pi
            d[i:end] = d[i:end] .- 2*pi
        elseif delta_d < -pi
            d[i:end] = d[i:end] .+ 2*pi
        end
    end
    d
end

# Plots the phase difference
function plot_phase_diff(sol)
    d = phase_diff(sol)
    plot(sol.t, d; xlabel="Time [s]", ylabel="(neutron - he3) phase difference [rad]")
end

In [None]:
# Solves for a chosen B1 value and simulation time
# (Uses the previously set value for B0)
function solve_f_crit(B1, sim_time)
    sol = run_simulations(sim_time, 1;
        B0=B0,
        B1=B1,
        initial_phases=initial_phases,
        saveat=0:t_samp:sim_time,
        )
    sol
end

# Plots phase difference for a chosen B1 value
function plot_f_crit(B1, sim_time)
    sol = solve_f_crit(B1, sim_time)
    plot_phase_diff(sol)
end

In [None]:
# Try a few different values of B1
plot_f_crit(0.4, 10)

In [None]:
# An objective function to minimize
function objective(sol, target, weight)
    d = phase_diff(sol)
    return sum(((d .- target).^2).*weight)
end
function min_f_crit(B1, sim_time)
    sol = solve_f_crit(B1, sim_time)
    objective(sol, zeros(length(sol.t)), ones(length(sol.t)))
end

In [None]:
# Find the optimum B1 value
opt_res = optimize(x->min_f_crit(x, 10), 0.4, 0.41)

In [None]:
B1_optimum = Optim.minimizer(opt_res)

In [None]:
# This should be a flat-ish line.
plot_f_crit(B1_optimum, 10)

In [None]:
# Define the pulsed modulation scheme
# Here we use Gaussian-shaped pulses.
# A and B determine the size of the pulses.
# x and y determine the pulse durations.
function modulation(t, f_mod, A, B, x, y)
    period = 1/f_mod
    1 - A*exp(-(t-period/4)^2/(2*x^2)) + B*exp(-(t-3*period/4)^2/(2*y^2))
end

# Returns a function that is a cosine modulated
# by our pulsed modulation scheme
function B1_templator(B1, w, f_mod, A, B, x, y)
    return t->B1*cos(w*t)*modulation(t % (1/f_mod), f_mod, A, B, 0.02, 0.02)
end

In [None]:
# Show what one period of the modulation cycle looks like
tmod = 0:1e-4:1
func_mod = B1_templator(B1_optimum, crit_params["w"], 1, 0.2, 0.2, 0.02, 0.02)
plot(tmod, [func_mod(t) for t=tmod]; xlabel="Time [s]", ylabel="Bx [Gauss]")

In [None]:
# Solves for given A & B parameters
# f_mod : modulation frequency (i.e. f_mod = 2 -> 4 pulses every seconds)
function solve_f_mod(A, B, f_mod, pulse_duration, sim_time)
    Bxfunc = Iterators.repeated(B1_templator(B1_optimum, 
            crit_params["w"], f_mod, A, B, pulse_duration, pulse_duration))
    sol = run_simulations(sim_time, 1;
            B0=B0,
            B1=0,
            initial_phases=(pi/6, -pi/6), # start the spins pi/3 radians apart
            Bxfunc=Bxfunc,
            saveat=0:t_samp:sim_time,
            )
    sol
end

# Plots the phase difference
function plot_f_mod(A, B, sim_time; f_mod=1, pulse_duration=0.02)
    sol = solve_f_mod(A, B, f_mod, pulse_duration, sim_time)
    plot_phase_diff(sol)
end

In [None]:
# Try out some values. Negative values are OK too.
# For this exercise, ideally we'd want the phase difference
# to oscillate between +/- pi/3 (around 1.05)
plot_f_mod(0.2, 0.5, 10)

In [None]:
# Define an objective function to be optimized
function min_f_mod(A, B, sim_time; f_mod=1, pulse_duration=0.02)
    sol = solve_f_mod(A, B, f_mod, pulse_duration, sim_time)
    period = 1/f_mod
    
    # The result should be a square wave of frequency f_mod and amplitude pi/3.
    target = sign.(cos.((2*pi*f_mod).*sol.t)) .* pi/3
    
    # We don't care what the angle is during the pulse, so
    # set those weights to zero.
    
    weights_1 = abs.((sol.t .% period) .- period*3/4) .> 2*pulse_duration
    weights_2 = abs.((sol.t .% period) .- period*1/4) .> 2*pulse_duration
    weights = weights_1 .& weights_2
    
    objective(sol, target, weights)
end

In [None]:
x0 = [0.2, 0.2]
results = optimize(x->min_f_mod(x[1], x[2], 10), x0, Optim.Options(time_limit = 300.0))

In [None]:
A_optimum, B_optimum = Optim.minimizer(results)

In [None]:
plot_f_mod(A_optimum, B_optimum, 10)