### 🔧 To-Do List

1. **Refine Berry Curvature Visualizations for Publication**
   - [ ] Enhance the figure quality/layout for presentation of Berry curvature figures to meet publication standards. 


# Berry Curvature

In [11]:
import os
os.chdir("src")
from nodal_knot import NodalKnot
os.chdir("..")

import numpy as np
import sympy as sp

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.cm import get_cmap, ScalarMappable

 
def k_to_zw(kx, ky, kz):
    """ F: 3D Brillouin zone -> C^2 """

    z_real = np.cos(2*kz) + 0.5
    z_imag = np.cos(kx) + np.cos(ky) + np.cos(kz) - 2.0
    z = z_real + 1j*z_imag
    
    w_real = np.sin(kx)
    w_imag = np.sin(ky)
    w = w_real + 1j*w_imag

    return z, w

def zw_to_c_hopf(z, w):
    """ f: C^2 -> C (Hopf Link) """
    return np.power(z, 2) - np.power(w, 2)

def zw_to_c_trefoil(z, w):
    """ f: C^2 -> C (Trefoil Knot) """
    return np.power(z, 2) - np.power(w, 3) 

def zw_to_c_3link(z, w):
    """ f: C^2 -> C (Figure-8 Knot) """
    return np.power(z, 3) - np.power(w, 2)*z

def zw_to_c_pqtorus(z, w):
    """ f: C^2 -> C (Figure-8 Knot) """
    p=2
    q=4
    return np.power(z, p) - np.power(w, q) 

 
trefoil = NodalKnot(k_to_zw, zw_to_c_trefoil)
threelink = NodalKnot(k_to_zw, zw_to_c_3link)
torus_12 = NodalKnot(k_to_zw, zw_to_c_pqtorus)

pio.renderers.default = "notebook_connected"


# -----------------------------------------------------------------------------
# 0) SYMBOLIC SETUP — once at import time
# -----------------------------------------------------------------------------
kx_s, ky_s, kz_s = sp.symbols('kx ky kz', real=True)
z_s = sp.cos(2*kz_s) + sp.Rational(1, 2) \
      + sp.I*(sp.cos(kx_s) + sp.cos(ky_s) + sp.cos(kz_s) - 2)
w_s = sp.sin(kx_s) + sp.I*sp.sin(ky_s)
f_s = z_s**2 - w_s**2

re_f_s = sp.simplify(sp.re(f_s))
im_f_s = sp.simplify(sp.im(f_s))

Bx_func = sp.lambdify((kx_s, ky_s, kz_s), re_f_s, 'numpy')
By_func = sp.lambdify((kx_s, ky_s, kz_s), im_f_s, 'numpy')

def compute_B_and_D(Bx_func, By_func, KX, KY, KZ, thickness):
    """
    Evaluate Bx, By on the meshgrid and build the detuning D = Bx^2 + By^2 - thickness^2.
    """
    Bx = Bx_func(KX, KY, KZ)
    By = By_func(KX, KY, KZ)
    D  = Bx**2 + By**2 - thickness**2
    return Bx, By, D

# -----------------------------------------------------------------------------
# 1) GRID CREATION
# -----------------------------------------------------------------------------
def make_kgrid(N, kx_min=-np.pi, kx_max=np.pi,
                  ky_min=-np.pi, ky_max=np.pi,
                  kz_min=0.0,   kz_max=np.pi):
    kx = np.linspace(kx_min, kx_max, N)
    ky = np.linspace(ky_min, ky_max, N)
    kz = np.linspace(kz_min, kz_max, N)
    return np.meshgrid(kx, ky, kz, indexing='ij')


# -----------------------------------------------------------------------------
# 2) FIELD, DETUNING, CONNECTION & CURVATURE
# -----------------------------------------------------------------------------
def compute_A_and_F(Bx, By, D, gamma):
    """
    Given Bx, By and detuning D, compute:
      - A_full: complex Berry connection array shape (3,N,N,N)
      - F_full: complex Berry curvature array    (3,N,N,N)
    """
    # precompute eps
    mask_un = (D >= 0)
    EPS = np.zeros_like(D, dtype=np.complex128)
    EPS[mask_un]   =  np.sqrt(D[mask_un])
    EPS[~mask_un]  = 1j*np.sqrt(np.abs(D[~mask_un]))

    # derivatives
    dBx = np.stack(np.gradient(Bx, edge_order=2), axis=0)
    dBy = np.stack(np.gradient(By, edge_order=2), axis=0)

    # --- A in unbroken region ---
    num_r = (gamma*Bx - EPS*By)[None]*dBx + (EPS*Bx + gamma*By)[None]*dBy
    den   = (2*EPS*(Bx**2 + By**2))[None]
    A_full = np.zeros_like(dBx, dtype=np.complex128)
    A_full[:, mask_un] = num_r[:, mask_un] / den[:, mask_un]

    # --- A in broken region ---
    cross_d = Bx[None]*dBy - By[None]*dBx
    num_b   = (EPS - gamma)[None] * cross_d
    A_full[:, ~mask_un] = num_b[:, ~mask_un] / den[:, ~mask_un]

    # --- F only in broken region ---
    cross_F = np.stack([
        dBy[1]*dBx[2] - dBy[2]*dBx[1],
        dBy[2]*dBx[0] - dBy[0]*dBx[2],
        dBy[0]*dBx[1] - dBy[1]*dBx[0]
    ], axis=0)
    den3    = (2*(EPS**3))[None]
    F_full  = np.zeros_like(A_full)
    F_full[:, ~mask_un] = (gamma * cross_F[:, ~mask_un]) / den3[:, ~mask_un]

    return A_full, F_full


