In [None]:
import torch
import math

import torchlensmaker as tlm
from torchlensmaker.shapes.common import line_coefficients

import matplotlib.pyplot as plt
import numpy as np

def render_shape(ax, shape):
    points = shape.evaluate(torch.linspace(*shape.domain(), 50)).detach().numpy()

    ax.plot(points[:, 0], points[:, 1], color="steelblue")

    normalst = torch.linspace(*shape.domain(), 11)
    normals_origins = shape.evaluate(normalst).detach().numpy()
    normals_vectors = shape.normal(normalst).detach().numpy()
    for o, n in zip(normals_origins, normals_vectors):
        ax.plot([o[0], o[0]+n[0]], [o[1], o[1]+n[1]], color="grey")
    ax.set_aspect("equal")


def render_derivative(ax, shape):
    t = torch.linspace(*shape.domain(), 50)
    diff_points = shape.derivative(t).detach().numpy()
    ax.plot(diff_points[:, 0], diff_points[:, 1])
    ax.set_aspect("equal")
    

def render_points(ax, points):
    if isinstance(points, torch.Tensor):
        points = points.detach().numpy()
    ax.scatter(points[:, 0], points[:, 1], color="green", marker="o")


def render_lines(ax, lines):
    """
    Render lines on a matplotlib axis.
    
    Parameters:
    ax (matplotlib.axes.Axes): The matplotlib axis to draw on.
    lines (torch.Tensor): A tensor of shape (N, 3) where each row represents
                          the coefficients [a, b, c] of a line ax + by + c = 0.
    """
    lines = torch.as_tensor(lines)
    
    # Get current axis limits
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    
    # Create x values for plotting
    x = torch.linspace(x_min, x_max, 100)
    
    # Iterate through each line
    for line in lines:
        a, b, c = line
        
        # Handle vertical lines (avoid division by zero)
        if b.abs() < 1e-6:
            if a.abs() < 1e-6:
                continue  # Skip if both a and b are zero
            x_intercept = -c / a
            ax.axvline(x=x_intercept.item(), color='orange')
        else:
            # Calculate y values: y = (-ax - c) / b
            y = (-a * x - c) / b
            
            # Convert to numpy for plotting
            x_np = x.numpy()
            y_np = y.numpy()
            
            # Plot the line
            ax.plot(x_np, y_np, color='orange')
    
    # Reset the axis limits (they might have changed during plotting)
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    

In [None]:
test_shapes = [
    #tlm.BezierSpline(15.0, Y=[3.0], CX=[0.2*15., 1.2*15.], CY=[6.8]),
    tlm.Line(15.0),
    tlm.PiecewiseLine(25.0, X=[-1.0, -3.0, -3.5, -7.0]),
    tlm.CircularArc(15.0, -25.),
    tlm.CircularArc(15.0, 25.),
    tlm.Parabola(40.0, 0.02),
    tlm.Parabola(40.0, -0.02),
]

test_lines = torch.tensor([
    [0.5, 1., 1/15.],
    [-1.5, -9.0, -25.],
    [-5, 3.6, -20],
    [0., 3., -18.],
    [3., 0., -18.],
    [5.5, -12.0, -15.],
])

def test_shape(shape):
    print("Testing shape", shape)
    print("domain:", shape.domain())

    fig, ax1 = plt.subplots(1, 1, figsize=(15, 8))
    ax1.set_xlim([-10, 10])
    ax1.set_ylim([-10, 10])
    render_shape(ax1, shape)
    
    #render_derivative(ax2, shape)
    render_lines(ax1, test_lines)

    sols = shape.collide(test_lines)

    # Keep only valid solutions (within domain)
    valid = torch.logical_and(sols <= shape.domain()[1], sols >= shape.domain()[0])
    sols = sols[valid]
    
    collisions, normals = shape.evaluate(sols), shape.normal(sols)
    
    render_points(ax1, collisions)

    plt.show()
    

def main():
    for shape in test_shapes:
        test_shape(shape)
    
main()