In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt

import torchlensmaker as tlm

from torchlensmaker.surface import scale_lines

In [None]:
test_shapes = [
    tlm.BezierSpline(15.0, X=[3.0], CX=[5.8], CY=[0.2*15/2., 1.2*15./2]),
    tlm.Line(15.0),
    tlm.PiecewiseLine(height=18.0, X=[-1.0, -3.0, -3.5, -7.0]),
    tlm.CircularArc(15.0, -25.),
    tlm.CircularArc(15.0, 25.),
    tlm.Parabola(10.0, a=0.02),
    tlm.Parabola(10.0, a=-0.02),
]

test_lines = torch.tensor([
    [1., 0.5, 15.],
    [-9.0, -1.5, -25.],
    [3.6, -5, -20],
    [3., 0., -18.],
])

surface_configurations = [
    ((15.0, -30.0), 1., "origin"),
    ((0.0, -15.0), 1., "origin"),
    ((-14.3, -5.1), -1., "extent"),
    ((0.0, 0.0), 1., "origin"),
    ((12.51, 5.72), 1., "extent"),
    ((0.0, 15.0), 1., "origin"),
    ((-15.0, 35.0), -1., "origin"),
]

xlim, ylim = [
    (-50, 50),
    (-50, 50)
]

def render_shape(ax, shape: tlm.Surface):
    # Render the shape profile
    points = shape.evaluate(torch.linspace(*shape.domain(), 50)).detach().numpy()
    ax.plot(points[:, 0], points[:, 1], color="steelblue")

    # Add some normals
    normalst = torch.linspace(*shape.domain(), 5)
    normals_origins = shape.evaluate(normalst).detach().numpy()
    normals_vectors = shape.normal(normalst).detach().numpy()
    for o, n in zip(normals_origins, normals_vectors):
        ax.plot([o[0], o[0]+n[0]], [o[1], o[1]+n[1]], color="grey")

    ax.plot(shape.pos[0], shape.pos[1], "x", color="red")

    ax.set_aspect("equal")


def render_lines(ax, lines, color):
    ymin, ymax = ax.get_ylim()
    over = 0.35
    ymin = ymin - over*(ymax - ymin)
    ymax = ymax + over*(ymax - ymin)
    for line in lines:
        a, b, c = line
        x0 = (-c-b*ymin)/a
        x1 = (-c-b*ymax)/a
        ax.plot([x0, x1], [ymin, ymax], color=color)


def render_points(ax, points):
    if isinstance(points, torch.Tensor):
        points = points.detach().numpy()
    ax.scatter(points[:, 0], points[:, 1], color="green", marker="o")


def test_shape(ax, shape):
    render_shape(ax, shape)

    render_lines(ax, test_lines, color="orange")
    #render_lines(ax, scale_lines(test_lines, torch.tensor([1.0, -1.0])), color="red")

    sols = shape.collide(test_lines)
    
    # Keep only valid solutions (within domain)
    valid = torch.logical_and(sols <= shape.domain()[1], sols >= shape.domain()[0])
    sols = sols[valid]

    if sols.numel() > 0:
        collisions = shape.evaluate(sols)
        render_points(ax, collisions)

def main():
    for index_shape, shape in enumerate(test_shapes):
        fig, ax = plt.subplots(1, 1, figsize=(15, 8))
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.grid()

        for index_position, (origin, scale, anchor) in enumerate(surface_configurations):
            #print(f"Shape {index_shape}, position {index_position}")
            surface = tlm.Surface(shape, origin, scale, anchor)

            # sanity check anchors
            for anchor in surface.valid_anchors:
                b = surface.at(anchor)
                a = surface.at("origin")
                assert(torch.allclose(b - a, surface.anchor_offset(anchor)))
    
            test_shape(ax, surface)
    
main()