# Imager Optimization

## Function Imports and Definitions

In [None]:
import pyrayt
from pyrayt import components
import pyrayt.materials as matl
from pyrayt.utils import lensmakers_equation
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
from typing import Tuple
from tinygfx.g3d import ObjectGroup
import pandas as pd


from tinygfx.g3d.renderers import draw

# Set the plot font and size for better readability
font = {'family':'sans-serif',
        'size'   : 18}
matplotlib.rc('font', **font)

def init_figure() -> Tuple[plt.Figure, plt.Axes]:
    """
    Convenience function to generate an axis with a set size and 
    """
    fig = plt.figure(figsize = (12,8))
    axis = plt.gca()
    axis.grid()
    return fig, axis


## System Constants

In [None]:
# All spatial units are mm
lens_diameter = 25.4
system_focus = 50 # The focus of the system
f_num = 2.4

# This is bigger than a 35mm sensor, but 
sensor_size = (36,24)
sensor_spot_diameter = np.linalg.norm(sensor_size)


## A Single Lens Imager

In [None]:
# Creating a BK7 Lens 
lens_material = matl.glass["BK7"]
lens_thickness = 5
lens_radius = 2*(lens_material.index_at(0.633)-1)*system_focus # We're discarding the lens thickness in the radius calculation

# Creating the Aperture
aperture_position = system_focus / 2 # place it half way between the lens and the imager
aperture_diameter = aperture_position/f_num
aperture = components.aperture((lens_diameter, lens_diameter), aperture_diameter).move_x(aperture_position)

lens = components.biconvex_lens(lens_radius, lens_radius, lens_thickness, aperture=lens_diameter, material = lens_material)
imager = components.baffle((sensor_spot_diameter, sensor_spot_diameter)).move_x(system_focus)

source = components.LineOfRays(0.5*lens_diameter, wavelength = 0.633).move_x(-system_focus).rotate_z(10)

tracer = pyrayt.RayTracer(source, [lens, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis)
plt.show()

## Making a Doublet

In [None]:
# Want the power of the whole system to remain unchanged, but cancel out for the first Order Chromatic Dispersion
p_sys = 1/system_focus # the lens power of the system
convex_matl = matl.glass["BK7"]
concave_matl = matl.glass["SF5"]
lens_radius = (2*(convex_matl.index_at(0.633)-1)-(concave_matl.index_at(0.633)-1)) / p_sys
#lens_radius = np.sqrt(2)*(lens_material.index_at(0.633)-1)*system_focus

lens1 = pyrayt.components.thick_lens(62.8, -45.7, 4, aperture = lens_diameter, material = convex_matl)
lens2 = pyrayt.components.thick_lens(-45.7, -128.2, 2.5, aperture = lens_diameter, material = concave_matl).move_x(3.26)
imager = components.baffle((sensor_spot_diameter, sensor_spot_diameter)).move_x(100)
doublet = ObjectGroup([lens1, lens2])

tracer = pyrayt.RayTracer(source, [lens1, lens2, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.3,
    axis=axis)
plt.show()

In [None]:
def test_source(wave):
    return pyrayt.components.LineOfRays(0, wavelength=wave).move_x(-10).move_y(lens_diameter/4)

def chromatic_abberation(system, source_function, wavelengths: np.ndarray):
    sources = [source_function(wave) for wave in wavelengths]
    tracer = pyrayt.RayTracer(sources, system)
    tracer.set_rays_per_source(1)
    results = tracer.trace()
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']
    results = pd.DataFrame({'wavelength': imager_rays['wavelength'], 'focus': intercept})
    return results


        # find the max generation limit and find where they intersect

        

In [None]:
abberation = chromatic_abberation([lens, imager], test_source, np.linspace(0.44, 0.75, 11))
abberation_doublet = chromatic_abberation([*doublet, imager], test_source, np.linspace(0.44, 0.75, 21))

fig, axis = init_figure()
axis.plot(1000*abberation['wavelength'], abberation['focus']-50)
axis.plot(1000*abberation_doublet['wavelength'], abberation_doublet['focus']-np.mean(abberation_doublet['focus']))
axis.set_xlabel("Wavelength (nm)")
axis.set_ylabel("Focal Length Shift (mm)")
plt.show()

In [None]:
results.loc[results['generation'] == np.max(results['generation'])]

In [None]:
tracer = pyrayt.RayTracer(source, [optimized_lens, aperture, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis)
plt.show()

fig, axis = init_figure()
tracer.show(
    ray_width=0.025,
    axis=axis)
axis.grid()
axis.set_xlim(system_focus - 1, system_focus + 1)
axis.set_ylim(-1, 1)
plt.show()

### Chromatic Abberation

In [None]:
sources = [components.LineOfRays(0.3*lens_diameter, wavelength = x).move_x(-system_focus).rotate_z(0) for x in [0.44, 0.53, 0.65]]
tracer = pyrayt.RayTracer(sources, [optimized_lens, aperture, imager.move_x(3)])
tracer.set_rays_per_source(2)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    color_function="wavelength",
    ray_width=0.1,
    axis=axis)
plt.show()

fig, axis = init_figure()
tracer.show(
    color_function="wavelength",
    ray_width=0.025,
    axis=axis)
axis.grid()
axis.set_xlim(system_focus - .5, system_focus + 1)
axis.set_ylim(-.5, 0.5)
plt.show()