# Imager Optimization

## Function Imports and Definitions

In [None]:
import pyrayt
from pyrayt import components
import pyrayt.materials as matl
from pyrayt.utils import lensmakers_equation
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
from typing import Tuple
from tinygfx.g3d import ObjectGroup
import pandas as pd
from tinygfx.g3d.renderers import draw

def init_figure() -> Tuple[plt.Figure, plt.Axes]:
    """
    Convenience function to generate an axis with a set size and 
    """
    fig = plt.figure(figsize = (12,8))
    axis = plt.gca()
    axis.grid()
    return fig, axis


## System Constants

In [None]:
# All spatial units are mm
lens_diameter = 30
lens_thickness = 8
system_focus = 50 # The focus of the system
f_num = 2.4

## A Single Lens Imager

In [None]:
# Creating a simple Lens 
lens_material = matl.glass["ideal"]
lens_radius = 2*(lens_material.index_at(0.633)-1)*system_focus # We're discarding the lens thickness in the radius calculation
lens = pyrayt.components.thick_lens(
    r1=lens_radius, 
    r2=-lens_radius,
    thickness=lens_thickness,
    aperture=lens_diameter,
    material=lens_material)

draw(lens)

In [None]:
# Creating the Imager
imager = components.baffle((lens_diameter, lens_diameter)).move_x(system_focus)

In [None]:
# Creating the Aperture
aperture_position = system_focus / 2 # place it half way between the lens and the imager
aperture_diameter = aperture_position / f_num
aperture = components.aperture(
    size=(lens_diameter, lens_diameter), 
    aperture_size=aperture_diameter
    ).move_x(aperture_position)

In [None]:
# Create a Parallel ray set
source = components.LineOfRays(0.8*lens_diameter, wavelength = 0.633).move_x(-10)

