In [None]:
#compute cos2theta using sin^2theta

In [None]:
def cos_to_sin2_halfangle(expr: sp.Expr) -> sp.Expr:
    """Replace cos(u) -> 1 - 2*sin(u/2)**2 everywhere (no phase shift)."""
    u = sp.Wild('u')
    out = expr.replace(sp.cos(u), 1 - 2*sp.sin(u/2)**2)
    expr = str(sp.expand(out))
    return (expr
    )

In [None]:
#compute sin^2theta cos2theta using sin^2theta

In [None]:
def cos_to_sin2_doubleangle(expr: sp.Expr) -> sp.Expr:
    """Replace cos(u) -> 1 - 2*sin(u/2)**2 everywhere (no phase shift)."""
    u = sp.Wild('u')
    out = expr.replace(1 - 2*sp.sin(u)**2, sp.cos(2*u))
    expr = (sp.expand(out))
    return (expr
    )

In [None]:
def cos_to_sin2_doubleanglem(expr: sp.Expr) -> sp.Expr:
    """Replace cos(u) -> 1 - 2*sin(u/2)**2 everywhere (no phase shift)."""
    u = sp.Wild('u')
    out = expr.replace(2*sp.sin(u)**2-1, -sp.cos(2*u))
    expr = (sp.expand(out))
    return (expr
    )

In [None]:
def sum_formulas(expr: sp.Expr) -> sp.Expr:
    u = sp.Wild('u')
    v = sp.Wild('v')

    rules = [
        (sp.sin(u + v), sp.sin(u)*sp.cos(v) + sp.cos(u)*sp.sin(v)),
        (sp.sin(u - v), sp.sin(u)*sp.cos(v) - sp.cos(u)*sp.sin(v)),
        (sp.cos(u + v), sp.cos(u)*sp.cos(v) - sp.sin(u)*sp.sin(v)),
        (sp.cos(u - v), sp.cos(u)*sp.cos(v) + sp.sin(u)*sp.sin(v)),
    ]

    out = expr
    for lhs, rhs in rules:
        out = out.replace(lhs, rhs)

    return sp.simplify(out)

In [None]:
def cos_sin_from_tan(tan_2phi, theta):
    """
    Given symbolic tan(2φ), return (cos(2φ), sin(2φ))
    using principal-value identities.
    """
    denom = sp.sqrt(1 + tan_2phi**2)
    cos_2phi = 1 / denom
    sin_2phi = tan_2phi / denom
    return cos_2phi, sin_2phi

In [None]:
def substitute_trig_values(expr, theta, phi,
                            cos2theta, sin2theta,
                            cos2phi, sin2phi, cos2phipr, sin2phipr, sint, cost, sinf, cosf, sinppr, cosppr):
    """
    Substitute cos(2θ), sin(2θ), cos(2φ), sin(2φ) in an expression.
    """
    substitutions = {
        sp.cos(2*theta): cos2theta,
        sp.sin(2*theta): sin2theta,
        sp.cos(2*phi): cos2phi,
        sp.sin(2*phi): sin2phi,
        sp.cos(2*phi_prime): cos2phipr,
        sp.sin(2*phi_prime): sin2phipr,
        sp.sin(theta): sint,
        sp.cos(theta): cost,
        sp.sin(phi): sinf,
        sp.cos(phi): cosf,
        sp.sin(phi_prime): sinppr,
        sp.cos(phi_prime): cosppr,
    }
    return expr.subs(substitutions)

In [None]:
# Make powers look like preety

In [None]:
def pretty_s(expr):
    expr = str(expr)
    return (expr
        .replace("**6", "⁶")
        .replace("**5", "⁵")
        .replace("**4", "⁴")
        .replace("**3", "³")
        .replace("**2", "²")
        .replace("*", "")
    )

In [None]:
# given tan2theta calculate all possible sin theta and costheta

In [None]:
def sin_theta_from_tan2_all(tan2theta, theta):
    r = sp.sqrt(1 / (1 + tan2theta**2))
    
    sin_branches = [
        sp.sqrt((1 + r) / 2),
        -sp.sqrt((1 + r) / 2),
        sp.sqrt((1 - r) / 2),
        -sp.sqrt((1 - r) / 2),
    ]
    cos_branches = [
        sp.sqrt((1 - r)/2),
        -sp.sqrt((1 - r)/2),
        sp.sqrt((1 + r)/2),
        -sp.sqrt((1 + r)/2),
    ]

    #equations = [sp.Eq(sp.sin(theta), b) for b in branches]
    #equations = [sp.Eq(sp.sin(theta), sp.simplify(b)) for b in branches]

    for i, (s, c) in enumerate(zip(sin_branches, cos_branches), 1):
        print(f"\nBranch {i}:")
        sp.pprint(sp.Eq(sp.sin(theta), sp.simplify(s)))
        sp.pprint(sp.Eq(sp.cos(theta), sp.simplify(c)))

In [None]:
# Given dimension,Diagonalizing matrix, Hamiltonian, mixing parameters, time - calculate and plot osc. probabilities

