# PHY206: General Physics for the Life Sciences II

## Tucker Knaak - Department of Physics, Creighton University - Spring 2024

#### In a simple sense, the nature of light can be described as a ray traveling in a straight line.  When light travels through an ideal thin-lens, it is refracted and its path is altered.  These ideas are combined to create the study of geometrical optics.  In this code, we allow the user to investigate geometrical optics.  The user adds lenses and an object to the system and can explore the effects on the rays of light.

#### This cell includes the libraries and functions required to use the geometrical_optics class.

In [1]:
'''Required Libraries'''
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np

'''Required Functions'''
from matplotlib.patches import Arrow, FancyArrowPatch, ConnectionPatch
from ipywidgets import interact, interactive, FloatSlider
from IPython.display import HTML, display
from copy import copy

#### This cell provides the geometrical_optics class used to investigate rays of light in a simple form.  This class can be called into other Jupyter notebooks.

In [2]:
class geometrical_optics():
    
    '''Initialize values and dictionaries of the class'''
    def __init__(self):
        
        '''System given by user'''
        self.num_lenses = 0    #total number of lenses in system
        self.color_dict = {}   #random colors for each lens {num: [lens / object / ray color, image color]}
        self.system_dict = {}  #values of system given by user {num: [focal length, object distance]}
        
        '''Values computed'''
        self.lens_dict = {}  #values of the lenses {num: [focal length, lens position, f_before position,
                             #                            f_after position, magnification, total magnification]}
        self.obj_dict = {}   #values of the objects {num: [object distance, object height, object position]}
        self.img_dict = {}   #values of the images {num: [image distance, image height, image position]}
        
        '''Patches drawn'''
        self.patch_dict = {}  #patches for system {num: [lens, object, image, rays]}
        
        '''Axis limits computed'''
        self.max_height = 0.0    #absolute value of largest object or image (force first input by user to be max height)
        self.min_height = 0.0    #negative value of max height (force first input by user to be min height)
        self.max_dist = -1000.0  #farthest focal point or image to the right (force first input by user to be max dist)
        self.min_dist = 1000.0   #farthest object or image to the left (force first input by user to be min dist)
        
        '''Animation frames'''
        self.num_frames = 25  #number of frames for each animated ray
        
        
        
    '''Internal function to create random colors for add system given by user'''
    def get_colors(self, num: int):
        
        '''Create first random color'''
        if num == 1:
            color1 = plt.cm.hsv(np.random.uniform(0.25, 1))  #lens / object / ray color
        else:
            color1 = self.color_dict[num - 1][1]  #previous image -> current object
        
        '''Create second random color'''
        while True:
            color2 = plt.cm.hsv(np.random.uniform(0.25, 1))  #image color
            if np.linalg.norm(np.subtract(color1, color2)) >= 0.2:  #force color1 & color2 to be visibly distinct
                break
                
        '''Add colors to dictionary'''
        self.color_dict[num] = [color1, color2]
        
        
        
    '''Function called by the user to add a lens to the system'''
    def add_lens(self, focal_length: float, object_distance: float):
        
        '''Total number of lenses'''
        self.num_lenses += 1
        
        '''Focal length check'''
        if float(focal_length) == 0.0:
            print('A thin lens must have a non-zero focal length (f =/= 0)!')
            while True:
                focal_length = float(input('What is the focal length of the lens [cm]? '))  #new focal length
                if focal_length != 0:
                    break
        
        '''Object distance check'''
        if float(object_distance) <= 0.0:
            print('An object must be placed in front of the lens (d_o > 0)!')
            while True:
                object_distance = float(input('What is the object distance of the lens [cm]? '))  #new object distance
                if object_distance > 0:
                    break
        
        '''Create colors for system'''
        self.get_colors(self.num_lenses)
        
        '''Add system to dictionary'''
        self.system_dict[self.num_lenses] = [float(focal_length), float(object_distance)]
        
        
        
    '''Function called by the user to delete a lens from the system'''
    def delete_lens(self, lens_number: int):
        
        '''Lens number check'''
        if int(lens_number) <= 0:                                           #lenses indexed by 1
            print('This lens is not currently in the system!')
            while True:
                lens_number = int(input('Which lens should be deleted? '))  #new lens number
                if lens_number in self.system_dict:
                    break
                    
        '''Delete chosen lens and re-index dictionaries'''
        if lens_number == self.num_lenses:                                  #no re-indexing required
            del self.system_dict[lens_number]
            del self.color_dict[lens_number]
        else:
            for num in range(lens_number, self.num_lenses):                 #re-index lenses by -1
                self.system_dict[num] = self.system_dict.pop(num + 1)
                self.color_dict[num] = self.color_dict.pop(num + 1)
        
        '''New total number of lenses'''
        self.num_lenses = len(self.system_dict)
        
        
        
    '''Internal function to compute values of the system from the thin-lens equation'''
    def thin_lens(self, num: int, focal_length: float, object_distance: float):
        
        '''Compute thin-lens and magnification equations'''
        image_distance = 1 / ((1 / focal_length) - (1 / object_distance))     #d_i = 1 / ((1 / f) - (1 / d_o))
        magnification = -image_distance / object_distance                     #m = -(d_i / d_o)
        if num == 1:
            object_height = 1.0                                               #WLOG set first object height to 1.0cm
            total_magnification = magnification
        else:
            object_height = self.img_dict[num - 1][1]                         #previous image -> current object
            total_magnification = self.lens_dict[num - 1][5] * magnification  #M = m_1 * m_2 * ... * m_n
        image_height = object_height * magnification                          #h_i = h_o * m
        
        '''Compute positions on ray diagram'''
        if num == 1:
            object_position = 0.0                                             #WLOG set first object position to 0.0cm
        else:
            object_position = self.img_dict[num - 1][2]                       #previous image -> current object
        lens_position = object_position + object_distance
        image_position = lens_position + image_distance
        f_before_position = lens_position - abs(focal_length)                 #focal point before lens
        f_after_position = lens_position + abs(focal_length)                  #focal point after lens
        
        '''Add values to dictionaries'''
        self.lens_dict[num] = [focal_length, lens_position, f_before_position, f_after_position,
                               magnification, total_magnification]
        self.obj_dict[num] = [object_distance, object_height, object_position]
        self.img_dict[num] = [image_distance, image_height, image_position]
        
        
        
    '''Internal function to find axes limits for ray diagram'''
    def get_axes(self, system: list):
        
        '''Compare new / old max height, min height, max dist, min dist'''
        for num in system:
            
            '''Vertical limits'''
            max_height = max(abs(self.obj_dict[num][1]), abs(self.img_dict[num][1]))
            if (max_height >= self.max_height):
                self.max_height = max_height
            self.min_height = -self.max_height

            '''Horizontal limits'''
            max_dist = max(self.img_dict[num][2], self.lens_dict[num][3])
            if max_dist >= self.max_dist:
                self.max_dist = max_dist
            min_dist = min([self.obj_dict[num][2], self.img_dict[num][2], self.lens_dict[num][2]])
            if (min_dist <= self.min_dist) or (self.min_dist == None):
                self.min_dist = min_dist
                
                
                
    '''Internal function to reset axes limits for ray diagram'''
    def reset_axes(self):
        self.max_height = 0.0
        self.min_height = 0.0
        self.max_dist = -1000.0
        self.min_dist = 1000.0
        
        
        
    '''Internal function to draw the lens patch for the ray diagram'''
    def draw_lens(self, num: int):
        
        '''Set color'''
        lens_color = self.color_dict[num][0]
        
        '''Converging // diverging lens patch'''
        if self.lens_dict[num][0] > 0:
            lens_patch = FancyArrowPatch((self.lens_dict[num][1], self.min_height - 0.5), (self.lens_dict[num][1],
                                          self.max_height + 0.5), arrowstyle = '<|-|>', mutation_scale = 20,
                                          linewidth = 1.5, facecolor = lens_color, edgecolor = 'darkgray', zorder = 10)
            self.patch_dict[num] = [lens_patch]  #add patch to dictionary
        else:
            lens_patch_top = FancyArrowPatch((self.lens_dict[num][1], self.max_height + 0.375), (self.lens_dict[num][1],
                                              self.max_height + 0.275), arrowstyle = '-|>', mutation_scale = 20,
                                              linewidth = 1.5, facecolor = lens_color, edgecolor = 'darkgray', zorder = 10)
            lens_patch_btm = FancyArrowPatch((self.lens_dict[num][1], self.min_height - 0.375), (self.lens_dict[num][1],
                                              self.min_height - 0.275), arrowstyle = '-|>', mutation_scale = 20,
                                              linewidth = 1.5, facecolor = lens_color, edgecolor = 'darkgray', zorder = 10)
            lens_patch_body = ConnectionPatch((self.lens_dict[num][1], self.min_height - 0.375), (self.lens_dict[num][1],
                                              self.max_height + 0.375), coordsA = 'data', coordsB = 'data',
                                              linestyle = 'solid', linewidth = 1.5, color = 'darkgray', zorder = 10)
            self.patch_dict[num] = [lens_patch_top, lens_patch_btm, lens_patch_body]  #add patches to dictionary
            
            
            
    '''Internal function to draw the object and image patches for the ray diagram'''
    def draw_obj_img(self, num: int):
        
        '''Set colors'''
        obj_color, img_color = self.color_dict[num][0], self.color_dict[num][1]
        
        '''Object and image patches'''
        obj_patch = Arrow(self.obj_dict[num][2], 0, 0, self.obj_dict[num][1], color = obj_color, width = 0.5, zorder = 10)
        img_patch = Arrow(self.img_dict[num][2], 0, 0, self.img_dict[num][1], color = img_color, width = 0.5, zorder = 10)
        
        '''Add patches to dictionary'''
        self.patch_dict[num].extend([obj_patch, img_patch])
        
        
        
    '''Internal function to draw the ray patches for the ray diagram'''
    def draw_rays(self, num: int):
        
        '''Set color'''
        ray_color = self.color_dict[num][0]
        
        '''Set linestyle'''
        linestyle_before = 'solid'
        if (self.lens_dict[num][0] < 0) or (self.img_dict[num][2] < self.obj_dict[num][2]): #(f < 0) or (img_pos < obj_pos)
            linestyle_after = 'dashed'  #virtual rays drawn with dashed lines
        else:
            linestyle_after = 'solid'   #real rays drawn with solid lines
        
        '''Ray patches'''
        ray1_before = ConnectionPatch((self.obj_dict[num][2], self.obj_dict[num][1]), (self.lens_dict[num][1],
                                       self.obj_dict[num][1]), coordsA = 'data', coordsB = 'data',
                                       linestyle = linestyle_before, color = ray_color, zorder = 5)
        ray2_before = ConnectionPatch((self.obj_dict[num][2], self.obj_dict[num][1]), (self.lens_dict[num][1], 0),
                                       coordsA = 'data', coordsB = 'data', linestyle = linestyle_before,
                                       color = ray_color, zorder = 5)
        ray1_after = ConnectionPatch((self.lens_dict[num][1], self.obj_dict[num][1]), (self.img_dict[num][2],
                                      self.img_dict[num][1]), coordsA = 'data', coordsB = 'data',
                                      linestyle = linestyle_after, color = ray_color, zorder = 5)
        ray2_after = ConnectionPatch((self.lens_dict[num][1], 0), (self.img_dict[num][2], self.img_dict[num][1]),
                                      coordsA = 'data', coordsB = 'data', linestyle = linestyle_after,
                                      color = ray_color, zorder = 5)
        
        '''Add patches to dictionary'''
        self.patch_dict[num].extend([ray1_before, ray2_before, ray1_after, ray2_after])
        
        
        
    '''Function called by user to create a ray diagram and a table of the system'''
    def ray_diagram(self, *args: int, save_diagram = False):
        
        '''Thin lens computations'''
        for num in range(1, self.num_lenses + 1):
            self.thin_lens(num, self.system_dict[num][0], self.system_dict[num][1])
        
        '''Lens systems input by user'''
        if args:
            system = list(args)                                      #specified lens systems
        else:
            system = [num for num in range(1, self.num_lenses + 1)]  #all systems
        
        '''Create figure for ray diagram and table'''
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize = (12, 10), gridspec_kw = {'height_ratios': [1.5, 1]})
        fig.tight_layout(pad = 2.0)
        self.get_axes(system)
        
        '''First plot -> ray diagram'''
        ax1.set_xlabel('Distance [cm]')
        ax1.set_ylabel('Height [cm]')
        ax1.set_title('Ray Diagram ')
        ax1.grid(True, linestyle = 'dashed', color = 'darkgray', alpha = 0.25)
        ax1.plot([self.min_dist - 5, self.max_dist + 5], [0, 0], linestyle = 'solid', linewidth = 1, color = 'black')
        
        '''Draw focal points and patches'''
        for num in system:
            ax1.scatter([self.lens_dict[num][2], self.lens_dict[num][3]], [0, 0], s = 15,
                         color = self.color_dict[num][0], zorder = 5)  #focal points
            self.draw_lens(num)
            self.draw_obj_img(num)
            self.draw_rays(num)
            for patch in self.patch_dict[num]:
                copied_patch = copy(patch)                             #copy to plot new patches each iteration
                ax1.add_patch(copied_patch)
                
        '''Second plot -> table'''
        ax2.axis('off')
        
        '''Initialize table'''
        table = []
        colors = []
        columns = ['Lens', '$\mathbf{f}$ [cm]', '$\mathbf{d_o}$ [cm]', '$\mathbf{h_o}$ [cm]', '$\mathbf{d_i}$ [cm]',
                   '$\mathbf{h_i}$ [cm]', '$\mathbf{m}$ [x]', '$\mathbf{M}$ [x]']
        
        '''Populate table'''
        for num in system:
            row = [num, round(self.lens_dict[num][0], 2), round(self.obj_dict[num][0], 2),
                   round(self.obj_dict[num][1], 2), round(self.img_dict[num][0], 2),
                   round(self.img_dict[num][1], 2), round(self.lens_dict[num][4], 2), round(self.lens_dict[num][5], 2)]
            row_color = [self.color_dict[num][0], 'white', 'white', 'white', 'white', 'white', 'white', 'white']
            table.append(row)
            colors.append(row_color)
            
        '''Create table'''
        ax2.axis('off')
        the_table = ax2.table(cellText = table, colLabels = columns, cellColours = colors,
                              loc = 'center', cellLoc = 'center')
        for (row, col), cell in the_table.get_celld().items():
            if (row == 0) or (col == -1):
                cell.set_text_props(fontweight = 'bold')  #make column headers bold text
        the_table.scale(1.0, 2.0)                         #scale height
        
        '''Save ray diagram as .png'''
        if save_diagram == True:
            user = input('Who is the user? ')
            filename = input('What is the filename? ')
            fig.savefig('c:/Users/{user}/Downloads/{filename}.png'.format(user = user, filename = filename),
                        bbox_inches = 'tight')
            
        '''Reset axes limits'''
        self.reset_axes()
        
        
        
    '''Internal function to calculate (x, y) coordinates of the rays before the lens'''
    def ray_coords_before(self, num: int, step: int):
        
        '''Set linestyle'''
        linestyle_before = 'solid'
        
        '''Horizontal distance traveled by ray'''
        delta_x = (self.obj_dict[num][0] * step) / (self.num_frames - 1)  #(d_o * step) / (num_frames - 1)
        
        '''(x, y) coordinates before lens'''
        #x = [obj_pos -> obj_pos + delta_x]
        x_coords = np.linspace(self.obj_dict[num][2], self.obj_dict[num][2] + delta_x, self.num_frames, endpoint = True)
        
        #y(x) = h_o
        ray1_coords = [self.obj_dict[num][1] for x in x_coords]
        
        #y(x) = (h_o / d_o) * (lens_pos - x)
        ray2_coords = [(self.obj_dict[num][1] / self.obj_dict[num][0]) * (self.lens_dict[num][1] - x) for x in x_coords]
        
        '''Return data'''
        return x_coords, ray1_coords, ray2_coords, linestyle_before
        
        
        
    '''Internal function to calculate (x, y) coordinates of the rays after the lens'''
    def ray_coords_after(self, num: int, step: int):
        
        '''Set linestyle'''
        linestyle_before = 'solid'
        if (self.lens_dict[num][0] < 0) or (self.img_dict[num][2] < self.obj_dict[num][2]): #(f < 0) or (img_pos < obj_pos)
            linestyle_after = 'dashed'  #virtual rays drawn with dashed lines
        else:
            linestyle_after = 'solid'   #real rays drawn with solid lines
            
        '''Horizontal distance traveled by ray'''
        delta_x = (self.img_dict[num][0] * step) / (self.num_frames - 1)  #(d_i * step) / (num_frames - 1)
        
        '''(x, y) coordinates after lens'''
        #x = [img_pos -> img_pos + delta_x]
        x_coords = np.linspace(self.lens_dict[num][1], self.lens_dict[num][1] + delta_x, self.num_frames, endpoint = True)
        
        #y(x) = ((h_i - h_o) / d_i) * (x - img_pos)
        ray1_coords = [((self.img_dict[num][1] - self.obj_dict[num][1]) / self.img_dict[num][0]) *  
                        (x - self.img_dict[num][2]) + self.img_dict[num][1] for x in x_coords]
        
        #y(x) = (h_i / d_i) * (x - lens_pos)
        ray2_coords = [(self.img_dict[num][1] / self.img_dict[num][0]) * (x - self.lens_dict[num][1]) for x in x_coords]
        
        '''Return data'''
        return x_coords, ray1_coords, ray2_coords, linestyle_after
        
        
        
    '''Function called by user to animate the path of the rays for the system'''
    def animate_ray_diagram(self, *args: int, save_animation = False):
        
        '''Thin lens computations'''
        for num in range(1, self.num_lenses + 1):
            self.thin_lens(num, self.system_dict[num][0], self.system_dict[num][1])
            
        '''Lens systems input by user'''
        if args:
            system = list(args)                                      #specified lens systems
        else:
            system = [num for num in range(1, self.num_lenses + 1)]  #all systems
            
        '''Create figure for animation'''
        fig, ax = plt.subplots(1, 1, figsize = (12, 5))
        fig.tight_layout(pad = 2.0)
        self.get_axes(system)
        ax.set_xlabel('Distance [cm]')
        ax.set_ylabel('Height [cm]')
        ax.set_title('Animated Ray Diagram')
        ax.grid(True, linestyle = 'dashed', color = 'darkgray', alpha = 0.25)
        ax.plot([self.min_dist - 5, self.max_dist + 5], [0, 0], linestyle = 'solid', linewidth = 1, color = 'black')
        
        '''Draw focal points and patches'''
        for num in system:
            ax.scatter([self.lens_dict[num][2], self.lens_dict[num][3]], [0, 0], s = 15,
                         color = self.color_dict[num][0], zorder = 5)  #focal points
            self.draw_lens(num)
            self.draw_obj_img(num)
            for patch in self.patch_dict[num]:
                copied_patch = copy(patch)                             #copy to plot new patches each iteration
                ax.add_patch(copied_patch)
                
        '''Lists of frames for animation'''
        total_frames_list = []
        ray1_frames_list = []
        ray2_frames_list = []
        
        '''Populate frame lists'''
        for num in system:
            
            '''Set color'''
            ray_color = self.color_dict[num][0]
            
            '''Find coordinates before // after lens for rays'''
            for step in range(self.num_frames * 2):
                if step < self.num_frames:
                    x_coords, ray1_coords, ray2_coords, linestyle_before = self.ray_coords_before(num, step)
                    ray1, = ax.plot(x_coords, ray1_coords, linestyle = linestyle_before, linewidth = 0.8, color = ray_color)
                    ray2, = ax.plot(x_coords, ray2_coords, linestyle = linestyle_before, linewidth = 0.8, color = ray_color)
                    ray1_frames_list.append(ray1)
                    ray2_frames_list.append(ray2)
                    total_frames_list.append(ray1_frames_list.copy() + ray2_frames_list.copy())  #append both lines
                else:
                    step -= 25  #re-index step by -25
                    x_coords, ray1_coords, ray2_coords, linestyle_after = self.ray_coords_after(num, step)
                    ray1, = ax.plot(x_coords, ray1_coords, linestyle = linestyle_after, linewidth = 0.8, color = ray_color)
                    ray2, = ax.plot(x_coords, ray2_coords, linestyle = linestyle_after, linewidth = 0.8, color = ray_color)
                    ray1_frames_list.append(ray1)
                    ray2_frames_list.append(ray2)
                    total_frames_list.append(ray1_frames_list.copy() + ray2_frames_list.copy())  #append both lines
                    
        '''Animate ray diagram'''
        ani = animation.ArtistAnimation(fig, total_frames_list, interval = 100, repeat = False)
        html = HTML(ani.to_jshtml())
        display(html)
        plt.close()
        
        '''Save animation as .gif'''
        if save_animation == True:
            user = input('Who is the user? ')
            filename = input('What is the filename? ')
            f = 'c:/Users/{user}/Downloads/{filename}.gif'.format(user = user, filename = filename)
            writergif = animation.PillowWriter(fps = len(total_frames_list) / (len(system) * 4))
            ani.save(f, writer = writergif)
        
        '''Reset axes limits'''
        self.reset_axes()
        
        
        
    '''Function to create interactive ray diagram of the system'''
    def interactive_ray_diagram(self, *args):
        
        '''Thin lens computations'''
        for num in range(1, self.num_lenses + 1):
            self.thin_lens(num, self.system_dict[num][0], self.system_dict[num][1])
            
        '''Lens systems input by user'''
        if args:
            system = list(args)                                      #specified lens systems
        else:
            system = [num for num in range(1, self.num_lenses + 1)]  #all systems
            
        '''Create interactive sliders'''
        focal_length_sliders_list = []
        lens_position_sliders_list = []
        
        '''Focal length and position of lenses'''
        for num in system:
            initial_focal_length = self.lens_dict[num][0]
            initial_lens_position = self.lens_dict[num][1]
            initial_object_position = self.obj_dict[num][2]
            focal_length_slider = FloatSlider(value = initial_focal_length, min = initial_focal_length - 20.0,
                                              max = initial_focal_length + 20.0, step = 1.0,
                                              description = '$f_{}$ [cm]'.format(num))
            lens_position_slider = FloatSlider(value = initial_lens_position, min = initial_object_position,
                                               max = initial_lens_position + 100.0, step = 1.0,
                                               description = 'Lens {} Pos. [cm]'.format(num))
            focal_length_sliders_list.append(focal_length_slider)
            lens_position_sliders_list.append(lens_position_slider)
            
        '''Populate dictionary of sliders'''
        slider_dict = {'lens{}_focal_length'.format(i + 1): focal_length_sliders_list[i]
                        for i in range(len(focal_length_sliders_list))}
        slider_dict.update({'lens{}_position'.format(i + 1): lens_position_sliders_list[i]
                             for i in range(len(lens_position_sliders_list))})
            
        
        '''Create interactive ray diagram'''
        interactive_plot = interactive(self.update_ray_diagram, **slider_dict)
        
        '''Display interactive ray diagram'''
        display(interactive_plot)
    
    '''Internal function to update ray diagram'''
    def update_ray_diagram(self, **slider_dict):
        
        '''Lens systems input by user'''
        system = [num for num in range(1, (len(slider_dict) // 2) + 1)]
        
        '''Update values of system for sliders'''
        for num in system:
            initial_object_position = self.obj_dict[num][2]
            self.system_dict[num][0] = slider_dict['lens{}_focal_length'.format(num)]
            self.system_dict[num][1] = slider_dict['lens{}_position'.format(num)] - initial_object_position
                
        '''Draw updated ray diagram'''
        self.ray_diagram(*system)