### First step of this notebook use Tracking Data Analysis from the track_limneariztion library from the Loren Frank Lab
(https://github.com/LorenFrankLab/track_linearization/tree/master)

In [2]:
%reload_ext autoreload
%autoreload 2

1. Create T-Maze Graph

In [None]:
from track_linearization import make_track_graph, plot_track_graph
import matplotlib.pyplot as plt

nodes = [
    (100,0,),
    (100,100),
    (15,150),
    (185,150),
]
edges = [
    (0, 1),  # connects node 0 and node 1
    (1, 2),  # connects node 0 and node 3
    (1, 3),  # connects node 1 and node 2
]

track_graph = make_track_graph(nodes, edges)

fig, ax = plt.subplots()
plot_track_graph(track_graph, ax=ax, draw_edge_labels=True)
ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
ax.set_xlabel("x-position")
ax.set_ylabel("y-position")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

In [None]:
from track_linearization.utils import plot_graph_as_1D

fig, ax = plt.subplots(figsize=(7, 1))
plot_graph_as_1D(track_graph, ax=ax)

2. Simulate Tracking Data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal

fig, ax = plt.subplots()
plot_track_graph(track_graph, ax=ax, draw_edge_labels=True)
ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
ax.set_xlabel("x-position")
ax.set_ylabel("y-position")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

# Define the number of points per arm and the total number of frames
num_points = 50
num_frames = 8 * num_points  # Total frames for the animation (8 segments)

# Define the starting point for Arm 1 (center of the Y-maze)
start_x, start_y = 100, 0

# Generate position data for Y-maze arms
# Arm 1: Vertical line from (100, 0) to (100, 100)
arm1_x_forward = np.full(num_points, start_x)  # Constant x-coordinate
arm1_y_forward = np.linspace(start_y, start_y + 100, num_points)  # Linearly spaced y-coordinates
arm1_x_backward = arm1_x_forward[::-1]  # Reverse the x-coordinates
arm1_y_backward = arm1_y_forward[::-1]  # Reverse the y-coordinates

# Arm 2: Starts at (100, 100) and extends at 150°
theta_2 = 150 * np.pi / 180  # Convert angle to radians
arm2_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_2)
arm2_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_2)
arm2_x_backward = arm2_x_forward[::-1]  # Reverse the x-coordinates
arm2_y_backward = arm2_y_forward[::-1]  # Reverse the y-coordinates

# Arm 3: Starts at (100, 100) and extends at -30°
theta_3 = 30 * np.pi / 180  # Convert angle to radians
arm3_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_3)
arm3_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_3)
arm3_x_backward = arm3_x_forward[::-1]  # Reverse the x-coordinates
arm3_y_backward = arm3_y_forward[::-1]  # Reverse the y-coordinates

# Combine all segments into a single trajectory
trajectory_x = np.concatenate([
    arm1_x_forward,  # Forward along arm 1
    arm2_x_forward,  # Forward along arm 2
    arm2_x_backward,  # Backward along arm 2
    arm1_x_backward,  # Backward along arm 1
    arm1_x_forward,  # Forward along arm 1 again
    arm3_x_forward,  # Forward along arm 3
    arm3_x_backward,  # Backward along arm 3
    arm2_x_forward   # Forward along arm 2 (final movement)
])

trajectory_y = np.concatenate([
    arm1_y_forward,  # Forward along arm 1
    arm2_y_forward,  # Forward along arm 2
    arm2_y_backward,  # Backward along arm 2
    arm1_y_backward,  # Backward along arm 1
    arm1_y_forward,  # Forward along arm 1 again
    arm3_y_forward,  # Forward along arm 3
    arm3_y_backward,  # Backward along arm 3
    arm2_y_forward   # Forward along arm 2 (final movement)
])

# Add noise to simulate measurement errors
noise = np.random.multivariate_normal([0, 0], [[0.5, 0], [0, 0.5]], len(trajectory_x))  # Shape: (400, 2)
trajectory_x += noise[:, 0]
trajectory_y += noise[:, 1]
position = np.vstack((trajectory_x, trajectory_y)).T

# # Combine all arms into a single position array
# position = np.vstack((
#     np.stack((arm1_x, arm1_y), axis=1),
#     np.stack((arm2_x, arm2_y), axis=1),
#     np.stack((arm3_x, arm3_y), axis=1)
# ))

# # Add noise to simulate measurement errors
# noise = np.random.multivariate_normal([0, 0], [[0.5, 0], [0, 0.5]], len(position))  # Shape: (150, 2)
# position += noise

# Assign colors based on which arm each point belongs to
colors = ['blue'] * num_points *3 + ['green'] * num_points*3 + ['red'] * num_points*2

# Plotting the Y-maze arms with transparency
plt.scatter(position[:num_points, 0], position[:num_points, 1], c='blue', s=10, alpha=0.6, label='Arm 1')
plt.scatter(position[num_points:num_points*2, 0], 
            position[num_points:num_points*2, 1], c='green', s=10, alpha=0.6, label='Arm 2')
plt.scatter(position[num_points*2:, 0], 
            position[num_points*2:, 1], c='red', s=10, alpha=0.6, label='Arm 3')

# Add legend, title, and labels
plt.legend(loc='lower right', fontsize=10, markerscale=2)  # Add legend with custom location and marker size
plt.title('Y-Maze Structure')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True)

# Show the plot
plt.show()

In [5]:
from track_linearization import get_linearized_position

position_df = get_linearized_position(position=position, track_graph=track_graph,use_HMM=False)

In [None]:
fig, ax = plt.subplots(figsize=(12, 3))
ax.scatter(
    position_df.index[:50],
    position_df.linear_position[:50],
    s=10,
    zorder=2,
    clip_on=False,
)
ax.scatter(
    position_df.index[50:150],
    position_df.linear_position[50:150],
    s=10,
    zorder=2,
    clip_on=False,
)
ax.scatter(
    position_df.index[150:250],
    position_df.linear_position[150:250],
    s=10,
    zorder=2,
    clip_on=False,
)
ax.scatter(
    position_df.index[250:350],
    position_df.linear_position[250:350],
    s=10,
    zorder=2,
    clip_on=False,
)
ax.scatter(
    position_df.index[350:],
    position_df.linear_position[350:],
    s=10,
    zorder=2,
    clip_on=False,
)
ax.plot(
    position_df.index,
    position_df.linear_position,
    color="lightgrey",
    zorder=1,
    clip_on=False,
)
ax.axhline(100, color="black", zorder=0, linestyle="--", clip_on=False)
ax.axhline(200, color="black", zorder=0, linestyle="--", clip_on=False)
ax.axhline(300, color="black", zorder=0, linestyle="--", clip_on=False)
ax.set_yticks([0, 100, 200, 300])
ax.set_ylim([0, 300])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlim((0, position_df.shape[0]))
ax.set_ylabel("Position")
ax.set_xlabel("Time")
plot_graph_as_1D(
    track_graph, ax=ax, axis="y", other_axis_start=position_df.index.max() + 1
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from scipy.stats import multivariate_normal

# Define the number of points per arm and the total number of frames
num_points = 50
num_frames = 8 * num_points  # Total frames for the animation (8 segments)

# Define the starting point for Arm 1 (center of the Y-maze)
start_x, start_y = 100, 0

# Generate position data for Y-maze arms
# Arm 1: Vertical line from (100, 0) to (100, 100)
arm1_x_forward = np.full(num_points, start_x)  # Constant x-coordinate
arm1_y_forward = np.linspace(start_y, start_y + 100, num_points)  # Linearly spaced y-coordinates
arm1_x_backward = arm1_x_forward[::-1]  # Reverse the x-coordinates
arm1_y_backward = arm1_y_forward[::-1]  # Reverse the y-coordinates

# Arm 2: Starts at (100, 100) and extends at 150°
theta_2 = 150 * np.pi / 180  # Convert angle to radians
arm2_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_2)
arm2_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_2)
arm2_x_backward = arm2_x_forward[::-1]  # Reverse the x-coordinates
arm2_y_backward = arm2_y_forward[::-1]  # Reverse the y-coordinates

# Arm 3: Starts at (100, 100) and extends at -30°
theta_3 = 30 * np.pi / 180  # Convert angle to radians
arm3_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_3)
arm3_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_3)
arm3_x_backward = arm3_x_forward[::-1]  # Reverse the x-coordinates
arm3_y_backward = arm3_y_forward[::-1]  # Reverse the y-coordinates

# Combine all segments into a single trajectory
trajectory_x = np.concatenate([
    arm1_x_forward,  # Forward along arm 1
    arm2_x_forward,  # Forward along arm 2
    arm2_x_backward,  # Backward along arm 2
    arm1_x_backward,  # Backward along arm 1
    arm1_x_forward,  # Forward along arm 1 again
    arm3_x_forward,  # Forward along arm 3
    arm3_x_backward,  # Backward along arm 3
    arm2_x_forward   # Forward along arm 2 (final movement)
])

