In [None]:
import networkx as nx
import random
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Point, LineString, Polygon
import ipywidgets as widgets
from IPython.display import clear_output, display
import time
from matplotlib import animation
from scipy.interpolate import interp1d
from IPython.display import HTML

# Import the TestEnvironments class from the uploaded module
from testEnvironments import TestEnvironments

# ---------- LOAD ENVIRONMENTS ----------
test_envs = TestEnvironments()
scene_dict = test_envs.get_all_scenes()

# Transform scenes into usable format
def convert_scene_to_env(name, scene):
    start = scene.get("start", Point(0, 0).buffer(0.8))
    goal = scene.get("goal", Point(0, 0).buffer(0.8))
    obstacles = [geom for key, geom in scene.items() if key not in ["start", "goal"]]
    start_center = start.centroid.coords[0]
    goal_center = goal.centroid.coords[0]
    return {
        "map_bounds": (0, 23, 0, 23),
        "start": start_center,
        "goal": goal_center,
        "obstacles": obstacles
    }

environments = {name: convert_scene_to_env(name, scene) for name, scene in scene_dict.items()}

# ---------- PARAMETERS ----------
samples_per_region = 40
lazy_k_neighbors = 8
num_regions = 4
collision_checks = 0

# ---------- PRM UTILITY FUNCTIONS ----------
def is_collision(p1, p2, obstacles):
    global collision_checks
    collision_checks += 1
    line = LineString([p1, p2])
    return any(line.crosses(obs) or line.within(obs) for obs in obstacles)

def sample_free(bounds, obstacles):
    minx, maxx, miny, maxy = bounds
    while True:
        x, y = random.uniform(minx, maxx), random.uniform(miny, maxy)
        if not any(obs.contains(Point(x, y)) for obs in obstacles):
            return (x, y)

def build_lazy_prm(region_bounds, num_samples, k, obstacles):
    G = nx.Graph()
    nodes = [sample_free(region_bounds, obstacles) for _ in range(num_samples)]
    G.add_nodes_from(nodes)
    for node in nodes:
        distances = sorted(
            ((other, np.linalg.norm(np.subtract(node, other))) for other in nodes if other != node),
            key=lambda x: x[1]
        )[:k]
        for neighbor, dist in distances:
            G.add_edge(node, neighbor, weight=dist)
    return G

def build_visibility_prm(nodes, obstacles):
    G = nx.Graph()
    for node in nodes:
        G.add_node(node)
        for other in G.nodes:
            if other != node and not G.has_edge(node, other):
                if not is_collision(node, other, obstacles):
                    dist = np.linalg.norm(np.subtract(node, other))
                    G.add_edge(node, other, weight=dist)
    return G

def build_hierarchical_visibility_lazy_prm(global_bounds, num_regions, samples_per_region, k, obstacles):
    minx, maxx, miny, maxy = global_bounds
    midx, midy = (minx + maxx) / 2, (miny + maxy) / 2
    regions = [
        (minx, midx, miny, midy),
        (midx, maxx, miny, midy),
        (minx, midx, midy, maxy),
        (midx, maxx, midy, maxy)
    ]
    local_graphs = [build_lazy_prm(region, samples_per_region, k, obstacles) for region in regions]
    all_nodes = [node for g in local_graphs for node in g.nodes]
    global_graph = build_visibility_prm(all_nodes, obstacles)
    for g in local_graphs:
        global_graph.add_edges_from(g.edges(data=True))
    return global_graph

def validate_lazy_path(G, path, obstacles):
    valid = True
    for u, v in zip(path[:-1], path[1:]):
        if is_collision(u, v, obstacles):
            G.remove_edge(u, v)
            valid = False
    return valid

def lazy_shortest_path(G, start, goal, obstacles):
    while True:
        try:
            path = nx.shortest_path(G, source=start, target=goal, weight='weight')
        except nx.NetworkXNoPath:
            return None
        if validate_lazy_path(G, path, obstacles):
            return path

def forward_kinematics(joint_angles, link_lengths):
    x, y, theta = 0, 0, 0
    points = [(x, y)]
    for angle, length in zip(joint_angles, link_lengths):
        theta += angle
        x += length * np.cos(theta)
        y += length * np.sin(theta)
        points.append((x, y))
    return points

def is_robot_collision(joint_angles, link_lengths, obstacles):
    points = forward_kinematics(joint_angles, link_lengths)
    for i in range(len(points) - 1):
        seg = LineString([points[i], points[i+1]])
        if any(seg.crosses(obs) or seg.within(obs) for obs in obstacles):
            return True
    return False

def sample_robot_configuration(dof):
    return tuple(random.uniform(-np.pi, np.pi) for _ in range(dof))

def build_robot_arm_prm(dof, num_samples, k, link_lengths, obstacles):
    G = nx.Graph()
    nodes = []
    while len(nodes) < num_samples:
        config = sample_robot_configuration(dof)
        if not is_robot_collision(config, link_lengths, obstacles):
            nodes.append(config)
    G.add_nodes_from(nodes)
    for node in nodes:
        distances = sorted(
            ((other, np.linalg.norm(np.subtract(node, other))) for other in nodes if other != node),
            key=lambda x: x[1]
        )[:k]
        for neighbor, dist in distances:
            if not is_robot_collision(neighbor, link_lengths, obstacles):
                G.add_edge(node, neighbor, weight=dist)
    return G

