# Cubic Bezier Spline

- Definition with lens constraints: C1 or G1 continuity, tangent at 0 = 0
- Line intersection with Newton's method accross all intervals, differentiable with no grad trick
- Normal at the intersection
- Resampling (aka splitting)

=> possibly use sigmoid function to constraint Cx in [0, 1], in the interval between knot?
=> or just positive constraint?


Parametrization:
    r: radius
    n: number of intervals (min 1)
    X: knots X values = linspace(0, r)
    Y: knots Y values
    C: control points

    X.shape = n+1 | all values fixed
    Y.shape = n parameters | (+first value implicitly 0)
    CX.shape = n+1 parameters
    CY.shape = n parameters | (+first value implicitly 0)

    for origin at radius, translate all points by y value of curve at radius

    t: bezier parameter, from 0 to n


evaluation at t:
    integer part of t gives the interval, and so relevant knots and control points
    decimal part of t gives the parameters for matrix form of bezier curve

normal at t:
    like evaluation but use derivative matrix
    rotate pi/2 with [-y, x]

resample:
    compute new knots by evaluating at the new X
    compute new control points by interpolation

intersection:
    newton's method over the piecewise defined equation of the intersection with line ax+by+c=0


In [None]:
import torch

from torchlensmaker.shapes import BezierSpline
from torchlensmaker.shapes.common import intersect_newton, mirror_points

from torchlensmaker.testing import render_lines, render_spline, render_collision_points


import matplotlib.pyplot as plt


def test():
    X = torch.tensor([-2.7340])
    CX = torch.tensor([-4.5])
    CY = torch.tensor([1.23, 9.])

    spline = BezierSpline(7.27*2, X, CX, CY)
    
    spline.dump()
    lines = torch.tensor([
        # no intersection
        [1., 0.5, 10.],
        [0.5, 1., 1/15.],
        [3., 0., -18.],   # vertical
        [0., 1., 5.],     # horizontal
        
        # 1 intersection
        [-9.0, -1.5, -25.],
        [3.6, -5, -10],
        [-5, 3.6, -20],

        # 2 intersections
        [-3., 0.5, -3.],
        [3., 0., 5.],   # vertical
        
        # edges cases: no collision but nearest is within the domain
        [-3., 0.7, 1.],
        [3., 0., -1.], # vertical

    ])

    fig, ax = plt.subplots(figsize=(12, 6))                
    render_spline(ax, spline)
    render_lines(ax, lines, xlim=(-10, 10))

    sols = intersect_newton(spline, lines)

    # Keep only valid solutions (within domain)
    valid = torch.logical_and(sols <= spline.domain()[1], sols >= spline.domain()[0])
    sols = sols[valid]
  
    collisions = spline.evaluate(sols)
    render_collision_points(ax, collisions)
    ax.set_xlim([-10, 10])
    ax.set_ylim([-10, 10])
    plt.show()


    for i in range(3):
        spline = spline.wiggle(cx=0.05, cy=0.05, x=0.1)
        spline = spline.resample()
        fig, ax = plt.subplots(figsize=(12, 6))
        render_spline(ax, spline)
        render_lines(ax, lines, xlim=(-10, 10))
        sols = intersect_newton(spline, lines)
        
        # Keep only valid solutions (within domain)
        valid = torch.logical_and(sols <= spline.domain()[1], sols >= spline.domain()[0])
        sols = sols[valid]
        
        collisions = spline.evaluate(sols)
        render_collision_points(ax, collisions)
        ax.set_xlim([-10, 10])
        ax.set_ylim([-10, 10])
        plt.show()
    

test()

In [None]:
def test2():
    lens_width = 15.0
    spline = BezierSpline(height=2*lens_width, X=[5.1431], CX=[6.6509], CY=[7.9189, 16.8972])

    lines = torch.tensor([
        [  1.0000,  -0.5000,  13.4000],
        [  1.0000,  -0.0000,  10.4222],
        [  1.0000,  -0.0000,   1.4889],
        [  1.0000,  -0.0000,  -1.4889],
        [  1.0000,  -0.0000, -10.4222],
        [  1.0000,  -0.0000, -13.4000]
    ])

    print("domain", spline.domain())

    fig, ax = plt.subplots(figsize=(12, 6))                
    render_spline(ax, spline)
    render_lines(ax, lines, xlim=(-10, 10))

    sols = intersect_newton(spline, lines)

    valid = torch.logical_and(sols <= spline.domain()[1], sols >= spline.domain()[0])
    sols = sols[valid]
    
    collisions = spline.evaluate(sols)
    render_collision_points(ax, collisions)
    plt.show()
    

test2()