In [24]:
import jax
import jax.numpy as jnp
from jax import jit, vmap
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
jax.config.update("jax_enable_x64", True) 



### writhe

In [25]:
@jit
def compute_writhe_chat(positions):
    N = positions.shape[0]
    tangents = jnp.diff(positions, axis=0)
    seg_len = jnp.linalg.norm(tangents, axis=1, keepdims=True)
    tangents = tangents / seg_len
    

    def pairwise_writhe(i1, i2):
        ri = positions[i1]
        rj = positions[i2]
        ti = tangents[i1]
        tj = tangents[i2]
        diff = ri - rj
        cross_prod = jnp.cross(tj, diff)
        scalar_triple = jnp.dot(ti, cross_prod)
        dist_cubed = jnp.linalg.norm(diff) ** 3
        return (scalar_triple / dist_cubed) * seg_len[i1] * seg_len[i2]

    # Create indices for all i1 < i2 combinations
    i1_indices, i2_indices = jnp.triu_indices(N - 1, k=1)

    # Vectorize pairwise_writhe over index arrays
    vectorized_writhe = vmap(pairwise_writhe)    
    writhe_sum = jnp.sum(vectorized_writhe(i1_indices, i2_indices))

    writhe = (1.0 / (2.0 * jnp.pi)) * writhe_sum
    return writhe

### Shapes

In [26]:
#Trefoil Knot
def trefoil_shape(spacings, period):
    
    t_tref = jnp.linspace(0,  period*jnp.pi, spacings)
    x = jnp.sin(t_tref) + 2 * jnp.sin(2 * t_tref) #(n,)
    y = jnp.cos(t_tref) - 2 * jnp.cos(2 * t_tref) #(n,)
    z = jnp.sin(3 * t_tref) #(n,)

    positions_trefoil = jnp.vstack([x, y, z]) #(3,n)
    positions_trefoil = positions_trefoil.T #(n, 3)
    return positions_trefoil

#twisted trefoil

def twisted_trefoil_shape(spacings, period):
    t_tref = jnp.linspace(0, period * jnp.pi, spacings)
    x = jnp.sin(t_tref) + 2 * jnp.sin(2 * t_tref)
    y = jnp.cos(t_tref) - 2 * jnp.cos(2 * t_tref) + 5 * jnp.sin(5 * t_tref)  # twist
    z = jnp.sin(3 * t_tref)

    twisted_trefoil = jnp.vstack([x, y, z]).T  # (n, 3)
    return twisted_trefoil

#Helix
# def helix_shape(spacings, period):
#     t_hel = jnp.linspace(0, period * jnp.pi, spacings)
#     t = jnp.linspace(0, 2 * jnp.pi, n)
#     x = jnp.cos(t_hel)
#     y = jnp.sin(t_hel)
#     z = t
#     positions_helix = jnp.vstack([x,y,z])
#     positions_helix = positions_helix.T
#     return positions_helix


##twisted_Helix
def twisted_helix_shape(spacings, period):
    t = jnp.linspace(0, period * jnp.pi, spacings)
    x = jnp.cos(t)
    y = jnp.sin(t)
    z = t + 5 * jnp.sin(5 * t) 
    positions_helix = jnp.vstack([x,y,z])
    twisted_helix = positions_helix.T
    return twisted_helix

### Plot Wr vs Segment

In [27]:
# @jit
def plot_wr_vs_segments(positions, fig_shape:str):
    segment_counts = []
    writhe_values = []

    N = positions.shape[0]

    # Loop over increasing number of points in the curve
    for k in range(1, N + 1):  
        sub_positions = positions[:k+1]

        # Compute segment 
        segments = jnp.diff(sub_positions, axis=0)
        num_segments = segments.shape[0]  

        # Compute writhe for the current sub-curve
        writhe = compute_writhe_chat(sub_positions)

        # Store number of segments and writhe value
        segment_counts.append(num_segments)
        writhe_values.append(writhe)

    # Plot the result
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=segment_counts,
        y=writhe_values,
        mode='lines+markers',
        name='Writhe vs Segment Count'
    ))

    fig.update_layout(
        title='Writhe vs Number of Segments ({})'.format(fig_shape),
        xaxis_title='Number of Segments',
        yaxis_title='Writhe',
        template='plotly_white'
    )

    fig.show()


