In [None]:

## TODO LIST

# move everything to absolute space
# handle rays not colliding better
# add back regression priors
# resample between profile shapes
# better batches normal / intersection code
# export polygon data for freecad import
# build 3D model
# tests

In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt
from itertools import islice

from torchlensmaker.raytracing import super_refraction, clamped_refraction, ray_point_squared_distance, refraction
from torchlensmaker.optics import FocalPointLoss, ParallelBeamUniform, ParallelBeamRandom, FixedGap, RefractiveSurface, OpticalStack
from torchlensmaker.shapes import Parabola, PiecewiseLine, Line, CircularArc, BezierSpline
from torchlensmaker.training import render, optimize

In [None]:
class DemoPiecewise(OpticalStack):
    def __init__(self, radius, focal_length, num_edges):
        super().__init__()
        self.radius = radius
        self.focal_length = focal_length
        self.num_edges = num_edges

        delta = radius / num_edges # width of an edge
        self.x = np.linspace(delta, radius, num_edges, dtype=np.float32)
        y = np.linspace(0., -5.0, num_edges+1)[1:]
    
        self.coeffs = nn.Parameter(torch.as_tensor(y).clone(), requires_grad=True)

    def make_stack(self):
        profile = PiecewiseLine(self.x, self.coeffs)

        return [
            ParallelBeamRandom(radius=self.radius),
            FixedGap(15.),
            RefractiveSurface(profile, n1=1.49, n2=1.0),
            FixedGap(self.focal_length),
            FocalPointLoss(),
        ]


def demo_piecewise():
    radius = math.sqrt(2)*25/2
    focal_length = 25.0
    num_edges = 8

    optics = DemoPiecewise(radius, focal_length, num_edges)

    render(optics, num_rays=10, render_width=radius)

    optimize(
        optics,
        optimizer = optim.Adam(optics.parameters(), lr=5e-2),
        num_rays = 8,
        num_iter = 200,
        render_width = radius
    )

    render(optics, num_rays=20, render_width=radius)

demo_piecewise()