In [None]:
"""Scatter neutrinos using density matrices.

Note: the functions expect states to have complex datatypes.

Created 17 October 2023.
"""
import numpy as np
from numpy import pi, sin, cos, arccos, identity
from scipy.linalg import logm


def random_scatter(initial_states, n, l, omega=(0, 0), theta=0.1, Theta='random'):
    """Scatter random pairs of particles `n` times and return all the states
    along the way. Scatter the particles with random angles and with vacuum
    oscillations.

    :param Theta: Relative angle of (the momentum of) interacting neutrinos.
    :param theta: Neutrino mixing angle.
    :param omega: Vacuum oscillation frequency (in units of the coherent flavor
     conversion oscillation frequency mu).
    """
    N = len(initial_states)
    particle1, particle2 = pick_random_pairs(N, n)

    if Theta == 'random':
        Theta = random_theta(n)
    else:
        Theta = Theta * np.ones(n)

    M = N // 2
    omega_per_M = np.array(omega) / M

    states = np.array(initial_states)
    yield states

    for i, (p1, p2, Theta_) in enumerate(zip(particle1, particle2, Theta)):
        print(f"Scatter {i}")
        
        states = states.copy()
        states[[p1, p2]] = independent_scatter(*states[[p1, p2]], l=l, Theta=Theta_)

        # Where to split the states to apply two different vacuum oscillations.
        split_index = N // 2

        for (omega_per_M_, states_) in zip(omega_per_M, [states[:split_index], states[split_index:]]):
            if omega_per_M_:
                # Apply vacuum oscillations.
                states_[:] = propagate(states_, omega_per_M_ * l, theta)

        yield states


def independent_scatter(rho1, rho2, *args, **kwargs):
    """Scatter two independent neutrinos, then return each of their new
    independent density matrices."""
    rho_full = combine_rho(rho1, rho2)
    rho_full = scatter(rho_full, *args, **kwargs)
    rho1, rho2 = split_rho(rho_full)
    return rho1, rho2


def scatter_backgrounds(rho0, rho_background, n, *args, **kwargs):
    """Scatter a neutrino of interest off of background neutrinos `n` times."""
    Theta = random_theta(n)

    rho = np.array(rho0)
    yield rho

    for Theta_ in Theta:
        rho = scatter_background(rho, rho_background, *args, Theta=Theta_, **kwargs)
        yield rho


def scatter_background(rho, rho_background, *args, **kwargs):
    """Scatter a neutrino of interest with (2x2) density matrix `rho` off of a
    background neutrino with *independent* (2x2) density matrix
    `rho_background`, and return the new density matrix for the neutrino of
    interest."""
    rho_full = combine_rho(rho, rho_background)
    rho_full = scatter(rho_full, *args, **kwargs)
    rho = trace_out(rho_full)

    return rho


def scatter(rho, Theta=pi/2, l=0.1):
    """Evolve the flavor density matrix of two neutrinos that scatter."""
    phase = np.exp(-2j * l * (1 - cos(Theta)))

    # Time evolution matrix.
    N = 2
    N_states = 2
    U = np.zeros(2 * N * [N_states], dtype=complex)
    U[0, 0, 0, 0] = U[1, 1, 1, 1] = phase
    U[0, 1, 0, 1] = U[1, 0, 1, 0] = (phase + 1) / 2
    U[0, 1, 1, 0] = U[1, 0, 0, 1] = (phase - 1) / 2

    rho = matmul(matmul(U, rho), dagger(U))  # TODO

    # This step is required in order to avoid floating point error that quickly
    # blows up (after ~50 scatters).
    trace = np.trace(np.trace(rho, axis1=0, axis2=2))
    rho /= trace

    return rho


def propagate(rho, omega_t, theta):
    """Propagate a neutrino in the vacuum, i.e., apply vacuum oscillations to
    `rho`.

    TODO: Check bugs from trace normalization."""
    delta = omega_t / 2
    sin_ = sin(2*theta)
    cos_ = cos(2*theta)
    U = cos(delta) * identity(2) + 1j * sin(delta) * np.array([[cos_, sin_], [sin_, -cos_]])
    rho = U @ rho @ U.T.conjugate()

    return rho


def combine_rho(rho1, rho2):
    """Combine `rho1` and `rho2` into `rho` via a tensor product."""
    return np.moveaxis(np.tensordot(rho1, rho2, axes=0), 1, 2)


def split_rho(rho):
    """Obtain each neutrino's density matrix by tracing out the other."""
    rho1 = np.trace(rho, axis1=1, axis2=3)
    rho2 = np.trace(rho, axis1=0, axis2=2)
    return rho1, rho2


def matmul(A, B):
    """Multiply two (2, 2, 2, 2) arrays as if they were (4, 4) matrices."""
    return np.einsum('ijkl,klmn', A, B)


def dagger(A):
    """Find the Hermitian conjugate of a (2, 2, 2, 2) array as if it was a
    (4, 4) matrix."""
    return np.moveaxis(A, (0, 1), (2, 3)).conjugate()


def trace_out(rho):
    """Take the trace with respect to the second neutrino."""
    return np.trace(rho, axis1=1, axis2=3)