trajectory_y = np.concatenate([
    arm1_y_forward,  # Forward along arm 1
    arm2_y_forward,  # Forward along arm 2
    arm2_y_backward,  # Backward along arm 2
    arm1_y_backward,  # Backward along arm 1
    arm1_y_forward,  # Forward along arm 1 again
    arm3_y_forward,  # Forward along arm 3
    arm3_y_backward,  # Backward along arm 3
    arm2_y_forward   # Forward along arm 2 (final movement)
])

# Add noise to simulate measurement errors
noise = np.random.multivariate_normal([0, 0], [[0.5, 0], [0, 0.5]], len(trajectory_x))  # Shape: (400, 2)
trajectory_x += noise[:, 0]
trajectory_y += noise[:, 1]
trajectory_positions = np.vstack((trajectory_x, trajectory_y)).T
# Create a figure and axis for the animation
fig, ax = plt.subplots()
ax.set_xlim(0, 200)
ax.set_ylim(0, 200)
ax.set_title('Y-Maze Movement Animation')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.grid(True)

# Scatter plot for the points (initialize with empty data)
scatter = ax.scatter([], [], s=10, alpha=0.6)

# Function to update the scatter plot for each frame
def update(frame):
    # Determine which points are visible in the current frame
    visible_points = np.column_stack((trajectory_x[:frame], trajectory_y[:frame]))
    
    # Update the scatter plot with the visible points
    scatter.set_offsets(visible_points)
    
    # Assign colors based on which segment the points belong to
    colors = ['blue'] * min(frame, num_points) + \
             ['green'] * max(0, min(frame - num_points, num_points)) + \
             ['purple'] * max(0, min(frame - 2 * num_points, num_points)) + \
             ['orange'] * max(0, min(frame - 3 * num_points, num_points)) + \
             ['cyan'] * max(0, min(frame - 4 * num_points, num_points)) + \
             ['red'] * max(0, min(frame - 5 * num_points, num_points)) + \
             ['magenta'] * max(0, min(frame - 6 * num_points, num_points)) + \
             ['yellow'] * max(0, frame - 7 * num_points)
    scatter.set_color(colors[:frame])
    
    return scatter,

# Create the animation
ani = FuncAnimation(fig, update, frames=num_frames, interval=100, blit=True)

# Convert animation to HTML and display it
HTML(ani.to_jshtml())

In [7]:
# Function to determine which arm a point belongs to and calculate its normalized position
def get_linear_position_arm_based(x, y, arm1_x, arm1_y, arm2_x, arm2_y, arm3_x, arm3_y):
    # Helper function to calculate the closest point on an arm
    def closest_point_on_arm(x, y, arm_x, arm_y):
        distances = np.sqrt((x - arm_x)**2 + (y - arm_y)**2)
        idx = np.argmin(distances)
        return idx, distances[idx]

    # Find the closest point on each arm
    idx1, dist1 = closest_point_on_arm(x, y, arm1_x, arm1_y)
    idx2, dist2 = closest_point_on_arm(x, y, arm2_x, arm2_y)
    idx3, dist3 = closest_point_on_arm(x, y, arm3_x, arm3_y)

    # Determine which arm is closest
    min_dist = min(dist1, dist2, dist3)
    if min_dist == dist1:
        arm = 0  # Arm 1
        pos = idx1 / len(arm1_x)
    elif min_dist == dist2:
        arm = 1  # Arm 2
        pos = 1 + idx2 / len(arm2_x)
    else:
        arm = 2  # Arm 3
        pos = 2 + idx3 / len(arm3_x)

    return pos

In [8]:
# Example: Estimate linear positions for all trajectory points
linear_positions_arm_based = np.array([
    get_linear_position_arm_based(x, y, arm1_x_forward, arm1_y_forward,
                                   arm2_x_forward, arm2_y_forward,
                                   arm3_x_forward, arm3_y_forward)
    for x, y in zip(trajectory_x, trajectory_y)
])

# print("Linear Positions (Arm-Based Indexing):", linear_positions_arm_based)

In [None]:
import matplotlib.pyplot as plt

# Plot linear positions over time
plt.figure(figsize=(12, 3))
# Arm-Based Indexing
plt.subplot(1, 2, 2)
plt.plot(linear_positions_arm_based, label="Arm-Based Indexing", color="green")
plt.title("Linear Position (Arm-Based Indexing)")
plt.xlabel("Frame")
plt.ylabel("Linear Position")
plt.legend()

plt.tight_layout()
plt.show()

Demo for arm-based indexing linearization

Filtering (OPTIONAL) only with low resolution trackingg data

In [55]:
from sklearn.cluster import DBSCAN

# Function to remove noise using DBSCAN
def remove_noise(x, y, eps=2.0, min_samples=5):
    points = np.column_stack((x, y))  # Combine x and y into a single array
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(points)  # Apply DBSCAN clustering
    labels = db.labels_  # Get cluster labels (-1 indicates noise)
    filtered_x = x[labels != -1]  # Keep only non-noise points
    filtered_y = y[labels != -1]
    return filtered_x, filtered_y, labels  # Return filtered points and labels

# Apply noise removaland Adjust DBSCAN parameters
filtered_trajectory_x, filtered_trajectory_y, labels = remove_noise(trajectory_x, trajectory_y, eps=3.0, min_samples=2)

In [None]:
# Plot raw vs filtered data
plt.figure(figsize=(12, 6))

# Raw data
plt.subplot(1, 2, 1)
plt.scatter(trajectory_x, trajectory_y, c='gray', s=10, alpha=0.5, label="Raw Data")
plt.title("Raw Data with Noise")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()

# Filtered data
plt.subplot(1, 2, 2)
plt.scatter(filtered_trajectory_x, filtered_trajectory_y, c='blue', s=10, alpha=0.7, label="Filtered Data")
plt.title("Filtered Data (Noise Removed)")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()

plt.tight_layout()
plt.show()

Step 2: Identify the Center Point
The center point of the Y-maze is a critical reference for labeling the arms. It can be identified as the point with the highest density of nearby points or by geometric analysis.

In [None]:
from scipy.spatial import distance

# Function to find the center point
def find_center_point(x, y):
    points = np.column_stack((x, y))
    # Compute pairwise distances
    dist_matrix = distance.cdist(points, points, metric='euclidean')
    # Find the point with the smallest average distance to other points
    avg_distances = dist_matrix.mean(axis=1)
    center_idx = np.argmin(avg_distances)
    return points[center_idx]

# Find the center point
center_x, center_y = find_center_point(trajectory_x, trajectory_y)
# Manually
center_x = 100
center_y= 100
print("Center Point:", (center_x, center_y))

In [None]:
# Plot the center point
plt.figure(figsize=(6, 6))
plt.scatter(trajectory_x, trajectory_y, c='blue', s=10, alpha=0.7, label="Filtered Data")
plt.scatter(center_x, center_y, c='red', s=100, marker='x', label="Center Point")
plt.title("Center Point Identification")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show()

Step 3: Cluster Points into Arms
Once the center point is identified, cluster the remaining points into three groups corresponding to the three arms. You can use k-means clustering or angular separation to assign points to arms.

Option 1: K-Means Clustering
K-means clustering partitions the points into three clusters based on their spatial distribution.

In [None]:
# from sklearn.cluster import KMeans

# # Function to cluster points into arms
# def cluster_arms(x, y, center_x, center_y, n_clusters=3):
#     points = np.column_stack((x, y))
#     # Exclude the center point from clustering
#     kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(points)
#     return kmeans.labels_

# # Cluster points into arms
# arm_labels = cluster_arms(trajectory_x, trajectory_y, center_x, center_y)

Option 2: Angular Separation
If the arms are approximately straight and radiate outward from the center, you can use the angles of the points relative to the center to assign them to arms.

In [6]:
# Function to assign points to arms based on angular separation
def assign_arms_by_angle(x, y, center_x, center_y):
    angles = np.arctan2(y - center_y, x - center_x)  # Compute angles relative to the center
    angles = np.degrees(angles) % 360  # Convert to degrees and normalize to [0, 360)
    
    # Define angular ranges for each arm (adjust these based on your maze geometry)
    arm_labels = np.zeros_like(angles, dtype=int)
    arm_labels[(angles >= 0) & (angles < 120)] = 2  # Arm 1
    arm_labels[(angles >= 120) & (angles < 240)] = 1  # Arm 2
    arm_labels[(angles >= 240) & (angles < 360)] = 0  # Arm 3
    
    return arm_labels

# Assign points to arms
arm_labels = assign_arms_by_angle(trajectory_x, trajectory_y, center_x, center_y)

In [None]:
# Plot clustering results
plt.figure(figsize=(8, 8))
for arm in range(3):  # Assuming 3 arms
    arm_points_x = trajectory_x[arm_labels == arm]
    arm_points_y = trajectory_y[arm_labels == arm]
    plt.scatter(arm_points_x, arm_points_y, s=10, label=f"Arm {arm+1}")
plt.scatter(center_x, center_y, c='red', s=100, marker='x', label="Center Point")
plt.title("Clustering Results (Arm Labeling)")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.show()

Step 4: Linearize Using Arm-Based Indexing
Once the arms are labeled, you can use the arm-based indexing method to compute the linear positions.