In [28]:
# helix_pos  =helix_shape(100, 2)
trefoil_pos = trefoil_shape(100, 2)
twis_trefoil_pos = twisted_trefoil_shape(100,2)
twis_helix_pos = twisted_helix_shape(100, 2)

plot_wr_vs_segments(trefoil_pos, "Trefoil")

# Plot 1: Writhe vs segment length

In [29]:
segment_counts = []
writhe_values = []
trefoil_pos = trefoil_shape(100,2)
# trefoil_pos = helix_shape(100, 8)
N = trefoil_pos.shape[0]

# Loop over increasing number of points 
for k in range(1, N + 1):  
    sub_positions = trefoil_pos[:k+1]

    # Compute segment 
    segments = jnp.diff(sub_positions, axis=0)
    num_segments = segments.shape[0]  

    # Compute writhe for the current subsegment
    writhe = compute_writhe_chat(sub_positions)

    # Store number of segments and writhe value
    segment_counts.append(num_segments)
    writhe_values.append(writhe.item())

#plotly plot
fig = make_subplots(rows=1, cols=2,
                    specs=[[{"type": "scatter3d"}, {"type": "xy"}]],
                    subplot_titles=["Trefoil Knot", "Writhe vs Segment Count"])

# hover_labels = [f"Segment: {i}" for i in range(len(trefoil_pos))]
hover_labels = [f"Segment: {i}" for i in segment_counts]


# 3D Knot
fig.add_trace(go.Scatter3d(
    x=trefoil_pos[:, 0],
    y=trefoil_pos[:, 1],
    z=trefoil_pos[:, 2],
    mode='lines+markers',
    marker=dict(size=2),
    line=dict(width=2),
    name='Trefoil Curve',
    text=hover_labels, 
    hoverinfo='text'
), row=1, col=1)

# Writhe curve
fig.add_trace(go.Scatter(
    x=segment_counts,
    y=writhe_values,
    mode='lines+markers',
    name='Writhe'
), row=1, col=2)

fig.add_trace(go.Scatter3d(
    x=[trefoil_pos[0,0]],
    y=[trefoil_pos[0,1]],
    z=[trefoil_pos[0,2]],
    mode="markers",
    marker=dict(size=5, color="green"),
    name="trefoil start"
),row=1, col=1)

# End point red
fig.add_trace(go.Scatter3d(
    x=[trefoil_pos[-2, 0]],
    y=[trefoil_pos[-2, 1]],
    z=[trefoil_pos[-2, 2]],
    mode="markers",
    marker=dict(size=5, color="red"),
    name="trefoil end"
), row=1, col=1)

fig.update_layout(
    height=600,
    width=1000,
    title_text="Trefoil Knot and Writhe",
    template='plotly_white'
)

fig.show()
#virdis 


# Plot 2: Writhe vs step number

In [40]:

# Compute writhe for increasing resolutions
resolutions = list(range(10, 101, 10))
writhe_values = []

for n in resolutions:
    positions = trefoil_shape(n, period=2)
    # positions = helix_shape(n, 8)

    writhe = compute_writhe_chat(positions)
    writhe_values.append(writhe.item())

trefoil_pos = positions 

# Plot using Plotly subplots
fig = make_subplots(
    rows=1, cols=2,
    specs=[[{"type": "scatter3d"}, {"type": "xy"}]],
    subplot_titles=["Trefoil Knot ", "Writhe vs Point Count"]
)

# 3D Trefoil
hover_labels = [f"Segment: {i}" for i in range(len(trefoil_pos))]

fig.add_trace(go.Scatter3d(
    x=trefoil_pos[:, 0],
    y=trefoil_pos[:, 1],
    z=trefoil_pos[:, 2],
    mode="lines+markers",
    marker=dict(size=2),
    line=dict(width=2),
    text=hover_labels,
    hoverinfo='text',
    name="Trefoil "
), row=1, col=1)

# Start and near-end markers
fig.add_trace(go.Scatter3d(
    x=[trefoil_pos[0, 0]],
    y=[trefoil_pos[0, 1]],
    z=[trefoil_pos[0, 2]],
    mode="markers",
    marker=dict(size=5, color="green"),
    name="Start Point"
), row=1, col=1)

fig.add_trace(go.Scatter3d(
    x=[trefoil_pos[-2, 0]],
    y=[trefoil_pos[-2, 1]],
    z=[trefoil_pos[-2, 2]],
    mode="markers",
    marker=dict(size=5, color="red"),
    name="Before End"
), row=1, col=1)