def flavor_expval(rho):
    return rho[1, 1].real


def entropy(rho):
    return - np.trace(rho @ logm(rho)).real


def pick_random_pairs(N, shape):
    """Pick two random (different) integers in [0, `N`)."""
    # Any random choice.
    choice1 = np.random.randint(0, N, shape)

    # A random choice that's different from choice 1.
    choice2 = (choice1 + np.random.randint(1, N, shape)) % N

    return choice1, choice2


def random_theta(shape):
    """Return the `theta` angles of spherically uniform random 3D directions."""
    cos_theta = 2*np.random.rand(shape) - 1
    theta = arccos(cos_theta)

    return theta


In [None]:
M = 50
n = 100000
l = 0.05
omega = np.array([0.1, 0.2])
# omega = (0, 0)
Theta = 'random'
# Theta = pi/2
theta = 0.1

initial_states = np.array(M * [[[1, 0], [0, 0]]] + M * [[[0, 0], [0, 1]]], dtype=complex)
states = np.array(list(random_scatter(initial_states, n, l=l, omega=omega, Theta=Theta, theta=theta)))

In [None]:
states

In [None]:
import matplotlib.pyplot as plt

plt.plot(states[..., :, 0, 0] - 0.5, alpha=0.5)
# plt.plot(states[..., :, 0, 0].mean(axis=1) - 0.5, c='black', lw=2)
# plt.plot(states[..., :M, 0, 0].mean(axis=1) - 0.5, c='black', lw=2)
# plt.plot(states[..., M:, 0, 0].mean(axis=1) - 0.5, c='black', lw=2)

# x = np.arange(len(states))
# plt.plot(0.5*np.exp(-(4/3)*l**2*x / M), c='red', lw=1, ls='--')
# plt.plot(-0.5*np.exp(-(4/3)*l**2*x / M), c='red', lw=1, ls='--')

None

# Analytic approach

In [None]:
from scipy.integrate import solve_ivp

omega_hat = np.array([sin(2*theta), 0, -cos(2*theta)])
omega_e = omega[0] * omega_hat
omega_x = omega[1] * omega_hat

def integrand(i, y):
    e, x = y.reshape(2, 3)
    b = (e + x) / 2
    
    e_cross_b = np.cross(e, b)
    de = - l * (e_cross_b + np.cross(e, omega_e) + (4/3) * l * ((e - b) + np.cross(e_cross_b, b) / 2 - np.cross(e_cross_b, omega_e)))
    
    x_cross_b = np.cross(x, b)
    dx = - l * (x_cross_b + np.cross(x, omega_x) + (4/3) * l * ((x - b) + np.cross(x_cross_b, b) / 2 - np.cross(x_cross_b, omega_x)))

    dy = np.concatenate([de, dx])
    return dy

y0 = np.concatenate([[0, 0, 1], [0, 0, -1]])

sol = solve_ivp(integrand, (0, n/M), y0)

e, x = sol.y.reshape(2, 3, -1)

In [None]:
plt.plot(sol.t, e[2, :])
plt.plot(sol.t, x[2, :])
plt.plot(sol.t, (e[2, :] + x[2, :])/2)

u = np.arange(len(states)) / M
plt.plot(u, 2 * (states[..., :M, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)
plt.plot(u, 2 * (states[..., M:, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)
plt.plot(u, 2 * (states[..., :, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)

# plt.xlim(0, 25)
# plt.ylim(-0.2, 0.2)

# Discrete analytic approach (Theta=pi/2, not random)

In [None]:
from scipy.spatial.transform import Rotation

y = [np.array([[0, 0, 1], [0, 0, -1]])]

r_e = Rotation.from_rotvec(-omega_e * l)
r_x = Rotation.from_rotvec(-omega_x * l)

print(omega_x)

for _ in range(n//M):
    e, x = y[-1]
    b = (e + x) / 2
    
    e_cross_b = np.cross(e, b)
    # e = e - l * (e_cross_b + l * (e + np.cross(e_cross_b, b) / 2))
    e = e - l * (e_cross_b + l * (e - b))

    # b = (e + x) / 2
    
    x_cross_b = np.cross(x, b)
    # x = x - l * (x_cross_b + l * (x + np.cross(x_cross_b, b) / 2))
    x = x - l * (x_cross_b + l * (x - b))

    # e = r_e.apply(e)
    # x = r_x.apply(x)

    e = e - l * np.cross(e, omega_e)
    x = x - l * np.cross(x, omega_x)
    
    y.append(np.vstack([e, x]))

y = np.stack(y)

In [None]:
import matplotlib.pyplot as plt

plt.plot(y[:, 0, 2])
plt.plot(y[:, 1, 2])
plt.plot(y[..., 2].mean(axis=1))

u = np.arange(len(y))
plt.plot(u, np.exp(-l**2 * u), c='red', lw=1, ls='--')

u = np.arange(len(states)) / M
plt.plot(u, 2 * (states[..., :M, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)
plt.plot(u, 2 * (states[..., M:, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)
plt.plot(u, 2 * (states[..., :, 0, 0].mean(axis=1) - 0.5), c='black', lw=0.5)