# -----------------------------------------------------------------------------
# 3) EXTRACT MAGNITUDES & SHELL VECTORS
# -----------------------------------------------------------------------------
def extract_shell_vectors(F_full, D, shell_thickness):
    """
    From complex F_full (3×N×N×N), return unit‐vectors (Fx,Fy,Fz) and mask_shell.
    """
    # magnitude grid
    F_mag_grid = np.linalg.norm(F_full, axis=0)
    mask_shell = (D < 0) & (D > -shell_thickness)

    # flatten & normalize
    Fx, Fy, Fz = [comp[mask_shell].ravel() for comp in F_full.real]
    norm = np.sqrt(Fx**2 + Fy**2 + Fz**2)
    keep = norm > 0
    return (Fx[keep]/norm[keep],
            Fy[keep]/norm[keep],
            Fz[keep]/norm[keep],
            mask_shell,
            F_mag_grid)


In [12]:
 



# -----------------------------------------------------------------------------
# 5) PLOTTING (reuse the computed arrays)
# -----------------------------------------------------------------------------
# Without skeleton inside
"""def plot_ep_surface_and_vectors(KX, KY, KZ, D, vectors, mask_shell, cone_color='blue',SizeCone=1):
    Fx, Fy, Fz = vectors
    x = KX[mask_shell].ravel()
    y = KY[mask_shell].ravel()
    z = KZ[mask_shell].ravel()

    # EP surface @ D = -0.1
    surface = go.Isosurface(
        x=KX.ravel(), y=KY.ravel(), z=KZ.ravel(),
        value=D.ravel(),
        isomin=-0.1, isomax=-0.1, surface_count=1,
        colorscale='Greys', opacity=0.5, showscale=False,
        caps={'x_show':False,'y_show':False,'z_show':False}
    )

    # down-sample
    max_arrows = 100
    idx = np.random.choice(len(x), min(len(x), max_arrows), replace=False)

    # uniform‐color cones
    cones = go.Cone(
        x=x[idx], y=y[idx], z=z[idx],
        u=Fx[idx], v=Fy[idx], w=Fz[idx],
        anchor="tail",
        sizemode="absolute",   # or "scaled"
        sizeref=SizeCone,             # <-- bump this up from 1 to make cones 3× longer
        showscale=False,
        colorscale=[[0, cone_color], [1, cone_color]],
        cmin=0, cmax=1
    )


    fig = go.Figure(data=[surface, cones])
    fig.update_layout(scene={'aspectmode':'cube'},
                      margin={'l':0,'r':0,'b':0,'t':50})
    fig.show()
 """
# With skeleton inside
def plot_ep_surface_and_vectors(KX, KY, KZ, D,
                                vectors, mask_shell,
                                hopf_graph,
                                cone_color='blue',
                                SizeCone=1,
                                node_color='black',
                                edge_color='black',
                                node_size=2,
                                edge_width=1,
                                max_arrows=100):
    Fx, Fy, Fz = vectors
    # EP surface @ D = -0.1
    surface = go.Isosurface(
        x=KX.ravel(), y=KY.ravel(), z=KZ.ravel(),
        value=D.ravel(),
        isomin=-0.1, isomax=-0.1, surface_count=1,
        colorscale='Greys', opacity=0.5, showscale=False,
        caps={'x_show':False,'y_show':False,'z_show':False}
    )

    # down-sample the arrows
    x = KX[mask_shell].ravel()
    y = KY[mask_shell].ravel()
    z = KZ[mask_shell].ravel()
 
    idx = np.random.choice(len(x), min(len(x), max_arrows), replace=False)

    # uniform-color cones
    cones = go.Cone(
        x=x[idx], y=y[idx], z=z[idx],
        u=Fx[idx], v=Fy[idx], w=Fz[idx],
        anchor="tail",
        sizemode="absolute",
        sizeref=SizeCone,
        showscale=False,
        colorscale=[[0, cone_color], [1, cone_color]],
        cmin=0, cmax=1
    )

    # --- build graph‐edge trace ---
    edge_x, edge_y, edge_z = [], [], []
    for u, v, data in hopf_graph.edges(data=True):
        # data['pts'] is an ndarray of shape (M, 3)
        pts = data['pts']
        # unpack into separate lists, then None to lift the pen
        edge_x += pts[:, 0].tolist() + [None]
        edge_y += pts[:, 1].tolist() + [None]
        edge_z += pts[:, 2].tolist() + [None]
    graph_edges = go.Scatter3d(
        x=edge_x, y=edge_y, z=edge_z,
        mode='lines',
        line=dict(color=edge_color, width=edge_width),
        hoverinfo='none'
    )

    # --- build graph‐node trace ---
    node_x = [hopf_graph.nodes[n]['o'][0] for n in hopf_graph.nodes()]
    node_y = [hopf_graph.nodes[n]['o'][1] for n in hopf_graph.nodes()]
    node_z = [hopf_graph.nodes[n]['o'][2] for n in hopf_graph.nodes()]
    graph_nodes = go.Scatter3d(
        x=node_x, y=node_y, z=node_z,
        mode='markers',
        marker=dict(size=node_size, color=node_color),
        hoverinfo='none'
    )

    fig = go.Figure(data=[surface, cones, graph_edges, graph_nodes])
    fig.update_layout(
        scene=dict(aspectmode='cube'),
        margin=dict(l=0, r=0, b=0, t=50)
    )
    fig.show()
    return fig
 
