# Imager Optimization

## Function Imports and Definitions

In [None]:
import pyrayt
from pyrayt import components
import pyrayt.materials as matl
from pyrayt.utils import lensmakers_equation
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
import scipy.optimize
from typing import Tuple
from tinygfx.g3d import ObjectGroup
import pandas as pd
from tinygfx.g3d.renderers import draw
from tinygfx.g3d import ObjectGroup

def init_figure() -> Tuple[plt.Figure, plt.Axes]:
    """
    Convenience function to generate an axis with a set size and 
    """
    fig = plt.figure(figsize = (12,8))
    axis = plt.gca()
    axis.grid()
    return fig, axis


## System Constants

In [None]:
# All spatial units are mm
lens_diameter = 25.4
lens_thickness = 5
system_focus = 50 # The focus of the system
p_sys = 1/system_focus
f_num = 2.4

## A Single Lens Imager

In [None]:
# Creating a simple Lens 
lens_material = matl.glass["BK7"]
lens_radius = 2*(lens_material.index_at(0.532)-1)/p_sys
lens = pyrayt.components.thick_lens(
    r1=lens_radius, 
    r2=-lens_radius,
    thickness=lens_thickness,
    aperture=lens_diameter,
    material=lens_material)

draw(lens)

In [None]:
# Creating the Imager
imager = components.baffle((lens_diameter, lens_diameter)).move_x(system_focus)

In [None]:
# Creating the Aperture
aperture_position = system_focus / 2 # place it half way between the lens and the imager
aperture_diameter = aperture_position / f_num
aperture = components.aperture(
    size=(lens_diameter, lens_diameter), 
    aperture_size=aperture_diameter
    ).move_x(aperture_position)

In [None]:
# Create a Parallel ray set
source = components.LineOfRays(0.8*lens_diameter, wavelength = 0.633).move_x(-10)

tracer = pyrayt.RayTracer(source, [lens, aperture, imager])
tracer.set_rays_per_source(11)
results = tracer.trace()
fig, axis = init_figure()

tracer.show(
    #color_function="source",
    ray_width=0.2,
    axis=axis,
    view='xy')
plt.show()

### Spherical Aberrations

In [None]:
n_sources = 5
source_diameter = np.linspace(0.1, 0.5, n_sources)*lens_diameter
spherical_sources = [components.CircleOfRays(x).move_x(-10) for x in source_diameter]
tracer = pyrayt.RayTracer(spherical_sources, [lens, aperture, imager])
tracer.set_rays_per_source(111)
results = tracer.trace()
tracer.calculate_source_ids() # calculates which source generated each ray and adds it to the dataframe



imager_rays = results.loc[results['surface'] == imager.get_id()]

fig, axis = init_figure()
axis.set_aspect('equal')
axis.set_xlabel("y-axis (mm)")
axis.set_ylabel("z-axis (mm)")
axis.set_title("spot size at focus for various beam diameters")
for n, radii in enumerate(source_diameter):
    source_rays_on_imager = imager_rays.loc[imager_rays['source_id']==n]
    axis.plot(source_rays_on_imager['y1'], source_rays_on_imager['z1'],'-o',label=f"{radii:.02f}mm")
plt.legend()
plt.show()



In [None]:
def spherical_aberration(system, ray_origin: float, max_radius:float, sample_points=11):

    # the souce is a line of rays only on the +y axis. It's slightly shifted so zero is not a point
    # as it would focus at infinity
    source = pyrayt.components.LineOfRays(0.9*max_radius).move_x(ray_origin).move_y(max_radius/2)


    tracer = pyrayt.RayTracer(source, system)
    tracer.set_rays_per_source(sample_points)
    results = tracer.trace()

    # Since we don't have the actual imager as a variable in the function
    # assume it is the last thing a ray intersect with, meaning the rays that hit it have the 
    # highest generation
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    
    # Intercept is calculated using the tilt for each ray, with is a normalized vector representing
    # the direction the ray is travelling
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']

    # the original radii 
    radii = results.loc[np.logical_and(results['generation']==0, results['id'].isin(imager_rays['id']))]['y0']

    # create a new dataframe with the aberration metrics
    results = pd.DataFrame({'radius': np.asarray(radii), 'focus': np.asarray(intercept)})
    return results



In [None]:
aberrations = spherical_aberration([lens, aperture, imager], -10, 0.4*lens_diameter, 21)

