In [33]:
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor
import sys
sys.path.append('../')
from DataType.ElectricField import ElectricField
from LightSource.Gaussian_beam import Guassian_beam
from Props.ASM_Prop import ASM_prop
from Props.RSC_Prop import RSC_prop
from Components.Thin_Lens import Thin_LensElement
from Components.Aperture import ApertureElement
from Components.QuantizedDOE import SoftGumbelQuantizedDOELayerv3 as SoftGumbelQuantizedDOELayer
from Components.QuantizedDOE import NaiveGumbelQuantizedDOELayer
from Components.QuantizedDOE import PSQuantizedDOELayer
from Components.QuantizedDOE import STEQuantizedDOELayer
from Components.QuantizedDOE import FullPrecisionDOELayer
from Components.QuantizedDOE import FixDOEElement
from utils.Helper_Functions import normalize, DOE_xyz_cordinates_Generator
from utils.units import *
import numpy as np
import scipy.io
import time
import torch.nn as nn

# 1. Set simulation parameters

In [34]:
# Hologram parameters (high-temp resin)
epsilon = 2.66
tand = 0.003
    
input_field_shape = [100, 100]
input_dxy = 1 * mm
doe_shape = [100, 100]
doe_dxy = 1 * mm
    
    
c0 = 2.998e8
f2 = 300e9  # 0.3 THz

wavelengths = c0 / f2
#print("The number of wavelengths: " +  str(wavelengths.shape[0]))

# all parameters of DOE optimization
doe_params = {
    'doe_size': doe_shape,
    'doe_dxy' : doe_dxy,
    'doe_level': 4,
    'look_up_table': None,
    'num_unit': None,
    'height_constraint_max': 1 * mm,
    'tolerance': 30 * um,  
    'material': [epsilon, tand]
}

optim_params = {
    'c_s': 100,   # a number to boost the score (higher more robust to gumbel noise)
    'tau_max': 2.5, # maximum value of tau
    'tau_min': 1.5  # minimum value of tau
}


# layer num and distance between two layer
num_layer = 3
d_layer = 20 * mm

# training parameter
batch_size = 128

# 2. Define a 3-layers diffractive optical neural network (DONN)

In [35]:
class DONN(nn.Module):
    """ simulate a diffractive optical neural network with multiple DOEs
        various quantized optimization method can be used here
    """
    def __init__(self, 
                 input_dxy, 
                 input_field_shape,
                 doe_params, 
                 optim_params,
                 wavelengths, 
                 num_layer, 
                 d_layer,
                 q_method):
        super().__init__()
        
        self.input_dxy = input_dxy
        self.input_field_shape = input_field_shape
        self.doe_params = doe_params
        self.optim_params = optim_params
        
        self.wavelengths = wavelengths

        self.num_layer = num_layer
        self.d_layer = d_layer

        
        self.asm_prop2layer = ASM_prop(z_distance=50 * mm, 
                                       bandlimit_type='exact', 
                                       padding_scale=2, 
                                       bandlimit_kernel=True)

        self.aperture = ApertureElement(aperture_type = 'rect',
                                        aperture_size = 0.08)

        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        
        # define the difffractive layers
        if q_method == None:
            self.does = nn.ModuleList([FullPrecisionDOELayer(self.doe_params) for _ in range(self.num_layer)])
        elif q_method == 'sgs':
            self.does = nn.ModuleList([SoftGumbelQuantizedDOELayer(self.doe_params, self.optim_params) for _ in range(self.num_layer)])
        elif q_method == 'gs':
            self.does = nn.ModuleList([NaiveGumbelQuantizedDOELayer(self.doe_params, self.optim_params) for _ in range(self.num_layer)])
        elif q_method == 'psq':
            optim_params = {'c_s': 300,'tau_max': 800,'tau_min': 1}
            self.does = nn.ModuleList([PSQuantizedDOELayer(self.doe_params, optim_params) for _ in range(self.num_layer)])
        elif q_method == 'ste':
            self.does = n.ModuleList([STEQuantizedDOELayer(self.doe_params, self.optim_params) for _ in range(self.num_layer)])

        self.asm_prop_layer = ASM_prop(z_distance=self.d_layer, 
                                       bandlimit_type='exact', 
                                       padding_scale=2, 
                                       bandlimit_kernel=True)
    
        
        self.asm_prop2detector = ASM_prop(z_distance=50 * mm, 
                                          bandlimit_type='exact', 
                                          padding_scale=2, 
                                          bandlimit_kernel=True)
        
        

    def encode_object(self, u):
        # Create plane wave
        amplitude_fields = torch.ones(1, 1, self.input_field_shape[0], self.input_field_shape[1], device=self.device) * u
        phase_fields = torch.zeros_like(amplitude_fields, device=self.device)
        # Combine amplitude and phase to form complex electric fields
        electric_fields = amplitude_fields * torch.exp(1j * phase_fields) # add Batch dimension

        field = ElectricField(
                data=electric_fields, 
                wavelengths=self.wavelengths,
                spacing=self.input_dxy
        )
        
        field = self.asm_prop2layer(field)
        field = self.aperture(field)
        
        return field
    
    def forward(self, u, iter_frac=None):

        inputs = self.encode_object(u)

        for i in range(self.num_layer-1):
            field = self.does[i](inputs, iter_frac)
            field = self.asm_prop_layer(field)
            field = self.aperture(field)

        field = self.does[-1](inputs, iter_frac)
        outputs = self.asm_prop2detector(field)
        
        return outputs

In [36]:
donn = DONN(input_dxy=input_dxy, 
            input_field_shape=input_field_shape, 
            doe_params=doe_params, 
            optim_params=optim_params,
            wavelengths=wavelengths, 
            num_layer=num_layer,
            d_layer=d_layer, 
            q_method=None)

In [37]:
donn

DONN(
  (asm_prop2layer): ASM_prop()
  (aperture): ApertureElement()
  (does): ModuleList(
    (0-2): 3 x FullPrecisionDOELayer()
  )
  (asm_prop_layer): ASM_prop()
  (asm_prop2detector): ASM_prop()
)

# 3. Define a Label Generator and Dataloader

In [41]:
trans =transforms.Compose([Resize(input_field_shape),ToTensor()])

mnist_train = torchvision.datasets.MNIST(
        root="./data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.MNIST(
        root="./data", train=False, transform=trans, download=True) 
train_set, val_set, _ = torch.utils.data.random_split(mnist_train, [4096, 512, 60000-4096-512])
    
# train, validation and test
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=batch_size,
                                           shuffle=True, num_workers=8,pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_set, 
                                         batch_size=batch_size,
                                         shuffle=True, num_workers=4)

# 4. Define a train and validation pipelines

# 5. Training the model with different quantized methods