tracer = pyrayt.RayTracer(source, [lens, aperture, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()

tracer.show(
    #color_function="source",
    ray_width=0.2,
    axis=axis,
    view='xy')
plt.show()

### Spherical Aberrations

In [None]:
n_sources = 5
source_diameter = np.linspace(0.1, 0.5, n_sources)*lens_diameter
spherical_sources = [components.CircleOfRays(x).move_x(-10) for x in source_diameter]
tracer = pyrayt.RayTracer(spherical_sources, [lens, aperture, imager])
tracer.set_rays_per_source(111)
results = tracer.trace()
tracer.calculate_source_ids() # calculates which source generated each ray and adds it to the dataframe



imager_rays = results.loc[results['surface'] == imager.get_id()]

fig, axis = init_figure()
axis.set_aspect('equal')
axis.set_xlabel("y-axis (mm)")
axis.set_ylabel("z-axis (mm)")
axis.set_title("spot size at focus for various beam diameters")
for n, radii in enumerate(source_diameter):
    source_rays_on_imager = imager_rays.loc[imager_rays['source_id']==n]
    axis.plot(source_rays_on_imager['y1'], source_rays_on_imager['z1'],'-o',label=f"{radii:.02f}mm")
plt.legend()
plt.show()



In [None]:
def spherical_aberration(system, ray_origin: float, max_radius:float, sample_points=11):

    # the souce is a line of rays only on the +y axis. It's slightly shifted so zero is not a point
    # as it would focus at infinity
    source = pyrayt.components.LineOfRays(0.9*max_radius).move_x(ray_origin).move_y(max_radius/2)


    tracer = pyrayt.RayTracer(source, system)
    tracer.set_rays_per_source(sample_points)
    results = tracer.trace()

    # Since we don't have the actual imager as a variable in the function
    # assume it is the last thing a ray intersect with, meaning the rays that hit it have the 
    # highest generation
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    
    # Intercept is calculated using the tilt for each ray, with is a normalized vector representing
    # the direction the ray is travelling
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']

    # the original radii 
    radii = results.loc[np.logical_and(results['generation']==0, results['id'].isin(imager_rays['id']))]['y0']

    # create a new dataframe with the aberration metrics
    results = pd.DataFrame({'radius': np.asarray(radii), 'focus': np.asarray(intercept)})
    return results



In [None]:
aberrations = spherical_aberration([lens, aperture, imager], -10, 0.4*lens_diameter, 21)

fig, axis = init_figure()
axis.set_title("Focal Length vs. Beam Radius for a single-lens imager")
axis.plot(aberrations['radius'], aberrations['focus'])
axis.set_xlabel("Beam Radius (mm)")
axis.set_ylabel("Focal Length (mm)")
plt.show()

In [None]:
endpoints = results.loc[results['surface']==imager.get_id()]
fig, axes = plt.subplots(1,3, sharex=True, figsize=(12,8))
for n in range(3):
    #q,r = divmod(n,2)
    ax = axes[n]
    ax.set_aspect('equal')
    source_rays=endpoints.loc[endpoints['source_id']==n]
    ax.plot(source_rays['z1'], source_rays['y1'],'o')
plt.show()

## Making a Doublet

In [None]:
# Want the power of the whole system to remain unchanged, but cancel out for the first Order Chromatic Dispersion
p_sys = 1/system_focus # the lens power of the system
convex_matl = matl.glass["BK7"]
concave_matl = matl.glass["SF5"]
lens_radius = (2*(convex_matl.index_at(0.633)-1)-(concave_matl.index_at(0.633)-1)) / p_sys
#lens_radius = np.sqrt(2)*(lens_material.index_at(0.633)-1)*system_focus

lens1 = pyrayt.components.thick_lens(62.8, -45.7, 4, aperture = lens_diameter, material = convex_matl)
lens2 = pyrayt.components.thick_lens(-45.7, -128.2, 2.5, aperture = lens_diameter, material = concave_matl).move_x(3.26)
doublet = ObjectGroup([lens1, lens2])

tracer = pyrayt.RayTracer(test_sources, [lens1, lens2, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    color_function="source",
    ray_width=0.3,
    axis=axis)
plt.show()

In [None]:
def test_source(wave):
    return pyrayt.components.LineOfRays(0, wavelength=wave).move_x(-10).move_y(lens_diameter/4)

def chromatic_abberation(system, source_function, wavelengths: np.ndarray):
    sources = [source_function(wave) for wave in wavelengths]
    tracer = pyrayt.RayTracer(sources, system)
    tracer.set_rays_per_source(1)
    results = tracer.trace()
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']
    results = pd.DataFrame({'wavelength': imager_rays['wavelength'], 'focus': intercept})
    return results

def spherical_abberation(source, system, sample_points=11):
    tracer = pyrayt.RayTracer(source, system)
    tracer.set_rays_per_source(sample_points)
    results = tracer.trace()
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']
    radii = results.loc[results['generation']==0]['y0']
    results = pd.DataFrame({'radius': np.asarray(radii), 'focus': np.asarray(intercept)})
    return results


        

In [None]:
abberation = chromatic_abberation([lens, imager], test_source, np.linspace(0.44, 0.75, 11))
abberation_doublet = chromatic_abberation([*doublet, imager], test_source, np.linspace(0.44, 0.75, 21))

fig, axis = init_figure()
axis.plot(1000*abberation['wavelength'], abberation['focus']-50)
axis.plot(1000*abberation_doublet['wavelength'], abberation_doublet['focus']-np.mean(abberation_doublet['focus']))
axis.set_xlabel("Wavelength (nm)")
axis.set_ylabel("Focal Length Shift (mm)")
plt.show()

In [None]:
abberation = spherical_abberation(pyrayt.components.LineOfRays(aperture_diameter*0.4).move(-10,0.22*aperture_diameter), [lens, imager])
abberation_doublet = spherical_abberation(pyrayt.components.LineOfRays(aperture_diameter*0.4).move(-10,0.22*aperture_diameter), [*doublet, imager])

fig, axis = init_figure()
axis.plot(abberation['radius'], abberation['focus']-system_focus)
axis.plot(abberation_doublet['radius'], abberation_doublet['focus']-np.mean(abberation_doublet['focus']))
axis.set_xlabel("Wavelength (nm)")
axis.set_ylabel("Focal Length Shift (mm)")
plt.show()

In [None]:
results.loc[results['generation'] == np.max(results['generation'])]

In [None]:
tracer = pyrayt.RayTracer(source, [optimized_lens, aperture, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis)
plt.show()

fig, axis = init_figure()
tracer.show(
    ray_width=0.025,
    axis=axis)
axis.grid()
axis.set_xlim(system_focus - 1, system_focus + 1)
axis.set_ylim(-1, 1)
plt.show()

### Chromatic Abberation

In [None]:
sources = [components.LineOfRays(0.3*lens_diameter, wavelength = x).move_x(-system_focus).rotate_z(0) for x in [0.44, 0.53, 0.65]]
tracer = pyrayt.RayTracer(sources, [optimized_lens, aperture, imager.move_x(3)])
tracer.set_rays_per_source(2)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    color_function="wavelength",
    ray_width=0.1,
    axis=axis)
plt.show()

fig, axis = init_figure()
tracer.show(
    color_function="wavelength",
    ray_width=0.025,
    axis=axis)
axis.grid()
axis.set_xlim(system_focus - .5, system_focus + 1)
axis.set_ylim(-.5, 0.5)
plt.show()