In [15]:
# Function to calculate linear position using arm-based indexing
def get_linear_position_arm_based_with_labels(x, y, arm_labels, center_x, center_y):
    # Separate points by arm
    arm1_points = np.column_stack((x[arm_labels == 0], y[arm_labels == 0]))
    arm2_points = np.column_stack((x[arm_labels == 1], y[arm_labels == 1]))
    arm3_points = np.column_stack((x[arm_labels == 2], y[arm_labels == 2]))
    
    # Sort points along each arm by distance from the center
    def sort_points_by_distance(points, center_x, center_y):
        distances = np.sqrt((points[:, 0] - center_x)**2 + (points[:, 1] - center_y)**2)
        sorted_indices = np.argsort(distances)
        return points[sorted_indices]
    
    arm1_sorted = sort_points_by_distance(arm1_points, center_x, center_y)
    arm2_sorted = sort_points_by_distance(arm2_points, center_x, center_y)
    arm3_sorted = sort_points_by_distance(arm3_points, center_x, center_y)
    
    # Create a mapping from (x, y) to linear position
    linear_positions = np.zeros(len(x))
    for i, (xi, yi) in enumerate(zip(x, y)):
        if arm_labels[i] == 0:  # Arm 1
            idx = np.argmin(np.linalg.norm(arm1_sorted - [xi, yi], axis=1))
            linear_positions[i] = idx / len(arm1_sorted)
        elif arm_labels[i] == 1:  # Arm 2
            idx = np.argmin(np.linalg.norm(arm2_sorted - [xi, yi], axis=1))
            linear_positions[i] = 1 + idx / len(arm2_sorted)
        elif arm_labels[i] == 2:  # Arm 3
            idx = np.argmin(np.linalg.norm(arm3_sorted - [xi, yi], axis=1))
            linear_positions[i] = 2 + idx / len(arm3_sorted)
    
    return linear_positions

In [None]:
# Calculate linear positions
linear_positions_arm_based = get_linear_position_arm_based_with_labels(
    trajectory_x, trajectory_y, arm_labels, center_x, center_y
)
import matplotlib.pyplot as plt

# Create a color map for the arms
colors = ['blue', 'green', 'red']  # One color for each arm
arm_colors = [colors[label] for label in arm_labels]  # Assign colors based on arm labels

# Plot linear positions over time with colors for each arm
plt.figure(figsize=(14, 3))
for arm in range(3):  # Assuming 3 arms
    arm_mask = arm_labels == arm
    plt.scatter(
        np.arange(len(linear_positions_arm_based))[arm_mask],  # Frame indices
        linear_positions_arm_based[arm_mask],
        c=colors[arm],
        s=10,
        label=f"Arm {arm+1}"
    )
plt.title("Linearized Positions Over Time (Color-Coded by Arm)")
plt.xlabel("Frame")
plt.ylabel("Linear Position")
plt.legend()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib as mpl

# Increase the embedding limit to 50 MB
mpl.rcParams['animation.embed_limit'] = 50  # Value in MB

# # Prepare the data
# linear_positions_arm_based = get_linear_position_arm_based_with_labels(
#     filtered_trajectory_x, filtered_trajectory_y, arm_labels, center_x, center_y
# )

# Colors for arms
colors = ['blue', 'green', 'red']
arm_colors = [colors[label] for label in arm_labels]

# Set up the figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3))

# 2D Plot (Y-Maze)
ax1.set_xlim(0, 200)
ax1.set_ylim(0, 200)
ax1.set_title("2D Position in Y-Maze")
ax1.set_xlabel("X")
ax1.set_ylabel("Y")
ax1.grid(True)

# Scatter plot for 2D positions
scatter_2d = ax1.scatter([], [], s=50, c='red', label="Current Position")
trajectory_2d = ax1.scatter(trajectory_x, trajectory_y, s=10, c=arm_colors, alpha=0.5, label="Trajectory")

# 1D Plot (Linearized Position)
ax2.set_xlim(0, len(linear_positions_arm_based))
ax2.set_ylim(-0.5, 3.5)  # Adjust based on your linearized range
ax2.set_title("1D Linearized Position")
ax2.set_xlabel("Frame")
ax2.set_ylabel("Linear Position")
ax2.grid(True)

# Line plot for linearized positions
line_1d, = ax2.plot([], [], lw=2, label="Linearized Position")
scatter_1d = ax2.scatter([], [], s=50, c='red', label="Current Position")

# Add legends
ax1.legend()
ax2.legend()

# Initialize the animation
def init():
    # Initialize 2D scatter plot with an empty point
    scatter_2d.set_offsets(np.empty((0, 2)))  # Empty array with shape (0, 2)
    
    # Initialize 1D line and scatter plot
    line_1d.set_data([], [])  # Empty line
    scatter_1d.set_offsets(np.empty((0, 2)))  # Empty scatter
    
    return scatter_2d, line_1d, scatter_1d

# Update function for the animation
def update(frame):
    # Update 2D plot
    scatter_2d.set_offsets([[trajectory_x[frame], trajectory_y[frame]]])
    
    # Update 1D plot
    line_1d.set_data(range(frame + 1), linear_positions_arm_based[:frame + 1])
    scatter_1d.set_offsets([[frame, linear_positions_arm_based[frame]]])
    
    return scatter_2d, line_1d, scatter_1d

# Create the animation
ani = FuncAnimation(fig, update, frames=len(trajectory_x), init_func=init, interval=100, blit=True)

# Save the animation as a GIF
# from matplotlib.animation import PillowWriter
# ani.save("y_maze_animation.gif", writer=PillowWriter(fps=10))

# Display the animation in Jupyter Notebook (optional)
HTML(ani.to_jshtml())

Labels Functions(Practice)

In [None]:
# import numpy as np
# from sklearn.cluster import KMeans, DBSCAN
# from sklearn.decomposition import PCA
# import networkx as nx

# def label_with_kmeans(positions, num_arms, random_state=42):
#     """
#     Label maze arms using K-Means clustering.
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param num_arms: Number of arms in the maze.
#     :param random_state: Random seed for reproducibility.
#     :return: Labels for each position.
#     """
#     kmeans = KMeans(n_clusters=num_arms, random_state=random_state)
#     labels = kmeans.fit_predict(positions)
#     return labels

# def label_with_angular_separation(positions, num_arms):
#     """
#     Label maze arms using angular separation.
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param num_arms: Number of arms in the maze.
#     :return: Labels for each position.
#     """
#     center = np.mean(positions, axis=0)  # Assume center is the mean of all points
#     angles = np.arctan2(positions[:, 1] - center[1], positions[:, 0] - center[0])
#     angles = np.degrees(angles) % 360  # Convert to degrees and normalize to [0, 360)
    
#     sector_size = 360 / num_arms
#     labels = np.floor(angles / sector_size).astype(int)
#     return labels

# def label_with_pca(positions, num_arms):
#     """
#     Label maze arms using Principal Component Analysis (PCA).
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param num_arms: Number of arms in the maze.
#     :return: Labels for each position.
#     """
#     pca = PCA(n_components=2)
#     transformed = pca.fit_transform(positions)
    
#     if num_arms == 2:  # For T-maze or H-maze
#         labels = (transformed[:, 0] > 0).astype(int)
#     elif num_arms == 3:  # For Y-maze
#         labels = np.digitize(transformed[:, 1], bins=[-np.inf, 0, np.inf]) - 1
#     else:
#         raise ValueError("Unsupported number of arms.")
#     return labels

# def label_with_dbscan(positions, eps=0.5, min_samples=2):
#     """
#     Label maze arms using DBSCAN clustering.
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param eps: Maximum distance between points in a cluster.
#     :param min_samples: Minimum number of points to form a cluster.
#     :return: Labels for each position.
#     """
#     dbscan = DBSCAN(eps=eps, min_samples=min_samples)
#     labels = dbscan.fit_predict(positions)
#     return labels

# def label_with_graph(positions, threshold=1.5):
#     """
#     Label maze arms using a graph-based approach.
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param threshold: Maximum distance to connect two points.
#     :return: Labels for each position.
#     """
#     G = nx.Graph()
#     G.add_nodes_from(range(len(positions)))
    
#     # Add edges between points within the threshold distance
#     for i in range(len(positions)):
#         for j in range(i + 1, len(positions)):
#             if np.linalg.norm(positions[i] - positions[j]) < threshold:
#                 G.add_edge(i, j)
    
#     # Find connected components
#     components = list(nx.connected_components(G))
#     labels = np.zeros(len(positions), dtype=int)
#     for idx, component in enumerate(components):
#         labels[list(component)] = idx
#     return labels

# def label_maze(positions, method, **kwargs):
#     """
#     General function to label maze arms using the specified method.
#     :param positions: List of (x, y) coordinates representing points in the maze.
#     :param method: The labeling method to use ('kmeans', 'angular', 'pca', 'dbscan', 'graph').
#     :param kwargs: Additional parameters specific to the chosen method.
#     :return: Labels for each position.
#     """
#     positions = np.array(positions)  # Ensure positions are in NumPy array format
    
