In [None]:
"""Scatter neutrinos using density matrices.

Created 17 October 2023.
"""
import numpy as np
from numpy import pi, cos, arccos
from scipy.linalg import logm


def random_scatter(initial_states, n, *args, **kwargs):
    """Scatter random pairs of particles `n` times and return all the states
     along the way."""
    N = len(initial_states)
    particle1, particle2 = pick_random_pairs(N, n)
    relative_angles = random_theta(n)
    relative_angles = (pi) * np.ones_like(relative_angles)
    # particle1 = np.zeros_like(particle1)
    # particle2 = (N - 1) + particle1

    states = np.array(initial_states)
    yield states

    for _, (p1, p2, theta) in enumerate(zip(particle1, particle2, relative_angles)):
        states = states.copy()
        # states[1] = [[1/2, 1/2], [1/2, 1/2]]
        states[[p1, p2]] = independent_scatter(*states[[p1, p2]], theta=theta, *args, **kwargs)
        yield states


def independent_scatter(rho1, rho2, *args, **kwargs):
    """Scatter two independent neutrinos, then return each of their new
    independent density matrices."""
    rho_full = combine_rho(rho1, rho2)
    rho_full = scatter(rho_full, *args, **kwargs)
    rho1, rho2 = split_rho(rho_full)
    return rho1, rho2


def scatter_backgrounds(rho0, rho_background, n, *args, **kwargs):
    """Scatter a neutrino of interest off of background neutrinos `n` times."""
    # The `theta` angle of a spherically uniform random 3D direction.
    cos_theta = 2*np.random.rand(n) - 1
    theta = arccos(cos_theta)

    rho = np.array(rho0)
    yield rho

    for theta_ in theta:
        rho = scatter_background(rho, rho_background, *args, theta=theta_, **kwargs)
        yield rho


def scatter_background(rho, rho_background, *args, **kwargs):
    """Scatter a neutrino of interest with (2x2) density matrix `rho` off of a
    background neutrino with *independent* (2x2) density matrix
    `rho_background`, and return the new density matrix for the neutrino of
    interest."""
    rho_full = combine_rho(rho, rho_background)
    rho_full = scatter(rho_full, *args, **kwargs)
    rho = trace_out(rho_full)

    return rho


def scatter(rho, theta=pi/2, omega0_t=0.1):
    """Evolve the flavor density matrix of two neutrinos that scatter."""
    phase = np.exp(-2j * omega0_t * (1 - cos(theta)))

    # Time evolution matrix.
    N = 2
    N_states = 2
    U = np.zeros(2 * N * [N_states], dtype=complex)
    U[0, 0, 0, 0] = U[1, 1, 1, 1] = phase
    U[0, 1, 0, 1] = U[1, 0, 1, 0] = (phase + 1) / 2
    U[0, 1, 1, 0] = U[1, 0, 0, 1] = (phase - 1) / 2

    rho = matmul(matmul(U, rho), dagger(U))  # TODO
    trace = np.trace(np.trace(rho, axis1=0, axis2=2))
    rho /= trace  # TODO
    # print(np.angle(trace))
    
    return rho


def combine_rho(rho1, rho2):
    """Combine `rho1` and `rho2` into `rho` via a tensor product."""
    return np.moveaxis(np.tensordot(rho1, rho2, axes=0), 1, 2)


def split_rho(rho):
    """Obtain each neutrino's density matrix by tracing out the other."""
    rho1 = np.trace(rho, axis1=1, axis2=3)
    rho2 = np.trace(rho, axis1=0, axis2=2)
    return rho1, rho2


def matmul(A, B):
    """Multiply two (2, 2, 2, 2) arrays as if they were (4, 4) matrices."""
    return np.einsum('ijkl,klmn', A, B)


def dagger(A):
    """Find the Hermitian conjugate of a (2, 2, 2, 2) array as if it was a
    (4, 4) matrix."""
    return np.moveaxis(A, (0, 1), (2, 3)).conjugate()


def trace_out(rho):
    """Take the trace with respect to the second neutrino."""
    return np.trace(rho, axis1=1, axis2=3)


def flavor_expval(rho):
    return rho[1, 1].real


def entropy(rho):
    return - np.trace(rho @ logm(rho)).real


def pick_random_pairs(N, shape):
    """Pick two random (different) integers in [0, `N`)."""
    # Any random choice.
    choice1 = np.random.randint(0, N, shape)

    # A random choice that's different from choice 1.
    choice2 = (choice1 + np.random.randint(1, N, shape)) % N

    return choice1, choice2


def random_theta(shape):
    """Return the `theta` angles of spherically uniform random 3D directions."""
    cos_theta = 2*np.random.rand(shape) - 1
    theta = arccos(cos_theta)

    return theta


In [None]:
M = 20
initial_states = np.array(M * [[[1, 0], [0, 0]]] + M * [[[0, 0], [0, 1]]], dtype=complex)

In [None]:
states = np.array(list(random_scatter(initial_states, 5000, omega0_t=0.1)))

In [None]:
import matplotlib.pyplot as plt

flavor = states[:, :, 0, 0]
plt.plot(-(flavor.real - 0.5))
# plt.ylim(0, 1)
plt.xlim(0, 5000)

x = np.arange(len(flavor))
y = 0.5*np.exp(-x/500)
plt.plot(x, y, ls=':', lw=3, c='black')
plt.yscale('log')

In [None]:
states2 = np.array(list(random_scatter(initial_states, 5000, omega0_t=0.1)))
flavor2 = states2[..., 0, 0]
trace2 = np.trace(states2, axis1=-2, axis2=-1)
# plt.plot(flavor2.real)
# plt.plot(trace2.real)
plt.plot(np.angle(trace2))
plt.plot(np.abs(trace2))
# plt.plot(np.imag(states2[..., 0, 0]), c='black')
plt.xlim(0, 5000)
# plt.ylim(1 - 0.000000000001, 1 + 0.000000000001)

In [None]:
plt.plot(trace2.imag)

In [None]:
rho1 = [[1/2, -1/2], [-1/2, 1/2]]
rho2 = [[1/2, 0], [0, 1/2]]
split_rho(scatter(combine_rho(rho1, rho2)))

In [None]:
rhofull = combine_rho(rho1, rho2)
rhofull

In [None]:
states2[59]

# 

In [None]:
# sanity check

In [None]:
states3 = [np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).reshape(2, 2, 2, 2)]
for _ in range(100):
    states3.append(scatter(states3[-1]))
states3 = np.array([split_rho(s) for s in states3])

In [None]:
plt.plot(states3[:, 1].sum(axis=(-1, -2)))

In [None]:
states3[40, 0].sum()