# Writhe vs Resolution (right subplot)
fig.add_trace(go.Scatter(
    x=resolutions,
    y=writhe_values,
    mode='lines+markers',
    name='Writhe'
), row=1, col=2)

fig.update_layout(
    height=600,
    width=1000,
    
    template='plotly_white'
)

fig.show()


### Writhe vs figure scale

In [31]:

positions = trefoil_shape(100, 2)

# Scale factors
a_values = jnp.linspace(-4, 4, 50)
writhe_values = []

for a in a_values:
    scaled_positions = a * positions
    writhe = compute_writhe_chat(scaled_positions)
    writhe_values.append(writhe.item())

# Plot
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=a_values,
    y=writhe_values,
    mode='lines+markers',
    name='Writhe vs Scaling'
))

fig.update_layout(
    xaxis_title='Scaling Factor ',
    yaxis_title='Writhe',
    template='plotly_white'
)

fig.show()


### Writhe vs coordinates perturbation

In [32]:
n = 100
clean_positions = trefoil_shape(n, 2)
true_writhe = compute_writhe_chat(clean_positions)


noise_levels= np.linspace(0, 0.6, 10)
n_trials = 1000

writhe_means = []
writhe_stds = []

for sigma in noise_levels:
    wr_values = []
    for _ in range(n_trials):
        noise = np.random.normal(loc = 0,scale=sigma, size=clean_positions.shape)
        # print(noise)
        noisy_positions = clean_positions + noise
        wr_noisy = compute_writhe_chat(noisy_positions)
        wr_values.append(wr_noisy)
    wr_values = np.array(wr_values)
    writhe_means.append(wr_values.mean())
    writhe_stds.append(wr_values.std())

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=noise_levels,
    y=writhe_means,
    error_y=dict(type='data', array=writhe_stds, visible=True),
    mode='lines+markers',
    name='Mean Writhe ± Std'
))

fig.update_layout(
    title='Effect of Gaussian Noise on Writhe (Trefoil)',
    xaxis_title='Noise sigma',
    #plot y only show -10, 10
    yaxis_title='Writhe',
    template='plotly_white'
)

fig.show()


### visualize the effect of noise on the trefoil shape and bends

In [37]:
# Parameters
n = 100
sigma = 0.6  # Choose noise level
np.random.seed(42)  # Reproducibility

# Generate noisy trefoil
clean_positions = trefoil_shape(n, 2)
noise = np.random.normal(scale=sigma, size=clean_positions.shape)
noisy_positions = clean_positions + noise

# Hover labels
hover_labels = [f"Segment: {i}" for i in range(len(noisy_positions))]

# Plot 3D noisy trefoil
fig = make_subplots(rows=1, cols=1,
                    specs=[[{"type": "scatter3d"}]],
                    subplot_titles=[f"Noisy Trefoil (σ = {sigma})"])

fig.add_trace(go.Scatter3d(
    x=noisy_positions[:, 0],
    y=noisy_positions[:, 1],
    z=noisy_positions[:, 2],
    mode='lines+markers',
    marker=dict(size=2),
    line=dict(width=2),
    name='Noisy Trefoil',
    text=hover_labels,
    hoverinfo='text'
), row=1, col=1)

fig.update_layout(
    height=600,
    width=800,
    title_text=f"Noisy Trefoil Knot (Gaussian Noise σ = {sigma})",
    template='plotly_white'
)

fig.show()


### Writhe values vs a strech along the Z axis

In [39]:
n = 100
positions = trefoil_shape(n, 2)
positions = np.array(positions)
z_factors = np.linspace(-0.5, 0.5, 50)
writhe_values = []

for a in z_factors:
    stretched = positions.copy()
    stretched[:, 2] *= a  # stretch only z
    writhe = compute_writhe_chat(stretched)
    writhe_values.append(writhe)


fig = go.Figure()
fig.add_trace(go.Scatter(
    x=z_factors,
    y=writhe_values,
    mode='lines+markers',
    name='Writhe vs Z Stretch'
))

fig.update_layout(
    title='Writhe vs Z-Axis Stretching ',
    xaxis_title='Z Stretch Factor ',
    yaxis_title='Writhe',
    template='plotly_white'
)

fig.show()
#fig.save as pdf 