#     if method == "kmeans":
#         return label_with_kmeans(positions, kwargs.get("num_arms"), kwargs.get("random_state", 42))
#     elif method == "angular":
#         return label_with_angular_separation(positions, kwargs.get("num_arms"))
#     elif method == "pca":
#         return label_with_pca(positions, kwargs.get("num_arms"))
#     elif method == "dbscan":
#         return label_with_dbscan(positions, kwargs.get("eps", 0.5), kwargs.get("min_samples", 2))
#     elif method == "graph":
#         return label_with_graph(positions, kwargs.get("threshold", 1.5))
#     else:
#         raise ValueError(f"Unsupported method: {method}")




In [None]:
# Example Usage:
# if __name__ == "__main__":
#     # Define some example positions (e.g., from tracking an animal's movement)
#     positions = [(-1, -1), (1, -1), (0, 1), (-2, -2), (2, -2), (0, 2)]

#     # Label using different methods with custom parameters
#     print("K-Means Labels:", label_maze(positions, method="kmeans", num_arms=3, random_state=42))
#     print("Angular Separation Labels:", label_maze(positions, method="angular", num_arms=3))
#     print("PCA Labels:", label_maze(positions, method="pca", num_arms=3))
#     print("DBSCAN Labels:", label_maze(positions, method="dbscan", eps=1.5, min_samples=2))
#     print("Graph-Based Labels:", label_maze(positions, method="graph", threshold=1.5))


### Graph and with/without HMM labeling and position linearization
This functions were developed on top of the track_linearization library from the Frank Lab

In [1]:
import numpy as np
# Simulation of Y-Maze tracking data
# Define the number of points per arm and the total number of frames
num_points = 50
num_frames = 8 * num_points  # Total frames for the animation (8 segments)

# Define the starting point for Arm 1 (center of the Y-maze)
start_x, start_y = 100, 0

# Generate position data for Y-maze arms
# Arm 1: Vertical line from (100, 0) to (100, 100)
arm1_x_forward = np.full(num_points, start_x)  # Constant x-coordinate
arm1_y_forward = np.linspace(start_y, start_y + 100, num_points)  # Linearly spaced y-coordinates
arm1_x_backward = arm1_x_forward[::-1]  # Reverse the x-coordinates
arm1_y_backward = arm1_y_forward[::-1]  # Reverse the y-coordinates

# Arm 2: Starts at (100, 100) and extends at 150°
theta_2 = 150 * np.pi / 180  # Convert angle to radians
arm2_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_2)
arm2_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_2)
arm2_x_backward = arm2_x_forward[::-1]  # Reverse the x-coordinates
arm2_y_backward = arm2_y_forward[::-1]  # Reverse the y-coordinates

# Arm 3: Starts at (100, 100) and extends at -30°
theta_3 = 30 * np.pi / 180  # Convert angle to radians
arm3_x_forward = start_x + np.linspace(0, 100, num_points) * np.cos(theta_3)
arm3_y_forward = start_y + 100 + np.linspace(0, 100, num_points) * np.sin(theta_3)
arm3_x_backward = arm3_x_forward[::-1]  # Reverse the x-coordinates
arm3_y_backward = arm3_y_forward[::-1]  # Reverse the y-coordinates

# Combine all segments into a single trajectory
trajectory_x = np.concatenate([
    arm1_x_forward,  # Forward along arm 1
    arm2_x_forward,  # Forward along arm 2
    arm2_x_backward,  # Backward along arm 2
    arm1_x_backward,  # Backward along arm 1
    arm1_x_forward,  # Forward along arm 1 again
    arm3_x_forward,  # Forward along arm 3
    arm3_x_backward,  # Backward along arm 3
    arm2_x_forward   # Forward along arm 2 (final movement)
])

trajectory_y = np.concatenate([
    arm1_y_forward,  # Forward along arm 1
    arm2_y_forward,  # Forward along arm 2
    arm2_y_backward,  # Backward along arm 2
    arm1_y_backward,  # Backward along arm 1
    arm1_y_forward,  # Forward along arm 1 again
    arm3_y_forward,  # Forward along arm 3
    arm3_y_backward,  # Backward along arm 3
    arm2_y_forward   # Forward along arm 2 (final movement)
])

# Add noise to simulate measurement errors
noise = np.random.multivariate_normal([0, 0], [[0.5, 0], [0, 0.5]], len(trajectory_x))  # Shape: (400, 2)
trajectory_x += noise[:, 0]
trajectory_y += noise[:, 1]
position = np.vstack((trajectory_x, trajectory_y)).T

Functions to make the linearization, analysis, and visualuzation

In [19]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import seaborn as sns

def calculate_transitions_simple(track_graph, edges, diagonal_bias=0.1):
    """Original simple connectivity-based transition calculation"""
    n_edges = len(edges)
    transitions = np.zeros((n_edges, n_edges))
    
    for i, edge1 in enumerate(edges):
        for j, edge2 in enumerate(edges):
            # Allow transitions between connected segments
            if set(edge1) & set(edge2):  # if edges share a node
                transitions[i, j] = 1.0
    
    # Increase probability of staying on current segment
    np.fill_diagonal(transitions, diagonal_bias)
    # Normalize
    transitions /= transitions.sum(axis=1, keepdims=True)
    return transitions

def calculate_transitions_with_angles(track_graph, edges, diagonal_bias=0.1):
    """Calculate transitions considering segment angles at junctions"""
    n_edges = len(edges)
    transitions = np.zeros((n_edges, n_edges))
    
    for i, edge1 in enumerate(edges):
        for j, edge2 in enumerate(edges):
            # Check if edges share a node
            common_nodes = set(edge1) & set(edge2)
            if common_nodes:
                junction_node = list(common_nodes)[0]
                
                # Get vectors for both segments
                vec1 = np.array(track_graph.nodes[edge1[1]]['pos']) - \
                       np.array(track_graph.nodes[edge1[0]]['pos'])
                vec2 = np.array(track_graph.nodes[edge2[1]]['pos']) - \
                       np.array(track_graph.nodes[edge2[0]]['pos'])
                
                # Calculate angle between segments
                angle = np.abs(np.arctan2(vec2[1], vec2[0]) - 
                             np.arctan2(vec1[1], vec1[0]))
                
                # Higher probability for smoother transitions
                transitions[i, j] = np.cos(angle/2)  # Smoother turns = higher probability
    
    # Add bias for staying in same segment
    np.fill_diagonal(transitions, diagonal_bias)
    # Normalize
    transitions /= transitions.sum(axis=1, keepdims=True)
    return transitions

def calculate_transitions_with_distance(track_graph, edges, diagonal_bias=0.1):
    """Calculate transitions considering segment lengths"""
    n_edges = len(edges)
    transitions = np.zeros((n_edges, n_edges))
    
    for i, edge1 in enumerate(edges):
        edge1_length = track_graph.edges[edge1]['distance']
        for j, edge2 in enumerate(edges):
            if set(edge1) & set(edge2):
                edge2_length = track_graph.edges[edge2]['distance']
                # Favor transitions to similar length segments
                length_ratio = min(edge1_length, edge2_length) / \
                             max(edge1_length, edge2_length)
                transitions[i, j] = length_ratio
    
    np.fill_diagonal(transitions, diagonal_bias)
    transitions /= transitions.sum(axis=1, keepdims=True)
    return transitions

def calculate_transitions_with_junction_type(track_graph, edges, diagonal_bias=0.1):
    """Calculate transitions considering junction types (T, Y, Cross)"""
    n_edges = len(edges)
    transitions = np.zeros((n_edges, n_edges))
    
    # Find junction nodes and their types
    junction_nodes = {}
    for node in track_graph.nodes():
        degree = track_graph.degree(node)
        if degree > 2:
            junction_nodes[node] = degree
    
    for i, edge1 in enumerate(edges):
        for j, edge2 in enumerate(edges):
            common_nodes = set(edge1) & set(edge2)
            if common_nodes:
                junction_node = list(common_nodes)[0]
                if junction_node in junction_nodes:
                    # Adjust probability based on junction type
                    n_branches = junction_nodes[junction_node]
                    transitions[i, j] = 1.0 / (n_branches - 1)
    
    np.fill_diagonal(transitions, diagonal_bias)
    transitions /= transitions.sum(axis=1, keepdims=True)
    return transitions

def create_track_graph(node_positions=None, edges=None):
    """Create a graph representation of the track."""
    track_graph = nx.Graph()
    
    if node_positions is not None:
        for node_id, pos in node_positions.items():
            track_graph.add_node(node_id, pos=pos)
    
    if edges is not None:
        for edge_id, edge in enumerate(edges):
            track_graph.add_edge(edge[0], edge[1], 
                edge_id=edge_id, 
                distance=np.linalg.norm(
                    np.array(node_positions[edge[0]]) - np.array(node_positions[edge[1]])
                )
            )
    
    return track_graph