def plot_graph_with_edge_orientations(KX, KY, KZ,
                                           Fx, Fy, Fz,     # your 1-D masked arrays
                                           mask_shell,     # the same 3-D bool mask
                                           hopf_graph,
                                           cone_color='blue',
                                           size_factor=1.0,  # global multiplier
                                           node_color='black',
                                           edge_color='black',
                                           node_size=2,
                                           edge_width=1):
    # --- 1) draw edges & collect sample-points + lengths ---
    edge_x = []; edge_y = []; edge_z = []
    samples = []      # (pt, L) pairs
    lengths = []
    for u, v, data in hopf_graph.edges(data=True):
        pts = data['pts']                # array shape (P,3)
        # full edge polyline:
        edge_x += pts[:,0].tolist() + [None]
        edge_y += pts[:,1].tolist() + [None]
        edge_z += pts[:,2].tolist() + [None]

        # compute total length
        diffs = np.diff(pts, axis=0)
        L = np.sum(np.linalg.norm(diffs, axis=1))
        lengths.append(L)

        # sample at 25%, 50%, 75%
        idxs = np.linspace(0, len(pts)-1, num=5, dtype=int)[1:-1]
        for i in idxs:
            samples.append((pts[i], L))

    graph_edges = go.Scatter3d(
        x=edge_x, y=edge_y, z=edge_z,
        mode='lines',
        line=dict(color=edge_color, width=edge_width),
        hoverinfo='none'
    )

    # --- 2) draw nodes ---
    node_x = [hopf_graph.nodes[n]['o'][0] for n in hopf_graph.nodes()]
    node_y = [hopf_graph.nodes[n]['o'][1] for n in hopf_graph.nodes()]
    node_z = [hopf_graph.nodes[n]['o'][2] for n in hopf_graph.nodes()]
    graph_nodes = go.Scatter3d(
        x=node_x, y=node_y, z=node_z,
        mode='markers',
        marker=dict(size=node_size, color=node_color),
        hoverinfo='none'
    )

    # --- 3) build lookup table for masked vector field ---
    xm = KX[mask_shell].ravel()
    ym = KY[mask_shell].ravel()
    zm = KZ[mask_shell].ravel()
    coords = np.column_stack((xm, ym, zm))       # (M,3)
    Fvecs  = np.column_stack((Fx, Fy, Fz))       # (M,3)

    # --- 4) normalize lengths by Lmax ---
    Ls   = np.array(lengths)
    Lmax = Ls.max() if Ls.size else 1.0

    # prepare arrays for cones
    samp_pts = np.array([pt for pt, _ in samples])
    us = []; vs = []; ws = []

    for (pt, L) in samples:
        # find nearest vector-sample
        d2  = np.sum((coords - pt)**2, axis=1)
        idx = np.argmin(d2)
        u0, v0, w0 = Fvecs[idx]
        norm = np.linalg.norm((u0, v0, w0)) or 1.0
        dir_vec = np.array((u0, v0, w0)) / norm

        # square-root mapping of length
        scale = (L / Lmax)**(1/5) * size_factor

        us.append(dir_vec[0] * scale)
        vs.append(dir_vec[1] * scale)
        ws.append(dir_vec[2] * scale)

    # --- 5) plot cones scaled by sqrt-length ---
    cones = go.Cone(
        x=samp_pts[:,0], y=samp_pts[:,1], z=samp_pts[:,2],
        u=us, v=vs, w=ws,
        anchor="tail",
        sizemode="scaled",
        sizeref=1,            # base unit for “1.0” length; adjust if needed
        showscale=False,
        colorscale=[[0, cone_color], [1, cone_color]]
    )

    # assemble figure
    fig = go.Figure(data=[graph_edges, graph_nodes, cones])
    fig.update_layout(
        scene=dict(aspectmode='cube'),
        margin=dict(l=0, r=0, b=0, t=50)
    )
    fig.show()
    return fig
