# Load in some saved data and plot it.

In [None]:
import numpy as np

num_neutrinos = 'Ne-60-Nx-40'
# num_neutrinos = 'M-50'
split_index = 60

fac = 100
l = 1 / fac
n = 5000 * fac // 1

omega = np.array([0.1, 0.2])
theta = 0.1

version = 1

# For graphing only.
rotate_Bloch_vectors = True

sample_period = fac

In [None]:
import pickle
import os

# folder = '../approach2/data/2024-1-8/before/'
# folder = '../approach2/data/2024-1-1-A/'
# folder = '../approach2/data/2024-1-1-B/'
# folder = '../approach2/data/2023-11-20/'
# folder = '../approach2/data-Martin/'
folder = '../approach2/'
filename = f'{num_neutrinos}-l-{l}-n-{n}-P-{sample_period}-v-{version}.pickle'

with open(os.path.join(folder, filename), 'rb') as file:
    data = pickle.load(file)

In [None]:
N = data.shape[1]
n_single = 2/N  # Number of collisions a single neutrino undergoes per total collisions.

f = split_index / N
g = 1 - f

print(N, n_single, f, g)

In [None]:
Pauli = np.array([
    [[0, 1], [1, 0]],
    [[0, -1j], [1j, 0]],
    [[1, 0], [0, -1]]
])

def bloch_vector(rho):
    """Find the Bloch vector of 2x2 density matrices."""
    print(rho.shape)
    print(Pauli.shape)
    bloch = rho[..., np.newaxis, :, :] @ Pauli
    bloch = np.trace(bloch, axis1=-1, axis2=-2).real
    return bloch

In [None]:
from numpy import sin, cos, pi
from scipy.integrate import solve_ivp
from scipy.spatial.transform import Rotation

# TODO: 0 factor here
omega_hat = np.array([sin(2*theta), 0, -cos(2*theta)])
omega_e = omega[0] * omega_hat
omega_x = omega[1] * omega_hat

def integrand(i, y):
    e, x = y.reshape(2, 3)
    
    b = f*e + g*x

    # TODO
    # random_angle_factor = 1
    random_angle_factor = 4/3

    # TODO: 0 factors on omega 2nd order terms, and `b = x` & `e = x`
    # b = x
    de = - l * (np.cross(e, b + omega_e) + random_angle_factor * l * ((e - b) + np.cross(np.cross(e, b), b / 2 - (0)*omega_e / random_angle_factor)))
    # b = e
    dx = - l * (np.cross(x, b + omega_x) + random_angle_factor * l * ((x - b) + np.cross(np.cross(x, b), b / 2 - (0)*omega_x / random_angle_factor)))

    # de = -l * np.cross(e, x + omega_e) - l**2 * (e - x + np.cross(np.cross(e, x), x) / 2)
    # dx = -l * np.cross(x, e + omega_x) - l**2 * (x - e + np.cross(np.cross(x, e), e) / 2)

    # print(l * np.cross(e, b) / 2)
    # de += random_angle_factor * l**2 * np.cross(e, np.cross(e, b)) / 2
    # dx += random_angle_factor * l**2 * np.cross(x, np.cross(x, b)) / 2

    # print(l * np.cross(e, x) / 2)
    # de += l**2 * np.cross(e, np.cross(e, x)) / 2
    # dx += l**2 * np.cross(x, np.cross(x, e)) / 2

    # print(l * np.cross(e, de + dx) / 4)
    # de += -l * np.cross(e, de + dx) / 4
    # dx += -l * np.cross(x, de + dx) / 4

    # de += l**2 * dx / 2
    # dx += l**2 * de / 2

    dy = np.concatenate([de, dx])
    return dy

y0 = np.concatenate([[0, 0, 1], [0, 0, -1]])

u = np.arange(n * n_single)
sol = solve_ivp(integrand, (0, n * n_single), y0, t_eval=u)

e, x = sol.y.reshape(2, 3, -1)  # Analytic solutions.