from scipy.interpolate import interp1d
from IPython.display import HTML

def animate_robot_arm_path(path, link_lengths, obstacles, environment_name):
    plt.close('all')

    if len(path) < 2:
        print("Path too short to animate.")
        return

    path = np.array(path)
    num_joints = path.shape[1]

    # Compute cumulative distances along the path in joint space
    dists = np.cumsum([0] + [np.linalg.norm(path[i] - path[i - 1]) for i in range(1, len(path))])
    total_dist = dists[-1]

    # Interpolate 60 frames total (you can increase this for slower, smoother animation)
    num_frames = 60
    interp_dists = np.linspace(0, total_dist, num_frames)

    interpolators = [interp1d(dists, path[:, j], kind='linear') for j in range(num_joints)]
    interpolated_path = np.column_stack([f(interp_dists) for f in interpolators])

    fig, ax = plt.subplots(figsize=(6, 6))

    # Render obstacles
    for obs in obstacles:
        if isinstance(obs, Polygon):
            x, y = obs.exterior.xy
            ax.fill(x, y, color='gray')

    ax.set_xlim(0, 23)
    ax.set_ylim(0, 23)
    ax.set_title(f"Robot Arm Animation: {environment_name}")
    ax.grid(True)

    lines = [ax.plot([], [], 'o-', lw=4, color='deeppink')[0]]

    def init():
        for line in lines:
            line.set_data([], [])
        return lines

    def update(frame):
        config = interpolated_path[frame]
        points = forward_kinematics(config, link_lengths)
        xs, ys = zip(*points)
        lines[0].set_data(xs, ys)
        return lines

    ani = animation.FuncAnimation(
        fig, update, frames=num_frames,
        init_func=init, blit=True, interval=1000 / 60, repeat=False
    )
    display(HTML(ani.to_jshtml()))


# ---------- MAIN EXECUTION FUNCTION ----------
def plot_prm(G, obstacles, start, goal, path=None, environment_name=""):
    plt.figure(figsize=(6, 6))
    
    # Plot obstacles
    for obs in obstacles:
        x, y = obs.exterior.xy
        plt.fill(x, y, color='gray')
    
    # Plot PRM graph
    for u, v in G.edges():
        x_vals = [u[0], v[0]]
        y_vals = [u[1], v[1]]
        plt.plot(x_vals, y_vals, color='lightblue', linewidth=0.5)

    # Plot nodes
    x_nodes, y_nodes = zip(*G.nodes)
    plt.scatter(x_nodes, y_nodes, color='blue', s=10, label='PRM Nodes')

    # Plot start and goal
    plt.plot(start[0], start[1], 'go', markersize=10, label='Start')
    plt.plot(goal[0], goal[1], 'ro', markersize=10, label='Goal')

    # Plot path if it exists
    if path:
        path_x, path_y = zip(*path)
        plt.plot(path_x, path_y, color='orange', linewidth=2, linestyle='--', label='Path')

    plt.title(f"PRM Graph: {environment_name}")
    plt.xlim(0, 23)
    plt.ylim(0, 23)
    plt.grid(True)
    plt.legend()
    plt.show()