def animate_shell_and_slices_precomputed(
    KX, KY, KZ, D, flat_F, dx=0.1
):
    """
    Animate 3D shell‐sample + 2D kz‐slice using precomputed arrays:
      - KX, KY, KZ : meshgrid arrays (N×N×N)
      - D          : detuning array       (N×N×N)
      - flat_F     : 1D |F| array length N**3
      - dx         : initial shell‐thickness
    """
    N = KX.shape[0]
    # extract the coordinate vectors once:
    kx_vals = KX[:,0,0]
    ky_vals = KY[0,:,0]
    kz_vals = KZ[0,0,:]
    pts = np.vstack((KX.ravel(), KY.ravel(), KZ.ravel())).T

    # 1) shell‐masking threshold
    r0 = dx**3
    D_flat = np.abs(D.ravel())
    D_flat[D_flat > r0] = np.inf

    # 2) geometric sequence of radii down to ~1e-12
    exps = [1]
    while True:
        nxt = exps[-1] * 3
        if dx**nxt <= 1e-12:
            exps.append(nxt)
            break
        exps.append(nxt)
    r_vals = [dx**e for e in exps]

    # 3) gather shell‐point indices
    shells = []
    for lo, hi in zip(r_vals[1:], r_vals[:-1]):
        shells.append(np.where((D_flat <= hi) & (D_flat > lo))[0])
    shells.append(np.where(D_flat <= r_vals[-1])[0])
    sampled_idx = np.concatenate(shells)
    pts_sampled = pts[sampled_idx]
    F_sampled  = flat_F[sampled_idx]

    # 4) reshape for slicing
    F3d = flat_F.reshape((N, N, N))

    # 5) build figure skeleton
    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{"type":"scene"}, {"type":"xy"}]],
        subplot_titles=["3D shell + slice‐plane", "2D cross‐section"]
    )

    scatter3d = go.Scatter3d(
        x=pts_sampled[:,0], y=pts_sampled[:,1], z=pts_sampled[:,2],
        mode='markers',
        marker=dict(size=3, color=F_sampled, colorscale='Viridis',
                    colorbar=dict(title="|F|", thickness=15, x=0.45),
                    showscale=True),
        name="shell samples"
    )
    fig.add_trace(scatter3d, row=1, col=1)

    # placeholders for plane & slice
    fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='lines',
                               line=dict(color='gray'), opacity=0.3),
                  row=1, col=1)
    fig.add_trace(go.Scatter(x=[0], y=[0], mode='markers',
                             marker=dict(size=4, color=[0], colorscale='Viridis',
                                         showscale=True,
                                         colorbar=dict(title="|F|", thickness=10))),
                  row=1, col=2)

    # 6) build frames
    frames = []
    for l, kz0 in enumerate(kz_vals):
        # plane at kz0
        px = [kx_vals.min(), kx_vals.max(), kx_vals.max(), kx_vals.min(), kx_vals.min()]
        py = [ky_vals.min(), ky_vals.min(), ky_vals.max(), ky_vals.max(), ky_vals.min()]
        pz = [kz0]*5
        plane = go.Scatter3d(x=px, y=py, z=pz, mode='lines',
                             line=dict(color='gray', width=2),
                             surfaceaxis=2, opacity=0.3)

        # 2D slice
        mask_layer = np.abs(KZ - kz0) < 1e-12
        i0, j0    = np.nonzero(mask_layer[:,:,l])
        slice2d = go.Scatter(
            x=kx_vals[i0], y=ky_vals[j0],
            mode='markers',
            marker=dict(size=5, color=F3d[i0, j0, l],
                        colorscale='Viridis', showscale=False)
        )

        frames.append(go.Frame(data=[scatter3d, plane, slice2d], name=str(l)))

    fig.frames = frames

    # 7) slider & play/pause
    steps = [{
        "args": [[str(l)], {"frame":{"duration":0,"redraw":True}, "mode":"immediate"}],
        "label": f"{kz0:.2f}", "method":"animate"
    } for l, kz0 in enumerate(kz_vals)]
    sliders = [{
        "pad":{"t":50}, "len":0.8, "x":0.15, "y":0.05,
        "currentvalue":{"prefix":"k_z = ","font":{"size":16},"xanchor":"center"},
        "steps": steps
    }]
    fig.update_layout(
        title="3D shell & animated 2D slice",
        scene=dict(xaxis_title="kₓ", yaxis_title="kᵧ", zaxis_title="k_z",
                   camera=dict(eye=dict(x=1.5,y=1.5,z=1.2))),
        xaxis2=dict(title="kₓ"),
        yaxis2=dict(title="kᵧ", scaleanchor="x2", scaleratio=1),
        updatemenus=[{
            "buttons":[
                {"args":[None,{"frame":{"duration":0,"redraw":True},"mode":"immediate"}],
                 "label":"Play","method":"animate"},
                {"args":[[None],{"frame":{"duration":0,"redraw":True},"mode":"immediate"}],
                 "label":"Pause","method":"animate"}
            ],
            "direction":"left", "pad":{"r":10,"t":20}, "type":"buttons", "x":0.01, "y":0.15
        }],
        sliders=sliders,
        height=600, width=1100
    )
    fig.show()
    return fig

In [None]:
N = 100
gamma = 0.5
shell_thickness = 0.1

# grid
KX, KY, KZ = make_kgrid(N)

# fields + D
Bx, By, D = compute_B_and_D(Bx_func, By_func, KX, KY, KZ, gamma)

# compute A_full and F_full
A_full, F_full = compute_A_and_F(Bx, By, D, gamma)

Fmin = float(np.nanmin(F_full))
Fmax = float(np.nanmax(F_full))
print(f"‖F‖ min = {Fmin:.3e}, max = {Fmax:.3e}")

# compute magnitudes
flat_A = np.linalg.norm(A_full.reshape(3, -1), axis=0)
flat_F = np.linalg.norm(F_full.reshape(3, -1), axis=0)

# shell vectors
Fx, Fy, Fz, mask_shell, F_mag_grid = extract_shell_vectors(F_full.imag, D, shell_thickness)
Ax, Ay, Az, mask_shell, F_mag_grid = extract_shell_vectors(A_full.imag, D, shell_thickness)


**Note: We may like to use the below rescaling generally in our NodalKnot code**