In [None]:
def Neutron_Osc_Prob(n, H, S, eta_n, eps_n, t_n ):
    """Replace cos(u) -> 1 - 2*sin(u/2)**2 everywhere (no phase shift)."""
    m, t, eps, eta = sp.symbols('m t ε η', real=True)
    R=sp.Rational
    I=sp.eye(n)
    H_D=sp.simplify(S.T@H@S)
    H_D_R=H_D-m*I
    S_diag=sp.simplify(S*sp.exp(-sp.I*H_D_R*t)*S.H)
    P_n=sp.zeros(n)
    for i in range(n):
        for j in range (n):
            P_n[i,j]=sp.simplify(sp.expand_complex(S_diag[i,j]*S_diag[i,j].conjugate().rewrite(sp.sin)))
    s=0
    for i in range(n):
        s=s+P_n[0,i]
    s=sp.simplify(s)
    p_n_sin = P_n.applyfunc(cos_to_sin2_halfangle)
    display(Math(r"H=" + sp.latex(H)))
    display(Math(r"S=" + sp.latex(S)))
    display(Math(r"S_{diag}=" + sp.latex(S_diag)))
    for i in range(n):
        for j in range(n):
            display(Math(rf"P{n}_{{{i}{j}}}=" + sp.latex(p_n_sin[i, j])))
    P_n_approx = p_n_sin.applyfunc(lambda e: sp.series(e, t, 0, 9).removeO())
    
    
    params = {
    eps: eps_n,
    eta: eta_n,
    m: 1.0
            }
    
    P_dict = {}
    
    P_dict[rf"$P(n\to n)$"] = P_n[0, 0]
    P_dict[rf"$P(n\to \bar n)$"] = P_n[0, 1]
    for k in range(2, n):
        if k % 2 == 0:
        # n_k → n_1
            P_dict[rf"$P(n_\to n_{k//2})$"] = P_n[0, k]
        else:
        # n̄_k → n_1
            P_dict[rf"$P(n\to \bar n_{int(k // 2)})$"] = P_n[0, k]
    P_funcs = {
        label: sp.lambdify(t, expr.subs(params), "numpy")
    for label, expr in P_dict.items()
        }
    t_vals = np.linspace(0, t_n, 1000)
    plt.figure(figsize=(8,5))

    for label, f in P_funcs.items():
        y = np.asarray(f(t_vals), dtype=float)

        if np.allclose(y, 0.0):
            continue
        plt.plot(t_vals, f(t_vals), label=label)
        
#plt.plot(t_vals, np.ones_like(t_vals), 'k--', linewidth=2, label=r"$s=1$")
    s_simpl2 = sp.trigsimp(s)

    s_num_expr = sp.simplify(s_simpl2.subs(params))


    s_func = sp.lambdify(t, s_num_expr, "numpy")


    y = s_func(t_vals)
    y = np.full_like(t_vals, y, dtype=float) if np.isscalar(y) else y
    plt.plot(t_vals, y, 'k--', linewidth=2, label=r"$s(t)=\sum P$")

    plt.xlabel(r"$t$", fontsize=16)
    plt.ylabel(r"$P(t)$", fontsize=16)
    plt.legend(fontsize=16, loc="upper right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    display(Math(rf"s=" + sp.latex(sp.simplify(s))))
    return (H_D_R
    )

In [None]:
# Diagonalize matrix with unitary transformations

In [None]:
def hermitian_inner(u: sp.Matrix, v: sp.Matrix):
    """Hermitian inner product <u,v> = u^† v (returns a scalar)."""
    return (u.conjugate().T * v)[0]

def hermitian_gs(vectors, tol_simplify=True):
    """
    Hermitian Gram–Schmidt orthonormalization.
    Input: list of column vectors (sympy Matrix of shape (n,1))
    Output: list of orthonormal column vectors
    """
    ortho = []
    for v in vectors:
        w = sp.Matrix(v)  # copy
        for q in ortho:
            w = w - hermitian_inner(q, w) * q
        norm2 = sp.simplify(hermitian_inner(w, w)) if tol_simplify else hermitian_inner(w, w)
        # If symbolic, norm2 might not simplify to 0 even when it is; this is a best-effort check:
        if norm2 == 0:
            continue
        w = w / sp.sqrt(norm2)
        w = sp.simplify(w) if tol_simplify else w
        ortho.append(w)
    return ortho

def unitary_diagonalizer(H: sp.Matrix):
    """
    For Hermitian H, returns (U, D) where:
      - U is unitary (U.H*U = I)
      - D is diagonal (U.H*H*U = D) with eigenvalues on diagonal
    """
    n = H.shape[0]
    if H.shape[0] != H.shape[1]:
        raise ValueError("H must be square.")

    # (Optional) sanity check: Hermitian
    if sp.simplify(H - H.H) != sp.zeros(n):
        print("Warning: H does not simplify to Hermitian (H != H.H). Proceeding anyway.")

    eigs = H.eigenvects()  # [(eigenvalue, multiplicity, [basis vectors]), ...]

    # Build an orthonormal basis by orthonormalizing each eigenspace separately.
    # This avoids mixing eigenvectors from different eigenvalues (important for correctness).
    columns = []
    diag_entries = []

    for lam, mult, basis in eigs:
        # Ensure each basis vector is a column vector
        basis_cols = []
        for v in basis:
            v = sp.Matrix(v)
            if v.shape == (n,):        # row-ish
                v = v.reshape(n, 1)
            elif v.shape == (1, n):    # row vector
                v = v.T
            basis_cols.append(v)

        # Orthonormalize within this eigenspace
        onb = hermitian_gs(basis_cols)

        # Some eigenvects() bases may be smaller than multiplicity if symbolic issues occur;
        # we just take what we got.
        for q in onb:
            columns.append(q)
            diag_entries.append(lam)

    # Stack columns into U
    if len(columns) != n:
        raise ValueError(
            f"Did not obtain a full eigenbasis (got {len(columns)} vectors, expected {n}). "
            "This can happen with symbolic parameters; consider substituting numeric values or using .eigenvects() differently."
        )

    U = sp.Matrix.hstack(*columns)

    # Diagonal matrix of eigenvalues in the same column order
    D = sp.diag(*diag_entries)

    # Optional: enforce simplification
    U = sp.simplify(U)
    D = sp.simplify(D)

    return U, D