# Load in some saved data and plot it.

In [None]:
import numpy as np

sample_period = 1000

M = 50

fac = 1000
l = 1 / fac
n = 5000 * fac * 2

omega = np.array([0.1, 0.2])
theta = 0.1

In [None]:
import pickle

filename = f'../approach2/M-{M}-l-{l}-n-{n}-P-{sample_period}.pickle'

with open(filename, 'rb') as file:
    data = pickle.load(file)

In [None]:
from numpy import sin, cos
from scipy.integrate import solve_ivp

omega_hat = np.array([sin(2*theta), 0, -cos(2*theta)])
omega_e = omega[0] * omega_hat
omega_x = omega[1] * omega_hat

def integrand(i, y):
    e, x = y.reshape(2, 3)
    b = (e + x) / 2

    random_angle_factor = 4 / 3
    
    de = - l * (np.cross(e, b + omega_e) + random_angle_factor * l * ((e - b) + np.cross(np.cross(e, b), b / 2 - omega_e / random_angle_factor)))
    dx = - l * (np.cross(x, b + omega_x) + random_angle_factor * l * ((x - b) + np.cross(np.cross(x, b), b / 2 - omega_x / random_angle_factor)))

    dy = np.concatenate([de, dx])
    return dy

y0 = np.concatenate([[0, 0, 1], [0, 0, -1]])

u = np.arange(n / M)
sol = solve_ivp(integrand, (0, n/M), y0, t_eval=u)

e, x = sol.y.reshape(2, 3, -1)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# plt.style.use('bmh')

fig, ax = plt.subplots(figsize=(5, 3), dpi=200)

mean_e_color = 'C0'
std_e_color = 'C0'
mean_x_color = 'C1'
std_x_color = 'C1'

analytic_color = 'black'
mean_flavor_color = 'C2'

std_alpha = 0.4

u2 = np.arange(len(data)) * sample_period

flavor_e = 2 * np.abs(data[:, :M, 0, 0]) - 1
flavor_x = 2 * np.abs(data[:, M:, 0, 0]) - 1
mean_e = flavor_e.mean(axis=1)
mean_x = flavor_x.mean(axis=1)
mean = (mean_e + mean_x) / 2

# purity_e = np.abs(data[:, :m, 0, 1)

std_e = flavor_e.std(axis=1)
std_x = flavor_x.std(axis=1)

ax.plot(u2, mean_e, lw=2, c=mean_e_color)
ax.plot(u2, mean_x, lw=2, c=mean_x_color)
ax.plot(u2, mean, lw=2, c=mean_flavor_color)

ax.fill_between(u2, mean_e - std_e, mean_e + std_e, alpha=std_alpha, facecolor=std_e_color)
ax.fill_between(u2, mean_x - std_x, mean_x + std_x, alpha=std_alpha, facecolor=std_x_color)

ax.plot(M * sol.t, e[2, :], c=analytic_color, ls='-', lw=0.75)
ax.plot(M * sol.t, x[2, :], c=analytic_color, ls='-', lw=0.75)
ax.plot(M * sol.t, (e[2, :] + x[2, :])/2, c=analytic_color, ls='-', lw=0.75)

ax.plot(M * sol.t, (e[0]**2 + e[1]**2 + e[2]**2)**(1/2), lw=3)

ax.set_xlabel("$N_{sc}$ (total number of scatters)")
ax.set_ylabel(r"$a_z$ (flavor expectation value)")

text = rf"$\ell = {l}$, $\omega = ({omega[0]}, {omega[1]})$, $\theta = {theta}$, $N = {2*M}$"
ax.text(0.98, 0.98, text, transform=ax.transAxes, ha='right', va='top')

ax.set_xlim(0, n)
ax.set_ylim(-1, 1)
ax.set_yticks([-1, -0.5, 0, 0.5, 1])

# ax.set_xlim(0, 3000)
# ax.set_ylim(-0.1, 0.1)

In [None]:
fig, ax = plt.subplots(dpi=200)
ax.plot(u2, 2 * data[..., 0, 0] - 1)
ax.axis('off')
None