In [None]:
 

 
def rescale_graph_to_kspace(
    graph,
    kx_min, kx_max,
    ky_min, ky_max,
    kz_min, kz_max,
    N
):
    """
    Mutate `graph` in-place, converting its integer voxel coords
    (stored in node attr 'o' and edge attr 'pts') into physical
    k-space coords in the ranges [kx_min,kx_max], etc.
    
    Parameters
    ----------
    graph : networkx.Graph
      Must have node attribute 'o' as length-3 integer arrays,
      and edge attribute 'pts' as (M,3) integer arrays.

    kx_min, kx_max, ky_min, ky_max, kz_min, kz_max : floats
      The physical ranges in k-space.

    N : int
      Number of samples per dimension used in the skeletonization
      (i.e. hopf.pts_per_dim).
    """
    # Compute the step sizes
    dx = (kx_max - kx_min) / (N - 1)
    dy = (ky_max - ky_min) / (N - 1)
    dz = (kz_max - kz_min) / (N - 1)

    # Rescale node positions
    for n, data in graph.nodes(data=True):
        o = data.get('o')
        if o is None:
            continue
        i, j, k = map(int, o)
        data['o'] = np.array([
            kx_min + i * dx,
            ky_min + j * dy,
            kz_min + k * dz
        ], dtype=float)

    # Rescale edge point lists
    for u, v, data in graph.edges(data=True):
        pts = data.get('pts')
        if pts is None:
            continue
        pts_arr = np.asarray(pts, dtype=int)
        scaled = np.empty_like(pts_arr, dtype=float)
        scaled[:, 0] = kx_min + pts_arr[:, 0] * dx
        scaled[:, 1] = ky_min + pts_arr[:, 1] * dy
        scaled[:, 2] = kz_min + pts_arr[:, 2] * dz
        data['pts'] = scaled


        
hopf = NodalKnot(k_to_zw, zw_to_c_hopf,pts_per_dim=400)
hopf_graph = hopf.skeleton_graph(clean=True, thickness=gamma)
# 1) After building your skeleton_graph:
N = hopf.pts_per_dim  # default 400, or whatever you set
rescale_graph_to_kspace(
    hopf_graph,
    kx_min=-np.pi, kx_max=np.pi,
    ky_min=-np.pi, ky_max=np.pi,
    kz_min=0.0,     kz_max=np.pi,
    N=N
)

hopf_surface = hopf.knot_surface_points(thickness=gamma, epsilon=0.001)
hopf_surface_fig = hopf.plot_3D(hopf_surface)
hopf_surface_fig.show()

In [None]:
plot_ep_surface_and_vectors(
    KX, KY, KZ, D,
    (Fx, Fy, Fz),
    mask_shell,
    hopf_graph,
    cone_color='purple',
    SizeCone=1.4,          
    node_color='red',
    edge_color='blue',
    node_size=5,
    edge_width=7,
    max_arrows=75
)

In [None]:
plot_ep_surface_and_vectors(
    KX, KY, KZ, D,
    (Ax, Ay, Az),
    mask_shell,
    hopf_graph,
    cone_color='yellow',
    SizeCone=2,          
    node_color='red',
    edge_color='blue',
    node_size=4,
    edge_width=4
)

In [None]:
plot_graph_with_edge_orientations(
    KX, KY, KZ,
    Fx, Fy, Fz,
    mask_shell,
    hopf_graph,
    cone_color='purple',
    size_factor=0.4,    # smaller overall cones
    node_color='red',
    edge_color='blue',
    node_size=7,
    edge_width=4
)


In [None]:
animate_shell_and_slices_precomputed(
    KX, KY, KZ,
    D, flat_F,
    dx=0.2
)

In [None]:
 
# 1) Build k-space axes to match D.shape exactly
Nx, Ny, Nz = D.shape
kx_vals    = np.linspace(-np.pi, np.pi, Nx)
ky_vals    = np.linspace(-np.pi, np.pi, Ny)
kz_vals    = np.linspace(0,       np.pi, Nz)

# 2) Compute the FULL 3D magnitude grid (DON'T flatten or mask it!)
Fmag_grid = np.sqrt( np.abs(F_full[0])**2 +  np.abs(F_full[1])**2 +  np.abs(F_full[2])**2)
# Now Fmag_grid.shape == D.shape