rot = Rotation.from_rotvec([0, 2*theta, 0]).as_matrix()
rot2 = Rotation.from_rotvec([0, -2*theta, 0]).as_matrix()
if rotate_Bloch_vectors:
    e = (rot[..., np.newaxis] * e[np.newaxis, ...]).sum(axis=1)
    x = (rot[..., np.newaxis] * x[np.newaxis, ...]).sum(axis=1)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# plt.style.use('bmh')

figsize = None
fig, ax = plt.subplots(figsize=figsize, dpi=200)

mean_e_color = 'C0'
std_e_color = 'C0'
mean_x_color = 'C1'
std_x_color = 'C1'

analytic_color = 'black'
mean_flavor_color = 'C2'

std_alpha = 0.4

u2 = np.arange(len(data)) * sample_period

flavor_e = 2 * np.abs(data[:, :split_index, 0, 0]) - 1
flavor_x = 2 * np.abs(data[:, split_index:, 0, 0]) - 1

bloch_e = bloch_vector(data[:, :split_index])
bloch_x = bloch_vector(data[:, split_index:])

if rotate_Bloch_vectors:
    bloch_e = (rot2[np.newaxis, np.newaxis] * bloch_e[..., np.newaxis, :]).sum(axis=-1)
    bloch_x = (rot2[np.newaxis, np.newaxis] * bloch_x[..., np.newaxis, :]).sum(axis=-1)

flavor_e = bloch_e[..., 2]  # Simulations.
flavor_x = bloch_x[..., 2]
mean_e = flavor_e.mean(axis=1)
mean_x = flavor_x.mean(axis=1)
mean = f * mean_e + g * mean_x

print("hello", mean_e.shape)

# purity_e = np.abs(data[:, :m, 0, 1)

std_e = flavor_e.std(axis=1)
std_x = flavor_x.std(axis=1)

ax.plot(u2, mean_e, lw=2, c=mean_e_color)
ax.plot(u2, mean_x, lw=2, c=mean_x_color)
ax.plot(u2, mean, lw=2, c=mean_flavor_color)

print(bloch_e.shape, "test")
ax.plot(u2, np.linalg.norm(bloch_e.mean(axis=1), axis=-1), c="red")

ax.fill_between(u2, mean_e - std_e, mean_e + std_e, alpha=std_alpha, facecolor=std_e_color)
ax.fill_between(u2, mean_x - std_x, mean_x + std_x, alpha=std_alpha, facecolor=std_x_color)

ax.plot(sol.t / n_single, e[2, :], c=analytic_color, ls='-', lw=0.75)
ax.plot(sol.t / n_single, x[2, :], c=analytic_color, ls='-', lw=0.75)
ax.plot(sol.t / n_single, f * e[2, :] + g * x[2, :], c=analytic_color, ls='-', lw=0.75)

ax.plot(sol.t / n_single, np.linalg.norm(e, axis=0), c="blue")

# Purity
# ax.plot(sol.t / n_single, (e[0]**2 + e[1]**2 + e[2]**2)**(1/2), lw=3)

ax.set_xlabel("$N_{sc}$ (total number of scatters)")
ax.set_ylabel(r"$a_z$ (flavor expectation value)")

text = rf"$\ell = {l}$, $\omega = ({omega[0]}, {omega[1]})$, $\theta = {theta}$, $N = {N}$"
ax.text(0.98, 0.98, text, transform=ax.transAxes, ha='right', va='top')

# Subtitle
# subtext = "w/ extra term"
# ax.text(0.98, 0.93, subtext, transform=ax.transAxes, ha='right', va='top')

ax.set_xlim(0, n)
ax.set_ylim(-1, 1)
ax.set_yticks([-1, -0.5, 0, 0.5, 1])

# ax.set_xlim(0, 200000)
# ax.set_ylim(mean[0] - 0.0000000001, mean[0] + 0.0000000001)
# ax.set_ylim(mean[0] - 0.0001, mean[0] + 0.0001)

In [None]:
figsize=None
fig, ax = plt.subplots(figsize=figsize, dpi=300)
ax.plot(u2, flavor_e)
ax.plot(u2, flavor_x)
# ax.set_xlim(0, 50000)
ax.axis('off')
None