def interactive_plot(samples_per_region, lazy_k_neighbors, environment_name, dof):
    global collision_checks
    collision_checks = 0

    output_area.clear_output(wait=True)
    plt.close('all')

    with output_area:
        start_time = time.time()

        environment = environments[environment_name]
        obstacles = environment["obstacles"]
        map_bounds = environment["map_bounds"]
        start = environment["start"]
        goal = environment["goal"]

        print(f"Running planner for DOF: {dof}")

        if dof == 1:
            # Classic point PRM
            G = build_hierarchical_visibility_lazy_prm(map_bounds, num_regions, samples_per_region, lazy_k_neighbors, obstacles)
            G.add_node(start)
            G.add_node(goal)

            for node in list(G.nodes):
                if node in [start, goal]:
                    continue
                if not is_collision(node, start, obstacles):
                    G.add_edge(node, start, weight=np.linalg.norm(np.subtract(node, start)))
                if not is_collision(node, goal, obstacles):
                    G.add_edge(node, goal, weight=np.linalg.norm(np.subtract(node, goal)))

            path = lazy_shortest_path(G, start, goal, obstacles)
            planning_time = time.time() - start_time
            path_length = sum(np.linalg.norm(np.subtract(u, v)) for u, v in zip(path[:-1], path[1:])) if path else 0

            print("\n----------- Benchmark Results -----------")
            print(f"Number of nodes: {G.number_of_nodes()}")
            print(f"Number of edges: {G.number_of_edges()}")
            print(f"Collision checks performed: {collision_checks}")
            if path:
                print(f"Path found with length: {path_length:.2f}")
            else:
                print("No valid path found.")
            print(f"Planning time: {planning_time:.2f} seconds")

            plot_prm(G, obstacles, start, goal, path, environment_name)

            if path:
                animate_path(path, obstacles, start, goal, map_bounds, environment_name)
            else:
                print("No path to animate.")
        else:
            # Robot arm mode
            link_lengths = [3.0] * dof  # example link lengths
            G = build_robot_arm_prm(dof, samples_per_region * num_regions, lazy_k_neighbors, link_lengths, obstacles)

            start_config = sample_robot_configuration(dof)
            goal_config = sample_robot_configuration(dof)

            while is_robot_collision(start_config, link_lengths, obstacles):
                start_config = sample_robot_configuration(dof)
            while is_robot_collision(goal_config, link_lengths, obstacles):
                goal_config = sample_robot_configuration(dof)

            G.add_node(start_config)
            G.add_node(goal_config)

            for node in list(G.nodes):
                if node in [start_config, goal_config]:
                    continue
                dist_start = np.linalg.norm(np.subtract(node, start_config))
                dist_goal = np.linalg.norm(np.subtract(node, goal_config))
                if dist_start < 1.5 and not is_robot_collision(start_config, link_lengths, obstacles):
                    G.add_edge(node, start_config, weight=dist_start)
                if dist_goal < 1.5 and not is_robot_collision(goal_config, link_lengths, obstacles):
                    G.add_edge(node, goal_config, weight=dist_goal)

            path = lazy_shortest_path(G, start_config, goal_config, obstacles)
            planning_time = time.time() - start_time

            print("\n----------- Benchmark Results (Robot Arm) -----------")
            print(f"DOF: {dof}")
            print(f"Nodes: {G.number_of_nodes()}")
            print(f"Edges: {G.number_of_edges()}")
            print(f"Collision checks: {collision_checks}")
            if path:
                print(f"Path found of length: {len(path)}")
            else:
                print("No path found.")
            print(f"Planning time: {planning_time:.2f} seconds")

            if path:
                animate_robot_arm_path(path, link_lengths, obstacles, environment_name)
            else:
                print("No path to animate.")


# ---------- UI CONTROLS ----------
environment_dropdown = widgets.Dropdown(
    options=list(environments.keys()),
    value=list(environments.keys())[0],
    description='Environment:'
)

samples_slider = widgets.IntSlider(
    value=samples_per_region,
    min=10,
    max=100,
    step=1,
    description='Samples/Region:'
)

k_slider = widgets.IntSlider(
    value=lazy_k_neighbors,
    min=3,
    max=20,
    step=1,
    description='K Neighbors:'
)

dof_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=12,
    step=1,
    description='DOF:'
)

run_button = widgets.Button(
    description='Run PRM',
    button_style='success'
)

output_area = widgets.Output()

def run_clicked(_):
    output_area.clear_output(wait=True)
    interactive_plot(samples_slider.value, k_slider.value, environment_dropdown.value, dof_slider.value)

run_button.on_click(run_clicked)

ui = widgets.VBox([environment_dropdown, samples_slider, k_slider, dof_slider, run_button, output_area])
display(ui)

#------------------------Animation----------------------------
from matplotlib import animation
from scipy.interpolate import interp1d
from IPython.display import HTML

def animate_path(path, obstacles, start, goal, map_bounds, environment_name):
    plt.close('all')

    if len(path) < 2:
        print("Path too short to animate.")
        return

    path = np.array(path)
    dists = np.cumsum([0] + [np.linalg.norm(path[i] - path[i-1]) for i in range(1, len(path))])
    total_dist = dists[-1]
    interp_dist = np.linspace(0, total_dist, 60)

    fx = interp1d(dists, path[:, 0], kind='linear')
    fy = interp1d(dists, path[:, 1], kind='linear')
    interpolated_path = np.column_stack((fx(interp_dist), fy(interp_dist)))

    fig, ax = plt.subplots(figsize=(6, 6))

    for obs in obstacles:
        x, y = obs.exterior.xy
        ax.fill(x, y, color='gray')

    ax.set_xlim(map_bounds[0], map_bounds[1])
    ax.set_ylim(map_bounds[2], map_bounds[3])
    ax.set_title(f"Path Animation: {environment_name}")
    ax.grid(True)

    ax.plot(*start, 'go', markersize=10, label='Start')
    ax.plot(*goal, 'ro', markersize=10, label='Goal')
    ax.plot(path[:, 0], path[:, 1], linestyle='--', color='orange', linewidth=1.5)

    pink_dot, = ax.plot([], [], 'o', color='deeppink', markersize=8, label='Agent')
    ax.legend()

    def init():
        pink_dot.set_data([], [])
        return (pink_dot,)

    def update(frame):
        x, y = interpolated_path[frame]
        pink_dot.set_data([x], [y])
        return (pink_dot,)

    ani = animation.FuncAnimation(
        fig, update, frames=len(interpolated_path),
        init_func=init, interval=100, blit=True, repeat=False
    )

    display(HTML(ani.to_jshtml()))