def compute_loops(fixed_axis, fixed_val,
                  kx_vals, ky_vals, kz_vals,
                  D, Fmag_grid):
    axes = {"kx": kx_vals, "ky": ky_vals, "kz": kz_vals}
    i0   = np.abs(axes[fixed_axis] - fixed_val).argmin()

    free = ["kx", "ky", "kz"]
    free.remove(fixed_axis)
    U, V = np.meshgrid(axes[free[0]], axes[free[1]], indexing="ij")

    if fixed_axis == "kx":
        D_slice = D[i0,:,:]
        F_slice = Fmag_grid[i0,:,:]
    elif fixed_axis == "ky":
        D_slice = D[:,i0,:]
        F_slice = Fmag_grid[:,i0,:]
    else:  # fixed_axis == "kz"
        D_slice = D[:,:,i0]
        F_slice = Fmag_grid[:,:,i0]

    broken = (D_slice < 0)
    max_brok = np.max(np.abs(D_slice[broken])) if np.any(broken) else 0.0
    if max_brok > 0:
        lvls = np.linspace(0.1*max_brok, max_brok, 4)
        abs_levels = np.concatenate(([0.0], lvls))
    else:
        abs_levels = np.array([0.0])

    # EP loops at D=0
    c0 = plt.contour(U, V, D_slice, levels=[0.0], colors="none")
    ep_loops = [p.vertices for p in c0.collections[0].get_paths() if p.vertices.size>0]
    plt.clf()

    # Broken-region loops at D = -ℓ
    loops_info = []
    for ℓ in abs_levels[1:]:
        c1 = plt.contour(U, V, D_slice, levels=[-ℓ], colors="none")
        for p in c1.collections[0].get_paths():
            verts = p.vertices
            if verts.shape[0] == 0:
                continue
            u_vals, v_vals = axes[free[0]], axes[free[1]]
            ui = np.abs(u_vals[:, None] - verts[:,0]).argmin(axis=0)
            vi = np.abs(v_vals[:, None] - verts[:,1]).argmin(axis=0)
            peakF = float(np.nanmax(F_slice[ui, vi])) if verts.shape[0] > 0 else np.nan
            loops_info.append((verts, peakF))
        plt.clf()

    return D_slice, U, V, ep_loops, loops_info, abs_levels

# ──────────────────────────────────────────────────────────────────────────
# (C) Plot 1×3 panels and overlay edge–slice intersections
# ──────────────────────────────────────────────────────────────────────────

ix_mid = Nx // 2-2      # integer division
iy_mid = Ny // 2 -1 
iz_mid = Nz // 2 -1

kx_fix = 0
ky_fix = 0
kz_fix = np.pi/2

slices  = [("kx", kx_fix), ("ky", ky_fix), ("kz", kz_fix)]
results = [(ax, val) + compute_loops(ax, val, kx_vals, ky_vals, kz_vals, D, Fmag_grid)
           for ax, val in slices]

# Shared color scale for broken-region loops
all_peaks = [peak for (_a,_v,_D,_U,_V,_ep, loops,_al) in results
                  for (_verts, peak) in loops if np.isfinite(peak)]
Fmin = max(np.nanmin(all_peaks), 1e-12) if all_peaks else 1e-12
Fmax = np.nanmax(all_peaks) if all_peaks else 1.0
norm = LogNorm(vmin=Fmin, vmax=Fmax)
cmap = get_cmap("coolwarm")

fig, axes = plt.subplots(1, 3, figsize=(12, 4), constrained_layout=True)
axis_to_idx = {"kx": 0, "ky": 1, "kz": 2}

# pick a font size
fontsize = 16

for ax, (axis, val, D_slice, U, V, ep_loops, loops_info, abs_levels) in zip(axes, results):
    # 1) EP‐loops
    for verts in ep_loops:
        ax.plot(verts[:,0], verts[:,1], "--k", lw=2)
 
    # 2) Broken‐loops
    for verts, peakF in loops_info:
        color = cmap(norm(peakF)) if peakF>0 else "grey"
        ax.plot(verts[:,0], verts[:,1], color=color, lw=2.5)
 
    # ────────────────────────────────────────────────────────────────────
    # 3) one 3-D nearest-edge point *per EP-loop*, then show its 2-D projection
    i_fix = axis_to_idx[axis]                    # index of fixed coord
    free  = ["kx", "ky", "kz"]; free.remove(axis)
    u_i, v_i = axis_to_idx[free[0]], axis_to_idx[free[1]]

    # --------------------------------------------------------------------
    selected = []                                # (n_loops , 2) after fill

    for loop_id, verts in enumerate(ep_loops):
        # --- centroid of the loop, lifted to 3-D on this slice -------------
        cen2 = verts.mean(axis=0)                # (u , v)
        if axis == "kx":
            cen3 = np.array([val,        cen2[0], cen2[1]])
        elif axis == "ky":
            cen3 = np.array([cen2[0],    val,     cen2[1]])
        else:  # axis == "kz"
            cen3 = np.array([cen2[0],    cen2[1], val])

        # --- scan *all* sample points on *all* edges for the nearest point --
        best_pt   = None
        best_d2   = np.inf

        for _u, _v, data in hopf_graph.edges(data=True):
            pts = data["pts"]                    # (P,3)
            d2  = ((pts - cen3)**2).sum(axis=1)  # squared 3-D distance
            k   = np.argmin(d2)
            if d2[k] < best_d2:
                best_d2 = d2[k]
                best_pt = pts[k]                 # (x,y,z)

        # --- store the (u , v) projection of that best 3-D point -----------
        if best_pt is not None:
            selected.append(best_pt[[u_i, v_i]])
        else:
            print(f"⚠️ Loop {loop_id}: no sample points found in graph!")

    selected = np.asarray(selected)              # shape (n_loops , 2)
    # --------------------------------------------------------------------

    # 3a) 2-D scatter of *all* chosen representatives (red ×)
    ax.scatter(selected[:,0], selected[:,1],
            marker='x', c='black', s=30, label='nearest edge-point')

    print(f"Selected {len(selected)} representative points for "
        f"{len(ep_loops)} EP loops (3-D nearest-edge search).")
    # ────────────────────────────────────────────────────────────────────


    # ─────────────────────────────────────────────────────

    # 4) Labels and styling
    ax.set_xlabel(rf"${free[0]}$", fontsize=fontsize)
    ax.set_ylabel(rf"${free[1]}$", fontsize=fontsize)
    ax.set_title(rf"{axis} = {val:.3f}", fontsize=fontsize)

    ax.set_aspect("equal")
    ax.set_box_aspect(1)
 

