In [None]:
import torch
import math

import torchlensmaker as tlm
from torchlensmaker.shapes.common import line_coefficients

import matplotlib.pyplot as plt
import numpy as np

from torchlensmaker.testing import render_lines, render_spline, render_collision_points

def render_shape(ax, shape):

    # special case for spline
    if isinstance(shape, tlm.BezierSpline):
        return render_spline(ax, shape)
    
    points = shape.evaluate(torch.linspace(*shape.domain(), 50)).detach().numpy()

    ax.plot(points[:, 0], points[:, 1], color="steelblue")

    normalst = torch.linspace(*shape.domain(), 11)
    normals_origins = shape.evaluate(normalst).detach().numpy()
    normals_vectors = shape.normal(normalst).detach().numpy()
    for o, n in zip(normals_origins, normals_vectors):
        ax.plot([o[0], o[0]+n[0]], [o[1], o[1]+n[1]], color="grey")
    ax.set_aspect("equal")


def render_derivative(ax, shape):
    t = torch.linspace(*shape.domain(), 50)
    diff_points = shape.derivative(t).detach().numpy()
    ax.plot(diff_points[:, 0], diff_points[:, 1])
    ax.set_aspect("equal") 

In [None]:
test_shapes = [
    tlm.BezierSpline(15.0, X=[3.0], CX=[5.8], CY=[0.2*15/2., 1.2*15./2]),
    tlm.Line(15.0),
    tlm.PiecewiseLine(25.0, X=[-1.0, -3.0, -3.5, -7.0]),
    tlm.CircularArc(15.0, -25.),
    tlm.CircularArc(15.0, 25.),
    tlm.CircularArc(15.0, 1e10),
    tlm.CircularArc(15.0, -1e10),
    tlm.Parabola(40.0, 0.02),
    tlm.Parabola(40.0, -0.02),
]

test_lines = torch.tensor([
    #[1., 0., 0.],
    [0., 1. - 1e-2, 0.],
    #[0.5, 1., 1/15.],
    #[-1.5, -9.0, -25.],
    #[-5, 3.6, -20],
    #[0., 3., -18.],
    #[3., 0., -18.],
    #[5.5, -12.0, -15.],
])

def test_shape(shape):
    print("Testing shape", shape)
    print("domain:", shape.domain())

    fig, ax1 = plt.subplots(1, 1, figsize=(15, 8))
    ax1.set_xlim([-10, 10])
    ax1.set_ylim([-10, 10])
    render_shape(ax1, shape)
    
    #render_derivative(ax2, shape)
    render_lines(ax1, test_lines, xlim=(-10, 10))

    sols = shape.collide(test_lines)

    # Keep only valid solutions (within domain)
    valid = torch.logical_and(sols <= shape.domain()[1], sols >= shape.domain()[0])
    sols = sols[valid]
    
    collisions, normals = shape.evaluate(sols), shape.normal(sols)
    
    render_collision_points(ax1, collisions)

    plt.show()
    

def main():
    for shape in test_shapes:
        test_shape(shape)
    
main()