In [None]:
# Circular arc shape

import torch
import math

from torchlensmaker.shapes import CircularArc
from torchlensmaker.shapes.common import intersect_newton


import matplotlib.pyplot as plt
import numpy as np

def render_arc(ax, arc):
    points = arc.evaluate(torch.linspace(*arc.domain(), 50)).detach().numpy()

    ax.plot(points[:, 0], points[:, 1], color="steelblue")

    normalst = torch.linspace(*arc.domain(), 11)
    print(normalst)
    normals_origins = arc.evaluate(normalst).detach().numpy()
    normals_vectors = arc.normal(normalst).detach().numpy()
    for o, n in zip(normals_origins, normals_vectors):
        ax.plot([o[0], o[0]+n[0]], [o[1], o[1]+n[1]], color="grey")


def render_lines(ax, lines):
    ymin, ymax = ax.get_ylim()
    over = 0.35
    ymin = ymin - over*(ymax - ymin)
    ymax = ymax + over*(ymax - ymin)
    for line in lines:
        a, b, c = line
        x0 = (-c-b*ymin)/a
        x1 = (-c-b*ymax)/a
        ax.plot([x0, x1], [ymin, ymax], color="orange")


def render_derivative(shape):
    fig, ax = plt.subplots()

    t = torch.linspace(*shape.domain(), 50)
    diff_points = shape.derivative(t).detach().numpy()
    ax.plot(diff_points[:, 0], diff_points[:, 1])
    ax.set_aspect("equal")
    

def render_points(ax, points):
    if isinstance(points, torch.Tensor):
        points = points.detach().numpy()
    ax.scatter(points[:, 0], points[:, 1], color="green", marker="o")


def render_shape(ax, shape):
    if isinstance(shape, CircularArc):
        render_arc(ax, shape)

    ax.set_aspect("equal")

def test_shape(shape):
    print("Testing shape", shape)

    lines = torch.tensor([
        [1., 0.5, 15.],      # no intersection
        [-9.0, -1.5, -25.],  # 1st intersect
        [3.6, -5, -20],      # 2nd intersect
        [3., 0., -18.],      # 3rd intersect (vertical)
    ])

    fig, ax = plt.subplots(figsize=(12, 6))                
    render_shape(ax, shape)
    render_derivative(shape)
    render_lines(ax, lines)

    sols = intersect_newton(shape, lines)
    collisions = shape.evaluate(sols)
    render_points(ax, collisions)
    plt.show()
    

def main():
    arc = CircularArc(lens_radius=12.0, init=[-25.])
    print("domain", arc.domain())
    test_shape(arc)
    
main()