fig, axis = init_figure()
axis.set_title("Focal Length vs. Beam Radius for a single-lens imager")
axis.plot(aberrations['radius'], aberrations['focus'])
axis.set_xlabel("Beam Radius (mm)")
axis.set_ylabel("Focal Length (mm)")
plt.show()

### Chromatic Aberrations

In [None]:
def chromatic_abberation(system, ray_origin: float, test_radius:float, wavelengths: np.ndarray) -> pd.DataFrame:
    # create a set of sources for every wavelength of light
    sources = [
        pyrayt.components.LineOfRays(0, wavelength = wave)
        .move_y(test_radius)
        .move_x(ray_origin) 
        for wave in wavelengths]
    
    # Create the ray tracer and propagate
    tracer = pyrayt.RayTracer(sources, system)
    tracer.set_rays_per_source(1)
    results = tracer.trace()

    #filter the rays that intersect the imager
    imager_rays = results.loc[results['generation'] == np.max(results['generation'])]
    
    # calculate intercept of the imager rays with the x-axis and form into a dataframe
    intercept = -imager_rays['x_tilt']*imager_rays['y0']/imager_rays['y_tilt'] + imager_rays['x0']
    results = pd.DataFrame({'wavelength': imager_rays['wavelength'], 'focus': intercept})


    return results



In [None]:
aberrations = chromatic_abberation([lens, aperture, imager], -10, lens_diameter/4, np.linspace(0.45, 0.75, 11))

fig, axis = init_figure()
axis.set_title("Focal Length vs. wavelength for a single-lens imager")
axis.plot(1000*aberrations['wavelength'], aberrations['focus'])
axis.set_xlabel("Beam Radius (mm)")
axis.set_ylabel("wavelength (nm)")
plt.show()

### Coma Aberration

In [None]:
# Visualizing with three sources
#angle_sources= [pyrayt.components.LineOfRays(x * lens_diameter).move_x(-10).rotate_z(12) for x in np.linspace(0.1, 0.4, 3)]

source = pyrayt.components.LineOfRays(0.9 * lens_diameter).move_x(-10).rotate_z(12)
tracer = pyrayt.RayTracer(source, [lens, imager])
tracer.set_rays_per_source(111)
tracer.trace()
fig, axis = init_figure()
tracer.show(
    axis=axis,
    ray_width=0.15
)

In [None]:
# View the Coma distortion at each angle
tracer.calculate_source_ids()
results = tracer.get_results()
ray_set = results.loc[np.logical_and(results['source_id']==source, results['surface']==imager.get_id())]
fig, axis = init_figure()
ray_set.hist('y1', ax=axis)
axis.set_title("Y-Axis intersection with Imaging Plane")
axis.set_xlabel('Ray Focus (mm)')
    

In [None]:
def coma(system, ray_origin: float, max_radius:float, angle: float) -> float:
    source = pyrayt.components.LineOfRays(2*max_radius).rotate_x(90).move_x(-10).rotate_z(angle)
    
    # pin the system so the transform is undone at the end
    tracer = pyrayt.RayTracer(source, system)
    tracer.set_rays_per_source(11)
    results = tracer.trace()
        
    ray_set = results.loc[results['generation'] == np.max(results['generation'])]
    return np.mean(np.square((np.sin(ray_set['y_tilt'])-np.sin(angle*np.pi/180))))

In [None]:
coma([lens, imager], -10, 0.25*lens_diameter, 10)


In [None]:
np.sin(10*np.pi/180)

## Making a Doublet

In [None]:
# make a function that can be minimzed to find the optimum radius of curvature assuming a thick lens
def thick_lens_radius(power: float, thickness: float, index: float) -> float:
    return (index-1)*(1+np.sqrt(1-power*thickness/index))/power

def power(r1, r2, thickness, index):
    return (index-1)*(1/r1 - 1/r2 + (index-1)*thickness/(index*r1*r2))

In [None]:
# Want the power of the whole system to remain unchanged, but cancel out for the first Order Chromatic Dispersion
# Define the two materials we will use
matl1 = matl.glass["BK7"]
matl2 = matl.glass["SF2"]

# will make the first lens out of bk7 and the second out of sf5
# get abbe numbers for each material
v1 = matl1.abbe()
v2 = matl2.abbe()

# calculate the individual lens powers based on the dispersion
p1 = p_sys * v1/(v1 - v2)
p2 = p_sys * v2/(v2 - v1)

l1_thickness = 8
l2_thickness = 2

In [None]:
n1 = matl1.index_at(0.533)
n2 = matl2.index_at(0.533)