def classify_track_segments(track_graph, position, transition_type='simple',
                          route_euclidean_distance_scaling=1.0, 
                          sensor_std_dev=5.0, diagonal_bias=0.1):
    """
    Track segment classification with different transition types.
    
    Parameters:
    -----------
    transition_type : str
        Type of transition calculation: 'simple', 'angle', 'distance', or 'junction'
    """
    edges = list(track_graph.edges)
    n_edges = len(edges)
    n_time = position.shape[0]
    
    # 1. Calculate junction points and segment properties
    junction_nodes = set()
    for node in track_graph.nodes():
        if track_graph.degree(node) > 2:
            junction_nodes.add(node)
    
    # Calculate segment properties
    segment_properties = []
    for edge in edges:
        node1_pos = np.array(track_graph.nodes[edge[0]]['pos'])
        node2_pos = np.array(track_graph.nodes[edge[1]]['pos'])
        segment_vector = node2_pos - node1_pos
        length = np.linalg.norm(segment_vector)
        direction = segment_vector / length if length > 0 else segment_vector
        segment_properties.append({
            'start': node1_pos,
            'end': node2_pos,
            'length': length,
            'direction': direction,
            'has_junction': any(node in junction_nodes for node in edge)
        })
    
    # 2. Calculate transitions based on selected method
    if transition_type == 'angle':
        transitions = calculate_transitions_with_angles(track_graph, edges, diagonal_bias)
    elif transition_type == 'distance':
        transitions = calculate_transitions_with_distance(track_graph, edges, diagonal_bias)
    elif transition_type == 'junction':
        transitions = calculate_transitions_with_junction_type(track_graph, edges, diagonal_bias)
    else:  # 'simple' or default
        transitions = calculate_transitions_simple(track_graph, edges, diagonal_bias)
    
    # 3. Calculate emission probabilities
    emissions = np.zeros((n_time, n_edges))
    for t in range(n_time):
        for i, props in enumerate(segment_properties):
            # Project point onto segment
            v = position[t] - props['start']
            proj = np.dot(v, props['direction'])
            proj = np.clip(proj, 0, props['length'])
            proj_point = props['start'] + proj * props['direction']
            
            # Calculate perpendicular distance
            dist = np.linalg.norm(position[t] - proj_point)
            
            # Calculate emission probability
            emissions[t, i] = np.exp(-0.5 * (dist / sensor_std_dev)**2)
            
            # Add bonus for points near junction when on junction segments
            if props['has_junction']:
                for junction in junction_nodes:
                    junction_pos = np.array(track_graph.nodes[junction]['pos'])
                    dist_to_junction = np.linalg.norm(position[t] - junction_pos)
                    junction_bonus = np.exp(-0.5 * (dist_to_junction / sensor_std_dev)**2)
                    emissions[t, i] *= (1 + 0.2 * junction_bonus)
    
    # Normalize emissions
    emissions /= emissions.sum(axis=1, keepdims=True)
    
    # 4. Use Viterbi algorithm to find most likely sequence
    log_transitions = np.log(transitions)
    log_emissions = np.log(emissions)
    
    V = np.zeros((n_time, n_edges))
    path = np.zeros((n_time, n_edges), dtype=int)
    
    V[0] = log_emissions[0]
    
    for t in range(1, n_time):
        for j in range(n_edges):
            probabilities = V[t-1] + log_transitions[:, j] + log_emissions[t, j]
            V[t, j] = np.max(probabilities)
            path[t, j] = np.argmax(probabilities)
    
    track_segment_ids = np.zeros(n_time, dtype=int)
    track_segment_ids[-1] = np.argmax(V[-1])
    
    for t in range(n_time-2, -1, -1):
        track_segment_ids[t] = path[t+1, track_segment_ids[t+1]]
    
    return track_segment_ids, transitions  # Also return transitions for analysis

def improve_edge_detection(track_graph, position, emissions, window_size=5):
    """Improves edge detection using temporal smoothing and junction awareness."""
    n_time = position.shape[0]
    edges = list(track_graph.edges)
    n_edges = len(edges)
    
    junction_nodes = [node for node in track_graph.nodes() 
                     if track_graph.degree(node) > 2]
    
    segment_adjacency = np.zeros((n_edges, n_edges))
    for i, edge1 in enumerate(edges):
        for j, edge2 in enumerate(edges):
            if set(edge1) & set(edge2):
                segment_adjacency[i, j] = 1
    
    smoothed_emissions = np.zeros_like(emissions)
    for i in range(n_edges):
        smoothed_emissions[:, i] = np.convolve(emissions[:, i], 
                                             np.ones(window_size)/window_size, 
                                             mode='same')
    
    segment_ids = np.argmax(smoothed_emissions, axis=1)
    
    for t in range(1, n_time-1):
        current_segment = segment_ids[t]
        prev_segment = segment_ids[t-1]
        
        for junction in junction_nodes:
            junction_pos = np.array(track_graph.nodes[junction]['pos'])
            dist_to_junction = np.linalg.norm(position[t] - junction_pos)
            
            if dist_to_junction < 10:
                if not segment_adjacency[prev_segment, current_segment]:
                    valid_segments = np.where(segment_adjacency[prev_segment] > 0)[0]
                    valid_probs = smoothed_emissions[t, valid_segments]
                    segment_ids[t] = valid_segments[np.argmax(valid_probs)]
    
    return segment_ids

def _calculate_linear_position(track_graph, position, track_segment_id, edge_order, edge_spacing):
    """Calculate linear position along track with correct handling of Y-junctions."""
    edges = list(track_graph.edges)
    n_time = len(position)
    
    linear_position = np.zeros(n_time)
    projected_x = np.zeros(n_time)
    projected_y = np.zeros(n_time)
    
    cumulative_distances = {}
    current_distance = 0.0
    
    if isinstance(edge_spacing, (int, float)):
        edge_spacing = [edge_spacing] * (len(edge_order) - 1)
        
    for i, edge in enumerate(edge_order):
        edge_id = track_graph.edges[edge]['edge_id']
        cumulative_distances[edge_id] = current_distance
        current_distance += track_graph.edges[edge]['distance']
        if i < len(edge_spacing):
            current_distance += edge_spacing[i]
    
    for t in range(n_time):
        segment_id = int(track_segment_id[t])
        edge = edges[segment_id]
        
        start_pos = np.array(track_graph.nodes[edge[0]]['pos'])
        end_pos = np.array(track_graph.nodes[edge[1]]['pos'])
        
        segment_vector = end_pos - start_pos
        segment_length = np.linalg.norm(segment_vector)
        
        if segment_length > 0:
            point_vector = position[t] - start_pos
            projection = np.dot(point_vector, segment_vector) / (segment_length * segment_length)
            projection = np.clip(projection, 0, 1)
            
            proj_point = start_pos + projection * segment_vector
            projected_x[t] = proj_point[0]
            projected_y[t] = proj_point[1]
            
            linear_position[t] = (
                cumulative_distances[segment_id] + 
                projection * segment_length
            )
    
    return linear_position, projected_x, projected_y

def get_linearized_position(position, track_graph, edge_order=None, edge_spacing=0,
                          use_HMM=False, route_euclidean_distance_scaling=1.0,
                          sensor_std_dev=5.0, diagonal_bias=0.1, edge_map=None,
                          transition_type='simple'):
    """
    Get linearized position along track with different transition types.
    
    Parameters:
    -----------
    transition_type : str
        Type of transition calculation: 'simple', 'angle', 'distance', or 'junction'
    """
    if edge_order is None:
        edge_order = list(track_graph.edges)
    
    if use_HMM:
        # Use HMM for segment classification with specified transition type
        track_segment_id, _ = classify_track_segments(
            track_graph, 
            position,
            transition_type=transition_type,
            route_euclidean_distance_scaling=route_euclidean_distance_scaling,
            sensor_std_dev=sensor_std_dev,
            diagonal_bias=diagonal_bias
        )
        
        # Calculate proper emissions for edge detection improvement
        edges = list(track_graph.edges)
        n_edges = len(edges)
        n_time = len(position)
        emissions = np.zeros((n_time, n_edges))
        
        for t in range(n_time):
            for i, edge in enumerate(edges):
                node1_pos = np.array(track_graph.nodes[edge[0]]['pos'])
                node2_pos = np.array(track_graph.nodes[edge[1]]['pos'])
                segment_vector = node2_pos - node1_pos
                segment_length = np.linalg.norm(segment_vector)
                
                if segment_length > 0:
                    v = position[t] - node1_pos
                    proj = np.dot(v, segment_vector) / segment_length**2
                    proj = np.clip(proj, 0, 1)
                    proj_point = node1_pos + proj * segment_vector
                    
                    dist = np.linalg.norm(position[t] - proj_point)
                    emissions[t, i] = np.exp(-0.5 * (dist / sensor_std_dev)**2)
        
        emissions /= emissions.sum(axis=1, keepdims=True)
        track_segment_id = improve_edge_detection(track_graph, position, emissions)
    else:
        edges = list(track_graph.edges)
        distances = np.zeros((len(position), len(edges)))
        
        for i, edge in enumerate(edges):
            node1_pos = np.array(track_graph.nodes[edge[0]]['pos'])
            node2_pos = np.array(track_graph.nodes[edge[1]]['pos'])
            
            for t in range(len(position)):
                segment_vector = node2_pos - node1_pos
                segment_length_sq = np.dot(segment_vector, segment_vector)
                if segment_length_sq > 0:
                    point_vector = position[t] - node1_pos
                    projection = np.dot(point_vector, segment_vector) / segment_length_sq
                    projection = np.clip(projection, 0, 1)
                    projected_point = node1_pos + projection * segment_vector
                    distances[t, i] = np.linalg.norm(position[t] - projected_point)
        
        track_segment_id = np.argmin(distances, axis=1)
    
    if edge_map is not None:
        for cur_edge, new_edge in edge_map.items():
            track_segment_id[track_segment_id == cur_edge] = new_edge
    
    linear_position, projected_x, projected_y = _calculate_linear_position(
        track_graph,
        position,
        track_segment_id,
        edge_order,
        edge_spacing
    )
    
    # Calculate smoothed position
    def smooth_outliers(positions, segment_ids, window_size=5, threshold=2.0):
        smoothed = positions.copy()
        half_window = window_size // 2
        
        for i in range(len(positions)):
            start = max(0, i - half_window)
            end = min(len(positions), i + half_window + 1)
            
            local_mask = segment_ids[start:end] == segment_ids[i]
            if not any(local_mask):
                continue
                
            local_positions = positions[start:end][local_mask]
            
            local_median = np.median(local_positions)
            local_std = np.std(local_positions)
            
            if local_std > 0:
                z_score = abs(positions[i] - local_median) / local_std
                if z_score > threshold:
                    smoothed[i] = local_median
        
        return smoothed
    
    smoothed_linear_position = smooth_outliers(linear_position, track_segment_id)
    
    return pd.DataFrame({
        'linear_position': linear_position,
        'linear_position_smoothed': smoothed_linear_position,
        'track_segment_id': track_segment_id,
        'projected_x_position': projected_x,
        'projected_y_position': projected_y,
    })