# 5) Shared colorbar
sm = ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(
    sm,
    ax=axes,
    location="right",
    label=r"max‖F‖",
    shrink=0.8
)
cbar.ax.tick_params(labelsize=fontsize)
cbar.set_label(r"max‖F‖", fontsize=fontsize, labelpad=10)
 
plt.savefig("hopf_intersections.pdf", bbox_inches="tight", pad_inches=0)
plt.show()

### Figure Creation

In [None]:
import numpy as np
import plotly.io as pio

# 1) Build & capture the figure
fig_ep = plot_ep_surface_and_vectors(
    KX, KY, KZ, D,
    (Fx, Fy, Fz),
    mask_shell,
    hopf_graph,
    cone_color='purple',
    SizeCone=1.4,
    node_color='red',
    edge_color='blue',
    node_size=5,
    edge_width=7,
    max_arrows=155
)

# 2) Remove any legend entries or trace names
fig_ep.update_traces(showlegend=False)
for trace in fig_ep.data:
    trace.name = ""

# 3) Tidy axes, camera & enforce cube aspect
fig_ep.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(
        aspectmode='cube',
        camera=dict(
            eye=dict(x=2, y=-1, z=0.5),
            up = dict(x=0, y=0, z=1)
        ),
        xaxis=dict(
            title='kx',
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[-np.pi, np.pi],
            autorange=False
        ),
        yaxis=dict(
            title='ky',
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[-np.pi, np.pi],
            autorange=False
        ),
        zaxis=dict(
            title='kz',
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[0, np.pi],
            autorange=False
        )
    )
)

# 4) Export at the same 600×600 size and scale
pio.write_image(
    fig_ep,
    "ep_surface_vectors.pdf",
    format='pdf',
    width=600,
    height=600,
    scale=2
)

# 5) (Optional) display inline
fig_ep.show()


In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

# 1) Sample the surface points
surf = hopf.knot_surface_points(thickness=gamma, epsilon=0.01)
x, y, z = surf[:,0], surf[:,1], surf[:,2]

# 2) Build the base Plotly figure of the knot surface
fig_surf = go.Figure(
    data=go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(size=3, color='blue', opacity=0.7)
    )
)

# 3) Prepare mesh‐grids for the three slicing planes
N = 60  # resolution of each plane

# Plane at kx = 0 (yz‐plane)
yy, zz = np.meshgrid(np.linspace(-np.pi, np.pi, N),
                     np.linspace(0, np.pi, N),
                     indexing='ij')
xx0 = np.zeros_like(yy)

# Plane at ky = 0 (xz‐plane)
xx, zz2 = np.meshgrid(np.linspace(-np.pi, np.pi, N),
                      np.linspace(0, np.pi, N),
                      indexing='ij')
yy0 = np.zeros_like(xx)

# Plane at kz = π/2 (xy‐plane)
xx2, yy2 = np.meshgrid(np.linspace(-np.pi, np.pi, N),
                       np.linspace(-np.pi, np.pi, N),
                       indexing='ij')
zz_half = (np.pi/2) * np.ones_like(xx2)

# 4) Add the three semi-transparent slice planes
fig_surf.add_trace(go.Surface(
    x=xx0, y=yy, z=zz,
    opacity=0.3,
    showscale=False,
    colorscale=[[0, 'gray'], [1, 'gray']],
    name='kx=0'
))
fig_surf.add_trace(go.Surface(
    x=xx, y=yy0, z=zz2,
    opacity=0.3,
    showscale=False,
    colorscale=[[0, 'gray'], [1, 'gray']],
    name='ky=0'
))
fig_surf.add_trace(go.Surface(
    x=xx2, y=yy2, z=zz_half,
    opacity=0.3,
    showscale=False,
    colorscale=[[0, 'gray'], [1, 'gray']],
    name='kz=π/2'
))

# 5) Tidy up axes, camera & layout
fig_surf.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(
        aspectmode='cube',
        xaxis=dict(title='kx',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[-np.pi, np.pi], autorange=False),
        yaxis=dict(title='ky',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[-np.pi, np.pi], autorange=False),
        zaxis=dict(title='kz',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[0, np.pi], autorange=False),
        camera=dict(
            eye=dict(x=2, y=-1, z=0.5),
            up=dict(x=0, y=0, z=1)
        )
    )
)

# 6) Export to PDF and PNG (requires `pip install -U kaleido`)
pio.write_image(fig_surf, "hopf_surface.pdf",
                format='pdf', width=600, height=600, scale=2) 

# 7) (Optional) Display in notebook or interactive window
fig_surf.show()


In [None]:
import numpy as np
import plotly.io as pio

# 1) Build your graph‐orientation figure as before
fig_graph = plot_graph_with_edge_orientations(
    KX, KY, KZ,
    Fx, Fy, Fz,
    mask_shell,
    hopf_graph,
    cone_color='purple',
    size_factor=0.4,
    node_color='red',
    edge_color='blue',
    node_size=7,
    edge_width=4
)

# 2) Turn off all legends/traces in case any trace.name was set
fig_graph.update_traces(showlegend=False)