factor = 2.0
r2 = (n2-1)/((2/factor)*p2)
r3 = factor*r2
r1 = (p1/(n1-1)+1/r2)**-1
r1, r2, r3

In [None]:
# create lenses as a doublet
l1 = pyrayt.components.thick_lens(r1, r2, l1_thickness, aperture = lens_diameter, material=matl1)
l2 = pyrayt.components.thick_lens(r2,r3, 2, aperture = lens_diameter, material=matl2).move_x(1.01*(l1_thickness+l2_thickness)/2)
doublet = ObjectGroup([l1, l2])

source = pyrayt.components.LineOfRays(0.5*lens_diameter).move_x(-10).rotate_z(0)

tracer = pyrayt.RayTracer(source, [l1, l2, aperture, imager])
tracer.set_rays_per_source(5)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis
)

In [None]:
## Optimization functions to correct focus

def constant_power(radii) -> float:
    # returns the deviation of the system power from the actual power, used for solver
    n1 = bk7.index_at(0.633)
    p1 = (n1-1)*(1/radii[0] -1/radii[1]+ (n1-1)*l1_thickness/(n1*radii[0]*radii[1]))
    n2 = sf5.index_at(0.633)
    p2 = (n2-1)*(1/radii[2] -1/radii[3]+ (n2-1)*l2_thickness/(n2*radii[2]*radii[3]))

    # return the different between desired power and actual power
    return p1+p2-p_sys

constraints = [
    {'type':'eq', 'fun': constant_power}
]

def doublet_performance(radius) -> float:
    # make the system 
    l1 = pyrayt.components.thick_lens(radius[0], r2, l1_thickness, aperture = lens_diameter, material=matl1)
    l2 = pyrayt.components.thick_lens(r2, r3, l2_thickness, aperture = lens_diameter, material=matl2).move_x(1.01*(l1_thickness+l2_thickness)/2)

    system = [l1, l2, imager] # ignore the aperture for now
    
    # Get the chromatic aberration coefficients
    sphere = spherical_aberration(system, -10, lens_diameter/4, 10)
    return np.mean(np.square(sphere['focus']-system_focus))
    


In [None]:
optimization = scipy.optimize.minimize(doublet_performance, [r1])
r1 = optimization.x[0]

In [None]:
# create lenses as a doublet
l1 = pyrayt.components.thick_lens(r1, r2, l1_thickness, aperture = lens_diameter, material=matl1)
l2 = pyrayt.components.thick_lens(r2, r3, 2, aperture = lens_diameter, material=matl2).move_x(1.01*(l1_thickness+l2_thickness)/2)
doublet = ObjectGroup([l1, l2])

source = pyrayt.components.LineOfRays(0.5*lens_diameter).move_x(-10).rotate_z(0)

tracer = pyrayt.RayTracer(source, [l1, l2, imager])
tracer.set_rays_per_source(5)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis
)

In [None]:
chroma = [chromatic_abberation(sys, -10, 0.25*lens_diameter, np.linspace(0.45, 0.75, 11)) for sys in [(lens, imager), (l1, l2, imager)]]
sphere = [spherical_aberration(sys, -10, 0.25*lens_diameter, 11) for sys in [(lens, imager), (l1, l2, imager)]]

fig, axis = init_figure()
for c in chroma:
    axis.plot(c['wavelength'], c['focus']-np.mean(c['focus']))
plt.show()

fig, axis = init_figure()
for c in sphere:
    axis.plot(c['radius'], c['focus']-np.mean(c['focus']))
plt.show()

In [None]:
## Optimization functions to correct focus

p1 = power(r1, r2, l1_thickness, matl1.index_at(0.533))
p2 = power(r2, r3, l2_thickness, matl2.index_at(0.533))
opt_power = p1+p2 # optimum power

def inner_radii(radii) -> float:
    return np.abs(radii[2]) - np.abs(radii[1])

def r1_dir(radii) -> float:
    return radii[0]

def r2_dir(radii) -> float:
    return -radii[1]

def r3_dir(radii) -> float:
    return -radii[2]

def r4_dir(radii) -> float:
    return -radii[3]

constraints = [
    {'type':'ineq', 'fun': inner_radii},
    {'type':'ineq', 'fun': r1_dir},
    {'type':'ineq', 'fun': r2_dir},
    {'type':'ineq', 'fun': r3_dir},
    {'type':'ineq', 'fun': r4_dir},
]