def analyze_transition_matrices(track_graph, position):
    """
    Analyze and visualize different transition matrices
    """
    edges = list(track_graph.edges)
    
    # Calculate different transition types
    transitions_simple = calculate_transitions_simple(track_graph, edges)
    transitions_angle = calculate_transitions_with_angles(track_graph, edges)
    transitions_distance = calculate_transitions_with_distance(track_graph, edges)
    transitions_junction = calculate_transitions_with_junction_type(track_graph, edges)
    
    # Visualize
    plt.figure(figsize=(10, 5))
    
    plt.subplot(2,2,1)
    plt.imshow(transitions_simple)
    plt.title('Simple Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,2)
    plt.imshow(transitions_angle)
    plt.title('Angle-based Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,3)
    plt.imshow(transitions_distance)
    plt.title('Distance-based Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,4)
    plt.imshow(transitions_junction)
    plt.title('Junction-based Transitions')
    plt.colorbar()
    
    plt.tight_layout()
    plt.show()
    
    return {
        'simple': transitions_simple,
        'angle': transitions_angle,
        'distance': transitions_distance,
        'junction': transitions_junction
    }

def analyze_and_visualize_track(track_graph, position, sensor_std_dev=5.0):
    """
    Analyzes and visualizes track segments and emission probabilities
    
    Parameters:
    ----------
    track_graph : networkx.Graph
        Graph representation of the track
    position : numpy.ndarray
        Array of (x,y) positions
    sensor_std_dev : float
        Standard deviation for sensor noise model
    """
    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    
    # Get track properties
    edges = list(track_graph.edges)
    n_edges = len(edges)
    n_time = position.shape[0]
    
    # Calculate emission probabilities
    emissions = np.zeros((n_time, n_edges))
    projected_points = np.zeros((n_time, n_edges, 2))
    
    for t in range(n_time):
        for i, edge in enumerate(edges):
            node1_pos = np.array(track_graph.nodes[edge[0]]['pos'])
            node2_pos = np.array(track_graph.nodes[edge[1]]['pos'])
            segment_vector = node2_pos - node1_pos
            segment_length = np.linalg.norm(segment_vector)
            
            # Calculate projection
            v = position[t] - node1_pos
            proj = np.dot(v, segment_vector) / segment_length**2
            proj = np.clip(proj, 0, 1)
            proj_point = node1_pos + proj * segment_vector
            
            # Store projected point
            projected_points[t, i] = proj_point
            
            # Calculate distance and emission probability
            dist = np.linalg.norm(position[t] - proj_point)
            emissions[t, i] = np.exp(-0.5 * (dist / sensor_std_dev)**2)
    
    # Normalize emissions
    emissions /= emissions.sum(axis=1, keepdims=True)
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 3))
    
    # Plot 1: Track and Positions
    ax1.set_title('Track Layout and Positions')
    
    # Plot track segments
    for edge in edges:
        node1_pos = track_graph.nodes[edge[0]]['pos']
        node2_pos = track_graph.nodes[edge[1]]['pos']
        ax1.plot([node1_pos[0], node2_pos[0]], 
                 [node1_pos[1], node2_pos[1]], 
                 'k-', linewidth=2, alpha=0.5)
    
    # Plot positions colored by maximum emission probability
    max_prob_segment = np.argmax(emissions, axis=1)
    scatter = ax1.scatter(position[:, 0], position[:, 1], 
                         c=max_prob_segment, cmap='tab10',
                         alpha=0.6)
    
    # Plot projected points
    for i in range(n_edges):
        mask = max_prob_segment == i
        if np.any(mask):
            ax1.plot(projected_points[mask, i, 0], 
                    projected_points[mask, i, 1], 
                    'x', alpha=0.3)
    
    # Plot 2: Emission Probabilities
    ax2.set_title('Emission Probabilities Over Time')
    
    # Create time array
    time = np.arange(n_time)
    
    # Plot emission probabilities for each segment
    for i in range(n_edges):
        ax2.plot(time, emissions[:, i], 
                label=f'Segment {i} ({edges[i]})',
                alpha=0.7)
    
    ax2.set_xlabel('Time Step')
    ax2.set_ylabel('Emission Probability')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    return fig, emissions, projected_points

def enhanced_track_analysis(track_graph, position, segment_ids, sensor_std_dev=5.0):
    """
    Comprehensive track analysis with enhanced visualizations and metrics (fixed formatting)
    """
    import matplotlib.pyplot as plt
    from matplotlib.patches import Circle
    from matplotlib.collections import LineCollection
    import seaborn as sns
    import pandas as pd
    
    edges = list(track_graph.edges)
    n_edges = len(edges)
    n_time = position.shape[0]
    
    # Calculate metrics
    metrics = {
        'segment_counts': np.bincount(segment_ids, minlength=n_edges),
        'segment_transitions': np.zeros((n_edges, n_edges)),
        'distance_to_segment': np.zeros(n_time),
        'junction_proximity': np.zeros(n_time),
        'segment_velocities': np.zeros(n_edges),
        'classification_confidence': np.zeros(n_time)
    }
    
    # Calculate transition matrix
    for t in range(1, n_time):
        if segment_ids[t] != segment_ids[t-1]:
            metrics['segment_transitions'][segment_ids[t-1], segment_ids[t]] += 1
    
    # Calculate velocities
    velocities = np.zeros((n_time-1, 2))
    velocities = np.diff(position, axis=0)
    
    for i in range(n_edges):
        mask = segment_ids == i
        if np.any(mask[:-1]):
            metrics['segment_velocities'][i] = np.mean(np.linalg.norm(velocities[mask[:-1]], axis=1))
    
    # Create figure
    fig = plt.figure(figsize=(20, 15))
    gs = fig.add_gridspec(3, 3)
    
    # 1. Track Layout (Top Left)
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.set_title('Track Layout and Position Classification')
    
    # Plot track
    for edge in edges:
        node1_pos = track_graph.nodes[edge[0]]['pos']
        node2_pos = track_graph.nodes[edge[1]]['pos']
        ax1.plot([node1_pos[0], node2_pos[0]], 
                 [node1_pos[1], node2_pos[1]], 
                 'k-', linewidth=2, alpha=0.5)
    
    # Plot positions
    scatter = ax1.scatter(position[:, 0], position[:, 1], 
                         c=np.arange(n_time), cmap='viridis',
                         alpha=0.6, s=30)
    plt.colorbar(scatter, ax=ax1, label='Time')
    
    # 2. Segment Classification (Top Middle)
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.set_title('Segment Classification Over Time')
    ax2.plot(segment_ids, '-o', markersize=2, alpha=0.5)
    ax2.set_xlabel('Time Step')
    ax2.set_ylabel('Segment ID')
    ax2.grid(True)
    
    # 3. Transition Matrix (Top Right)
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.set_title('Segment Transition Matrix')
    sns.heatmap(metrics['segment_transitions'], 
                ax=ax3, 
                annot=True, 
                fmt='.1f',  # Changed format to handle floats
                cmap='YlOrRd')
    ax3.set_xlabel('To Segment')
    ax3.set_ylabel('From Segment')
    
    # 4. Velocity Analysis (Middle Left)
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.set_title('Velocity Magnitude Over Time')
    velocity_mag = np.linalg.norm(velocities, axis=1)
    ax4.plot(velocity_mag, alpha=0.7)
    ax4.set_xlabel('Time Step')
    ax4.set_ylabel('Velocity Magnitude')
    ax4.grid(True)
    
    # 5. Segment Usage (Middle Middle)
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.set_title('Segment Usage Distribution')
    ax5.bar(range(n_edges), metrics['segment_counts'])
    ax5.set_xlabel('Segment ID')
    ax5.set_ylabel('Count')
    ax5.grid(True)
    
    # 6. Segment Velocities (Middle Right)
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.set_title('Average Velocity by Segment')
    ax6.bar(range(n_edges), metrics['segment_velocities'])
    ax6.set_xlabel('Segment ID')
    ax6.set_ylabel('Average Velocity')
    ax6.grid(True)
    
    # 7. Path Continuity (Bottom Left)
    ax7 = fig.add_subplot(gs[2, 0])
    ax7.set_title('Path Continuity Analysis')
    segment_changes = np.diff(segment_ids) != 0
    change_points = np.where(segment_changes)[0]
    ax7.plot(segment_ids, 'b-', alpha=0.5, label='Segment ID')
    ax7.scatter(change_points, segment_ids[change_points], 
                color='red', alpha=0.7, label='Segment Changes')
    ax7.legend()
    ax7.grid(True)
    
    # 8. Position Density (Bottom Middle)
    ax8 = fig.add_subplot(gs[2, 1])
    ax8.set_title('Position Density')
    
    # Create position DataFrame
    pos_df = pd.DataFrame(position, columns=['x', 'y'])
    sns.kdeplot(data=pos_df, x='x', y='y', ax=ax8, cmap='viridis')
    
    # Plot track overlay
    for edge in edges:
        node1_pos = track_graph.nodes[edge[0]]['pos']
        node2_pos = track_graph.nodes[edge[1]]['pos']
        ax8.plot([node1_pos[0], node2_pos[0]], 
                 [node1_pos[1], node2_pos[1]], 
                 'r-', linewidth=2, alpha=0.5)
    
    # 9. Metrics Summary (Bottom Right)
    ax9 = fig.add_subplot(gs[2, 2])
    ax9.set_title('Classification Metrics')
    
    metrics_text = (
        f'Total Transitions: {metrics["segment_transitions"].sum():.1f}\n'
        f'Average Velocity: {np.mean(velocity_mag):.2f}\n'
        f'Max Velocity: {np.max(velocity_mag):.2f}\n'
        f'Segment Changes: {len(change_points)}\n'
        f'Most Used Segment: {np.argmax(metrics["segment_counts"])}\n'
        f'Least Used Segment: {np.argmax(metrics["segment_counts"])}'
    )
    ax9.text(0.1, 0.5, metrics_text, fontsize=10, verticalalignment='center')
    ax9.axis('off')
    
    plt.tight_layout()
    return fig, metrics

def print_analysis_summary(metrics):
    """Print detailed analysis summary"""
    print("=== Track Analysis Summary ===")
    
    print("\nSegment Usage:")
    for seg_id, count in enumerate(metrics['segment_counts']):
        print(f"Segment {seg_id}: {count} points ({count/sum(metrics['segment_counts'])*100:.1f}%)")
    
    print("\nTransition Analysis:")
    total_transitions = metrics['segment_transitions'].sum()
    print(f"Total segment transitions: {total_transitions:.1f}")
    print(f"Average velocity: {np.mean(metrics['segment_velocities']):.2f}")
    
    # Most common transitions
    transitions = metrics['segment_transitions']
    top_transitions = []
    for i in range(transitions.shape[0]):
        for j in range(transitions.shape[1]):
            if transitions[i,j] > 0:
                top_transitions.append((i, j, transitions[i,j]))
    
    top_transitions.sort(key=lambda x: x[2], reverse=True)
    print("\nMost Common Transitions:")
    for from_seg, to_seg, count in top_transitions[:5]:
        print(f"Segment {from_seg} → {to_seg}: {count:.1f} times")

def analyze_transitions(track_graph, position):
    """Compare different transition matrices"""
    transitions_simple = calculate_transitions_simple(track_graph, list(track_graph.edges))
    transitions_angle = calculate_transitions_with_angles(track_graph, list(track_graph.edges))
    transitions_distance = calculate_transitions_with_distance(track_graph, list(track_graph.edges))
    transitions_junction = calculate_transitions_with_junction_type(track_graph, list(track_graph.edges))
    
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2,2,1)
    plt.imshow(transitions_simple)
    plt.title('Simple Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,2)
    plt.imshow(transitions_angle)
    plt.title('Angle-based Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,3)
    plt.imshow(transitions_distance)
    plt.title('Distance-based Transitions')
    plt.colorbar()
    
    plt.subplot(2,2,4)
    plt.imshow(transitions_junction)
    plt.title('Junction-based Transitions')
    plt.colorbar()
    
    plt.tight_layout()
    plt.show()

In [None]:
# Define custom node positions
node_positions = {
    0: (100, 0),
    1: (100, 100),
    2: (15, 150),
    3: (185, 150),
}

# Define custom edges
edges = [
    (0, 1),
    (1, 2),
    (1, 3),
]


track_graph = create_track_graph(node_positions=node_positions, edges=edges)
# Define edge order and spacing
edge_order = [(0, 1), (1, 2), (1, 3)]
edge_spacing = 1

# Try different transition types
result_df_simple = get_linearized_position(
    position, track_graph,
    edge_order=edge_order,
    edge_spacing=edge_spacing,
    use_HMM=True,
    transition_type='simple'
)

result_df_angle = get_linearized_position(
    position, track_graph,
    edge_order=edge_order,
    edge_spacing=edge_spacing,
    use_HMM=True,
    transition_type='angle'
)

# Visualize the results
plt.figure(figsize=(10, 3))

# Plot segment IDs
plt.subplot(1, 2, 1)
plt.plot(result_df_simple['track_segment_id'], 'b-', label='Simple', alpha=0.7)
plt.plot(result_df_angle['track_segment_id'], 'r--', label='Angle-based', alpha=0.7)
plt.title('Segment Classification Comparison')
plt.xlabel('Time')
plt.ylabel('Segment ID')
plt.legend()
plt.grid(True)

# Plot linear positions
plt.subplot(1, 2, 2)
plt.plot(result_df_simple['linear_position'], 'b-', label='Simple', alpha=0.7)
plt.plot(result_df_angle['linear_position'], 'r--', label='Angle-based', alpha=0.7)
plt.title('Linear Position Comparison')
plt.xlabel('Time')
plt.ylabel('Linear Position')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# If you want to see the transition matrices:
transition_matrices = analyze_transition_matrices(track_graph, position)

In [None]:
# Compare HMM vs non-HMM labeling
result_hmm = get_linearized_position(position, track_graph, use_HMM=True)
result_no_hmm = get_linearized_position(position, track_graph, use_HMM=False)

plt.figure(figsize=(12, 3))
plt.plot(result_hmm['track_segment_id'], 'b-', label='HMM Labeling')
plt.plot(result_no_hmm['track_segment_id'], 'r--', label='Simple Labeling', alpha=0.5)
plt.title('HMM vs Simple Segment Labeling')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Define custom node positions
node_positions = {
    0: (100, 0),
    1: (100, 100),
    2: (15, 150),
    3: (185, 150),
}

# Define custom edges
edges = [
    (0, 1),
    (1, 2),
    (1, 3),
]


track_graph = create_track_graph(node_positions=node_positions, edges=edges)
# Define edge order and spacing
edge_order = [(0, 1), (1, 2), (1, 3)]
edge_spacing = 10

# Example usage:
fig, emissions, projected_points = analyze_and_visualize_track(track_graph, position)
improved_segments = improve_edge_detection(track_graph, position, emissions)

# Example usage:
# fig, metrics = enhanced_track_analysis(track_graph, position, improved_segments)
# plt.show()
# print_analysis_summary(metrics)


# Get linearized position with HMM
result_df = get_linearized_position(
    position,
    track_graph,
    edge_order=edge_order,
    edge_spacing=edge_spacing,
    use_HMM=False,
    route_euclidean_distance_scaling=1.0,
    sensor_std_dev=5.0,
    diagonal_bias=0.1
)
# result_df = get_linearized_position(
#     position,
#     track_graph,
#     edge_order=edge_order,
#     edge_spacing=edge_spacing,
#     use_HMM=True,  # This will now use the improved edge detection
#     sensor_std_dev=5.0,
#     diagonal_bias=0.1
# )
# # Print the results
# print(result_df)

In [None]:
from IPython.display import HTML
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import matplotlib.gridspec as gridspec

def animate_track_positions(position, track_graph, result_df, interval=100, tail_length=10):
    """
    Create an animated visualization showing 2D and 1D positions side by side.
    
    Parameters:
    -----------
    position : numpy.ndarray
        Array of shape (n_time, 2) containing x,y coordinates
    track_graph : networkx.Graph
        Graph representing the track layout
    result_df : pandas.DataFrame
        DataFrame containing linearized positions and segment IDs
    interval : int
        Time interval between frames in milliseconds
    tail_length : int
        Number of previous positions to show in trail
    """
    # Set higher animation embed limit (100MB)
    plt.rcParams['animation.embed_limit'] = 100_000_000
    plt.ioff()
    
    # Set up the figure
    fig = plt.figure(figsize=(12, 5))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
    
    # 2D plot
    ax1 = plt.subplot(gs[0])
    ax1.set_title('2D Track Position')
    
    # 1D plot
    ax2 = plt.subplot(gs[1])
    ax2.set_title('Linear Position')
    
    # Plot track segments
    edges = list(track_graph.edges)
    segment_colors = ['b', 'g', 'r']
    
    # Draw track in 2D
    for edge_id, edge in enumerate(edges):
        node1_pos = track_graph.nodes[edge[0]]['pos']
        node2_pos = track_graph.nodes[edge[1]]['pos']
        ax1.plot([node1_pos[0], node2_pos[0]], 
                [node1_pos[1], node2_pos[1]], 
                f'{segment_colors[edge_id]}-', 
                linewidth=2, alpha=0.5,
                label=f'Segment {edge_id}')
    
    # Set axis limits
    x_min, x_max = position[:, 0].min(), position[:, 0].max()
    y_min, y_max = position[:, 1].min(), position[:, 1].max()
    padding = 20
    ax1.set_xlim(x_min - padding, x_max + padding)
    ax1.set_ylim(y_min - padding, y_max + padding)
    
    # Plot all linear positions by segment
    time_points = np.arange(len(position))
    for seg_id in range(len(edges)):
        mask = result_df['track_segment_id'] == seg_id
        if np.any(mask):
            ax2.plot(time_points[mask], 
                    result_df['linear_position'][mask],
                    f'{segment_colors[seg_id]}-', alpha=0.3)
    
    ax2.set_xlim(0, len(position))
    ax2.set_ylim(result_df['linear_position'].min() - 10,
                result_df['linear_position'].max() + 10)
    
    # Initialize plots
    point_2d = ax1.plot([], [], 'ko', markersize=8)[0]
    trail_2d = ax1.plot([], [], 'k:', alpha=0.5)[0]
    point_1d = ax2.plot([], [], 'ko', markersize=8)[0]
    trail_1d = ax2.plot([], [], 'k:', alpha=0.5)[0]
    
    ax1.grid(True)
    ax2.grid(True)
    ax1.legend()
    
    def animate(frame):
        # Update 2D position
        start_idx = max(0, frame - tail_length)
        trail_x = position[start_idx:frame+1, 0]
        trail_y = position[start_idx:frame+1, 1]
        
        point_2d.set_data([position[frame, 0]], [position[frame, 1]])
        trail_2d.set_data(trail_x, trail_y)
        
        # Update 1D position
        trail_t = np.arange(start_idx, frame+1)
        trail_pos = result_df['linear_position'].iloc[start_idx:frame+1]
        
        point_1d.set_data([frame], [result_df['linear_position'].iloc[frame]])
        trail_1d.set_data(trail_t, trail_pos)
        
        # Update colors based on current segment
        current_segment = result_df['track_segment_id'].iloc[frame]
        point_2d.set_color(segment_colors[current_segment])
        point_1d.set_color(segment_colors[current_segment])
        
        return point_2d, trail_2d, point_1d, trail_1d
    
    anim = FuncAnimation(fig, animate, 
                        frames=len(position),
                        interval=interval,
                        blit=True)
    
    plt.close()
    return HTML(anim.to_jshtml())

# Example usage:
"""
anim_html = animate_track_positions(
    position, 
    track_graph, 
    result_df,
    interval=50,  # Speed of animation (lower = faster)
    tail_length=10  # Length of trail
)
display(anim_html)
"""

In [None]:
anim_html = animate_track_positions(
    position, 
    track_graph, 
    result_df,
    interval=50,  # Speed of animation (lower = faster)
    tail_length=10  # Length of trail
)
display(anim_html)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx

def plot_linear_and_2d_positions(track_graph, position, result_df, edge_order):
    """
    Plot the linearized position with colored segments and the corresponding 2D position.
    
    Parameters:
    ----------
    track_graph : networkx.Graph
        Graph representation of the 2D track.
    position : numpy.ndarray, shape (n_time, 2)
        2D position of the animal over time.
    result_df : pandas.DataFrame
        DataFrame containing 'linear_position', 'projected_x_position', 
        'projected_y_position', and 'track_segment_id'.
    edge_order : list of tuples
        Order of edges in the graph. Used to determine segment labels.
    """
    # Extract data from the result DataFrame
    linear_positions = result_df['linear_position'].values
    projected_x_positions = result_df['projected_x_position'].values
    projected_y_positions = result_df['projected_y_position'].values
    track_segment_ids = result_df['track_segment_id'].values
    n_time = len(linear_positions)
    
    # Create a mapping from segment ID to color
    colors = plt.cm.tab10.colors  # Use a colormap with distinct colors
    segment_colors = {i: colors[i % len(colors)] for i in range(len(edge_order))}
    
    # Create a figure with two subplots: linearized position on the left, 2D position on the right
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    ax1, ax2 = axes
    
    # === Left Plot: Linearized Position ===
    # Plot each segment with its corresponding color
    unique_segments = np.unique(track_segment_ids)
    for segment_id in unique_segments:
        segment_mask = track_segment_ids == segment_id
        segment_time = np.arange(n_time)[segment_mask]
        segment_linear_positions = linear_positions[segment_mask]
        
        # Plot the segment with its color
        ax1.plot(
            segment_time, segment_linear_positions,
            color=segment_colors[segment_id],
            label=f"Segment {edge_order[segment_id]}",
            alpha=0.8
        )
        ax1.scatter(
            segment_time, segment_linear_positions,
            color=segment_colors[segment_id],
            s=50
        )
    
    # Add titles, labels, and legend
    ax1.set_title("Linearized Position with Colored Segments")
    ax1.set_xlabel("Time Step")
    ax1.set_ylabel("Linear Position")
    ax1.legend(loc="upper left", bbox_to_anchor=(1.01, 1), title="Segments")
    ax1.grid(alpha=0.3)
    
    # === Right Plot: 2D Position ===
    # Extract node positions from the graph
    pos = nx.get_node_attributes(track_graph, 'pos')
    
    # Draw the track graph
    nx.draw(
        track_graph,
        pos,
        with_labels=True,
        node_size=500,
        node_color='lightblue',
        font_weight='bold',
        ax=ax2
    )
    
    # Plot the actual 2D position data
    ax2.plot(position[:, 0], position[:, 1], color='red', label='Animal Path', alpha=0.7)
    ax2.scatter(position[:, 0], position[:, 1], color='red', s=50, label='Animal Position')
    
    # Plot the projected positions with segment colors
    for segment_id in unique_segments:
        segment_mask = track_segment_ids == segment_id
        ax2.plot(
            projected_x_positions[segment_mask],
            projected_y_positions[segment_mask],
            color=segment_colors[segment_id],
            label=f"Segment {edge_order[segment_id]}",
            alpha=0.8
        )
        ax2.scatter(
            projected_x_positions[segment_mask],
            projected_y_positions[segment_mask],
            color=segment_colors[segment_id],
            s=50
        )
    
    # Add titles and labels
    ax2.set_title("2D Track Graph with Colored Segments")
    ax2.set_xlabel("X Coordinate")
    ax2.set_ylabel("Y Coordinate")
    ax2.legend(loc="upper left", bbox_to_anchor=(1.01, 1), title="Segments")
    ax2.set_aspect('equal', adjustable='box')  # Ensure equal aspect ratio
    
    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()

# Plot the linearized position and 2D position
plot_linear_and_2d_positions(track_graph, position, result_df, edge_order)

Comparing HMM different labeling methods

In [None]:
# Try different transition types
result_df_simple = get_linearized_position(
    position, track_graph,
    edge_order=edge_order,
    edge_spacing=edge_spacing,
    use_HMM=True,
    transition_type='simple'
)

result_df_angle = get_linearized_position(
    position, track_graph,
    edge_order=edge_order,
    edge_spacing=edge_spacing,
    use_HMM=True,
    transition_type='angle'
)

# Visualize the results
plt.figure(figsize=(15, 5))

# Plot segment IDs
plt.subplot(1, 2, 1)
plt.plot(result_df_simple['track_segment_id'], 'b-', label='Simple', alpha=0.7)
plt.plot(result_df_angle['track_segment_id'], 'r--', label='Angle-based', alpha=0.7)
plt.title('Segment Classification Comparison')
plt.xlabel('Time')
plt.ylabel('Segment ID')
plt.legend()
plt.grid(True)

# Plot linear positions
plt.subplot(1, 2, 2)
plt.plot(result_df_simple['linear_position'], 'b-', label='Simple', alpha=0.7)
plt.plot(result_df_angle['linear_position'], 'r--', label='Angle-based', alpha=0.7)
plt.title('Linear Position Comparison')
plt.xlabel('Time')
plt.ylabel('Linear Position')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# If you want to see the transition matrices:
transition_matrices = analyze_transition_matrices(track_graph, position)