In [None]:
import torch
import math

import torchlensmaker as tlm
from torchlensmaker.shapes.common import line_coefficients

import matplotlib.pyplot as plt
import numpy as np

def render_shape(ax, shape):
    points = shape.evaluate(torch.linspace(*shape.domain(), 50)).detach().numpy()

    ax.plot(points[:, 0], points[:, 1], color="steelblue")

    normalst = torch.linspace(*shape.domain(), 11)
    normals_origins = shape.evaluate(normalst).detach().numpy()
    normals_vectors = shape.normal(normalst).detach().numpy()
    for o, n in zip(normals_origins, normals_vectors):
        ax.plot([o[0], o[0]+n[0]], [o[1], o[1]+n[1]], color="grey")
    ax.set_aspect("equal")


def render_lines(ax, lines):
    ymin, ymax = ax.get_ylim()
    over = 0.35
    ymin = ymin - over*(ymax - ymin)
    ymax = ymax + over*(ymax - ymin)
    for line in lines:
        a, b, c = line
        x0 = (-c-b*ymin)/a
        x1 = (-c-b*ymax)/a
        ax.plot([x0, x1], [ymin, ymax], color="orange")


def render_derivative(ax, shape):
    t = torch.linspace(*shape.domain(), 50)
    diff_points = shape.derivative(t).detach().numpy()
    ax.plot(diff_points[:, 0], diff_points[:, 1])
    ax.set_aspect("equal")
    

def render_points(ax, points):
    if isinstance(points, torch.Tensor):
        points = points.detach().numpy()
    ax.scatter(points[:, 0], points[:, 1], color="green", marker="o")


    

In [None]:
test_shapes = [
    tlm.BezierSpline(15.0, Y=[3.0], CX=[0.2*15., 1.2*15.], CY=[6.8]),
    tlm.Line(15.0),
    tlm.PiecewiseLine(25.0, Y=[-1.0, -3.0, -3.5, -7.0]),
    tlm.CircularArc(15.0, -25.),
    tlm.CircularArc(15.0, 25.),
    tlm.Parabola(40.0, 0.002),
]

test_lines = torch.tensor([
    [1., 0.5, 15.],
    [-9.0, -1.5, -25.],
    [3.6, -5, -20],
    [3., 0., -18.],
    [-12.0, 5.5, -155.],
])

def test_shape(shape):
    print("Testing shape", shape)
    print("domain:", shape.domain())

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))             
    render_shape(ax1, shape)
    #render_derivative(ax2, shape)
    render_lines(ax1, test_lines)

    sols = shape.collide(test_lines)

    # Keep only valid solutions (within domain)
    valid = torch.logical_and(sols <= shape.domain()[1], sols >= shape.domain()[0])
    sols = sols[valid]
    
    collisions, normals = shape.evaluate(sols), shape.normal(sols)
    
    render_points(ax1, collisions)
    plt.show()
    

def main():
    for shape in test_shapes:
        test_shape(shape)
    
main()