def doublet_performance(radius) -> float:
    # make the system 
    l1 = pyrayt.components.thick_lens(radius[0], radius[1], l1_thickness, aperture = lens_diameter, material=matl1)
    l2 = pyrayt.components.thick_lens(radius[2], radius[3], l2_thickness, aperture = lens_diameter, material=matl2).move_x(1.001*(l1_thickness+l2_thickness)/2)

    system = [l1, l2, imager] # ignore the aperture for now
    
    # Get the chromatic aberration coefficients
    chroma = chromatic_abberation(system, -10, lens_diameter/4, np.linspace(0.45,0.7,11))
    chroma_error =  np.mean(np.square(chroma['focus']-system_focus))
    w_chroma = 1

    # Get the spherical aberration coefficients
    sphere = spherical_aberration(system, -10, lens_diameter/4, 10)
    sphere_error = np.mean(np.square(sphere['focus']-system_focus))
    w_sphere = 1

    # coma_error = coma([l1, l2, imager], -10, 0.25*lens_diameter, 10)
    # w_coma = 10000

    return chroma_error*w_chroma + sphere_error*w_sphere #+ coma_error * w_coma

In [None]:
optimization = scipy.optimize.minimize(doublet_performance, [r1, r2, r2, r3], constraints=constraints)

In [None]:
optimization

In [None]:
# create lenses as a doublet
l1 = pyrayt.components.thick_lens(optimization.x[0], optimization.x[1], l1_thickness, aperture = lens_diameter, material=matl1)
l2 = pyrayt.components.thick_lens(optimization.x[2], optimization.x[3], l2_thickness, aperture = lens_diameter, material=matl2).move_x(1.01*(l1_thickness+l2_thickness)/2)
doublet = ObjectGroup([l1, l2])

source = pyrayt.components.LineOfRays(0.5*lens_diameter).move_x(-10).rotate_z(0)

tracer = pyrayt.RayTracer(source, [l1, l2, imager])
tracer.set_rays_per_source(5)
results = tracer.trace()
fig, axis = init_figure()
tracer.show(
    ray_width=0.1,
    axis=axis
)

In [None]:
chroma = [chromatic_abberation(sys, -10, 0.25*lens_diameter, np.linspace(0.45, 0.75, 11)) for sys in [(lens, imager), (l1, l2, imager)]]
sphere = [spherical_aberration(sys, -10, 0.25*lens_diameter, 11) for sys in [(lens, imager), (l1, l2, imager)]]

fig = plt.figure(figsize=(16,8))
axis = plt.subplot(121)
axis.set_xlabel("Light Wavelength (nm)")
axis.set_ylabel("Focal Length shift (mm)")
axis.set_title("Chromatic Aberration")
[axis.plot(c['wavelength'], c['focus']-np.mean(c['focus'])) for c in chroma]
axis.grid()
plt.legend(("single lens","doublet lens"))

axis = plt.subplot(122)
axis.grid()
axis.set_title("Spherical Aberration")
axis.set_xlabel("Lens Radius (mm)")
axis.set_ylabel("Focal Length shift (mm)")
[axis.plot(c['radius'], c['focus']-np.mean(c['focus'])) for c in sphere]
plt.legend(("single lens","doublet lens"))
plt.show()

In [None]:
# Visualizing with three sources
angle_sources= [pyrayt.components.LineOfRays(0.5 * lens_diameter).move_x(-10).rotate_z(12),
pyrayt.components.LineOfRays(0.5*lens_diameter).rotate_x(90).move_x(-10).rotate_z(12),
pyrayt.components.LineOfRays(0.5*lens_diameter).rotate_x(90).move_x(-10).move_y(5).rotate_z(12)]

tracer = pyrayt.RayTracer(angle_sources, [lens, imager])
tracer.set_rays_per_source(5)
results = tracer.trace()
tracer.calculate_source_ids()
fig, axis = init_figure()
tracer.show(
    axis=axis,
    color_function="source",
    ray_width=0.1
)

In [None]:
imager_rays = results.loc[results['surface']==imager.get_id()]
fig, axis = init_figure()
axis.set_aspect('equal')
for source in (0,1,2):
    subset = imager_rays.loc[imager_rays['source_id']==source]
    axis.plot(subset['z1'], subset['y1']-10.5,'o')

In [None]:
np.cos(10*np.pi/180)

In [None]:
x = np.linspace(-np.pi/3, np.pi/3, 1001)
y1 = x
y2 = x - x**3/np.math.factorial(3)
y3 = np.sin(x)

fig, axis = init_figure()
axis.plot(x*180/np.pi, y1)
axis.plot(x*180/np.pi, y2)
axis.plot(x*180/np.pi, y3)
plt.show()