# 3) Strip axis titles, ticks & grids, enforce cube aspect + camera
fig_graph.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(
        aspectmode='cube',
        camera=dict(eye=dict(x=2, y=-1, z=0.5),
                    up = dict(x=0, y=0, z=1)),
        xaxis=dict(
            title='',            # remove axis title
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[-np.pi, np.pi],
            autorange=False
        ),
        yaxis=dict(
            title='',
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[-np.pi, np.pi],
            autorange=False
        ),
        zaxis=dict(
            title='',
            showticklabels=False,
            showgrid=False,
            zeroline=False,
            range=[0, np.pi],
            autorange=False
        )
    )
)
# 5) Tidy up axes, camera & layout
fig_graph.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    scene=dict(
        aspectmode='cube',
        xaxis=dict(title='kx',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[-np.pi, np.pi], autorange=False),
        yaxis=dict(title='ky',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[-np.pi, np.pi], autorange=False),
        zaxis=dict(title='kz',
                   showticklabels=False, showgrid=False, zeroline=False,
                   range=[0, np.pi], autorange=False),
        camera=dict(
            eye=dict(x=2, y=-1, z=0.5),
            up=dict(x=0, y=0, z=1)
        )
    )
)

# 4) Export at the same 600×600 size
pio.write_image(
    fig_graph,
    "hopf_graph_orientations.pdf",
    format='pdf',
    width=600,
    height=600,
    scale=2
)

# 5) Optional: display inline
fig_graph.show()


In [None]:
import matplotlib.pyplot as plt
from pdf2image import convert_from_path

# Desired height in inches for both figures  
# Load the first pages of the PDFs
img1 = convert_from_path('hopf_surface.pdf', dpi=600)[0]
img2 = convert_from_path('hopf_intersections.pdf', dpi=600)[0]

# Create subplots with width ratios 1:3 (total width = 4*x inches, height = x inches)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(4 * x, x), gridspec_kw={'width_ratios': [1, 3]})

# Display images
ax1.imshow(img1)
ax2.imshow(img2)

# Remove axes for a clean look
for ax in (ax1, ax2):
    ax.axis('off')

plt.tight_layout()
plt.savefig('surfaces_horizontal')
plt.show()


In [None]:
import matplotlib.pyplot as plt
from pdf2image import convert_from_path
from PIL import ImageOps

# “Unit” height in inches
x = 4

# Convert PDFs → PIL Images at 600 dpi
img1 = convert_from_path('hopf_surface.pdf',       dpi=600)[0]
img2 = convert_from_path('hopf_intersections.pdf', dpi=600)[0]

# Crop img1 by fixed pixels on each side
px_left, px_top, px_right, px_bottom = 250*3, 450*3, 200*3, 200*3
w, h = img1.size
img1 = img1.crop((px_left, px_top, w - px_right, h - px_bottom))
px_left, px_top, px_right, px_bottom = 0, 0,0, 10*3
w, h = img2.size
img2= img2.crop((px_left, px_top, w - px_right, h - px_bottom))

# Create a high-DPI figure
dpi_val = 600
fig, (ax1, ax2) = plt.subplots(
    2, 1,
    figsize=(x,   x),
    dpi=dpi_val,
    gridspec_kw={'height_ratios': [3, 1]}
)

ax1.imshow(img1)
ax2.imshow(img2)
for ax in (ax1, ax2):
    ax.axis('off')

plt.tight_layout()

# Save at high DPI with tight bounding box
fig.savefig(
    "surfaces_vertical.pdf",
    format='pdf',
    dpi=dpi_val,
    bbox_inches='tight',
    pad_inches=0
)
plt.show()


In [None]:
import matplotlib.pyplot as plt
from pdf2image import convert_from_path
from PIL import ImageOps

# “Unit” height in inches
x = 4
dpi_val = 600

# Convert PDFs → PIL Images at high DPI
img1 = convert_from_path('ep_surface_vectors.pdf',       dpi=dpi_val)[0]
img2 = convert_from_path('hopf_graph_orientations.pdf',  dpi=dpi_val)[0]

# —– Crop margins for img1 (adjust px values as needed) —–
px_left1, px_top1, px_right1, px_bottom1 = 0*3, 500*3, 0*3, 200*3
w1, h1 = img1.size
img1 = img1.crop((px_left1, px_top1, w1 - px_right1, h1 - px_bottom1))

# —– Crop margins for img2 (adjust px values as needed) —–
px_left2, px_top2, px_right2, px_bottom2 = 0*3, 500*3, 0*3,200*3
w2, h2 = img2.size
img2 = img2.crop((px_left2, px_top2, w2 - px_right2, h2 - px_bottom2))

# Create figure with two equal-height square panels (x × x each)
fig, (ax1, ax2) = plt.subplots(
    2, 1,
    figsize=(x, 1* x),  # width = x, total height = 2x
    dpi=dpi_val,
    gridspec_kw={'height_ratios': [1, 1]}
)

# Display
ax1.imshow(img1)
ax2.imshow(img2)
for ax in (ax1, ax2):
    ax.axis('off')

plt.tight_layout()

# Save at high DPI
fig.savefig(
    "surfaces_vertical_equal.pdf",
    format='pdf',
    dpi=dpi_val,
    bbox_inches='tight',
    pad_inches=0
)
plt.show()
