In [None]:
## Initialize Hyperparameters and import libraries

import numpy as np
import torch
import gc

import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
import operator
from functools import reduce
from functools import partial

batch_size = 25
epochs = 50
n_train = 3000
n_test = 200
n_total = n_train + n_test

In [None]:
# Encoding wave prior: for details on wave prior, search following paper
# https://arxiv.org/abs/2209.10098

def wave_prior_y(design_space, wavelength):

    del_y_s = 500/15
    del_y_d = 1500/60
    del_y_p = 500/150

    wave_prior = design_space # shape of return is same as design space

    for i in range(np.shape(design_space)[0]):
        if i<15:
            y = del_y_s*i
            tmp = 2*np.pi*y/wavelength
            wave_prior[i,:] = np.sin(design_space[i,:]*tmp)

        elif i<75:
            y = del_y_s*15+del_y_d*(i-15)
            tmp = 2*np.pi*y/wavelength
            wave_prior[i,:] = np.sin(design_space[i,:]*tmp)

        else:
            y = del_y_s*15+del_y_d*60 + del_y_p*(i-65)
            tmp = 2*np.pi*y/wavelength
            wave_prior[i,:] = np.sin(design_space[i,:]*tmp)

    return wave_prior

In [None]:
def wave_prior_x(design_space, wavelength):

    del_x = 1000/32
    
    wave_prior = design_space # shape of return is same as design space

    for i in range(np.shape(design_space)[1]):
        x = del_x*i
        tmp = 2*np.pi*x/wavelength
        wave_prior[:,i] = np.sin(design_space[:,i]*tmp)

    return wave_prior

In [None]:
# Check wave prior implementation

very_tmp = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez_Fermi\\data_Ez\\patternings\\patterning_"+f'{2:08d}'+".npy")
tmp_wl = 500
tmp_sv = np.empty((225,64))
tmp_sv[:15] = np.ones((1*15,64))
tmp_sv[15:75] = np.repeat(np.repeat(very_tmp, 15, axis=0), 2, axis=1)
tmp_sv[75:] = np.ones((10*15,64))*2

plt.imshow(wave_prior_y(tmp_sv, tmp_wl), aspect=0.4)
plt.colorbar()


In [None]:
## Import Data: Train

wav_len_step = 10
wav_len_list = [400 + i * wav_len_step for i in range((300) // wav_len_step + 1)]
wl_list_train = wav_len_list
n_wl_train = len(wl_list_train)

train_wprior  = np.empty((n_train*n_wl_train, 1, 225, 64), np.float32)
train_results = np.empty((n_train*n_wl_train, 2, 225, 64), np.float32)

for i in range (n_train):

    tmp_struct = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez_Fermi\\data_Ez\\patternings\\patterning_"+f'{i:08d}'+".npy")
    tmp_result = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez_Fermi\\data_Ez\\true_results\\result_Ex_"+f'{i:08d}'+".npy")

    sampled_result = tmp_result

    train_space = np.empty((225,64))
    train_space[:15] = np.ones((1*15,64))
    train_space[15:75] = np.repeat(np.repeat(tmp_struct, 15, axis=0), 2, axis=1)
    train_space[75:] = np.ones((10*15,64))*2

    for k in range(n_wl_train):
        
        tmp_wl = wl_list_train[k]

        train_wprior[i*n_wl_train + k][0] = wave_prior_y(np.copy(train_space), tmp_wl) # y waveprior
        #train_wprior[i*n_wl + k][1] = wave_prior_x(tmp_sv, wl_list[k]) # x waveprior
        train_results[i*n_wl_train + k][0] = np.real(sampled_result[k])
        train_results[i*n_wl_train + k][1] = np.imag(sampled_result[k])

    if i%500 ==0 :
        print(i)

In [None]:
## Import Data: Test

wl_list_test = wav_len_list
n_wl_test = len(wl_list_test)

test_wprior   = np.empty((n_test*n_wl_test, 1, 225, 64), np.float32)
test_results  = np.empty((n_test*n_wl_test, 2, 225, 64), np.float32)

for i in range (n_test):

    tmp_struct = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez\\patternings\\patterning_"+f'{i:08d}'+".npy")
    tmp_result = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez\\true_results\\result_Ex_"+f'{i:08d}'+".npy")

    sampled_result = tmp_result

    test_space = np.empty((225,64))
    test_space[:15] = np.ones((1*15,64))
    test_space[15:75] = np.repeat(np.repeat(tmp_struct, 15, axis=0), 2, axis=1)
    test_space[75:] = np.ones((10*15,64))*2

    for k in range(n_wl_test):

        tmp_wl = wl_list_test[k]

        test_wprior[i*n_wl_test + k] = wave_prior_y(np.copy(test_space), tmp_wl) # y waveprior
        #train_wprior[i*n_wl + k][1] = wave_prior_x(tmp_sv, wl_list[k]) # x waveprior

        test_results[i*n_wl_test + k][0] = np.real(sampled_result[k])
        test_results[i*n_wl_test + k][1] = np.imag(sampled_result[k])

    if i%500 ==0 :
        print(i)


In [None]:
# Define 2D FNO Model
# Code mostly copied from https://github.com/neuraloperator/neuraloperator/tree/main

#Complex multiplication
def compl_mul2d(a, b):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    op = partial(torch.einsum, "bixy,ioxy->boxy")
    return torch.stack([
        op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
    ], dim=-1)

################################################################
# 2D Fourier layer
################################################################
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x.to(torch.float32)

# 2D FNO Model: input wave prior, output Ex

class _2DFNO(nn.Module):
    def __init__(self, modes_x, modes_y, width, FNO_layer_num, FC_neuron):
        super(_2DFNO, self).__init__()

        self.modes_x = modes_x
        self.modes_y = modes_y
        self.width = width
        self.FNO_layer_num = FNO_layer_num
        self.FC_neuron = FC_neuron
        
        self.conv_layers = nn.ModuleList([
            SpectralConv2d(self.width, self.width, self.modes_x, self.modes_y) for _ in range(self.FNO_layer_num)
        ])

        self.w_layers = nn.ModuleList([
            nn.Conv2d(self.width, self.width, 1) for _ in range(self.FNO_layer_num)
        ])
        self.fc0 = nn.Linear(1, self.width)

        self.fc1_ori = nn.Linear(self.width, self.FC_neuron)
        self.fc2_ori = nn.Linear(self.FC_neuron, 2)


    def forward(self, x):


        x = x.permute(0, 2, 3, 1)
        #grid = self.get_grid(x.shape, x.device)
        x = self.fc0(x)
        x = F.gelu(x)
        x = x.permute(0, 3, 1, 2) # Batch size, channels

        for i in range(self.FNO_layer_num-1):
            x1 = self.conv_layers[i](x)
            x2 = self.w_layers[i](x)
            x = x1 + x2
            x = F.gelu(x)

        x1 = self.conv_layers[i](x)
        x2 = self.w_layers[i](x)
        x = x1 + x2

        x = x.permute(0, 2, 3, 1)
        x = self.fc1_ori(x)
        x = F.gelu(x)
        x = self.fc2_ori(x)

        x = x.permute(0, 3, 1, 2)

        return x
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    
    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))
        return c

In [None]:
# Prepare training and test data for Ex

train_input = torch.tensor(train_wprior, dtype=torch.float32)
train_output = torch.tensor(train_results, dtype=torch.float32)

test_input = torch.tensor(test_wprior, dtype=torch.float32)
test_output = torch.tensor(test_results, dtype=torch.float32)

# train and test loader for model_x

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_input, train_output), batch_size=batch_size, shuffle=True)
test_loader  = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_input, test_output), batch_size=batch_size, shuffle=False)

In [None]:
# Train model

model = _2DFNO(modes_x=32, modes_y=15, width=10, FNO_layer_num=4, FC_neuron=64).cuda()
print(f"Total parameters in the model: {model.count_params()}")

# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.5e-2, weight_decay=0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Loss function
loss_fn = nn.MSELoss()

# Track training and testing RMSE over time
Train_rmse_arr = [] 
Test_rmse_arr = []

# Free up GPU memory before starting training
gc.collect()
torch.cuda.empty_cache()

# Training loop
total_time = 0

for ep in range(epochs):
    t1 = default_timer()
    model.train()
    Train_mse = 0

    for input_shape, result in train_loader:
        input_shape, result = input_shape.cuda(), result.cuda()
        optimizer.zero_grad()
        
        # Forward pass through the model with the hypernetwork
        out = model(input_shape)
        
        # Calculate loss
        Train_mse_temp = loss_fn(out.reshape(batch_size, -1), result.reshape(batch_size, -1))
        torch.cuda.empty_cache()
        Train_mse_temp.backward()
        
        # Update weights
        optimizer.step()
        
        Train_mse += Train_mse_temp.item() * batch_size

    scheduler.step()

    # Validation loop
    model.eval()
    Test_mse = 0.0
    with torch.no_grad():
        for input_shape, result in test_loader:
            input_shape, result = input_shape.cuda(), result.cuda()

            # Forward pass through the model during testing
            out = model(input_shape)
            Test_mse_temp = loss_fn(out.reshape(batch_size, -1), result.reshape(batch_size, -1))
            Test_mse += Test_mse_temp.item() * batch_size

    # Compute RMSE for train and test sets
    Train_mse /= len(train_loader.dataset)
    Test_mse /= len(test_loader.dataset)

    Train_rmse = np.sqrt(Train_mse)
    Test_rmse = np.sqrt(Test_mse)

    # Store RMSE values for each epoch
    Train_rmse_arr.append(Train_rmse)
    Test_rmse_arr.append(Test_rmse)

    t2 = default_timer()
    total_time += t2 - t1

    print(f"Epoch {ep+1}, Time: {t2-t1:.2f}s, Train RMSE: {Train_rmse:.4f}, Test RMSE: {Test_rmse:.4f}")

# Training complete
print(f"Training completed in {total_time:.2f} seconds.")

In [None]:
# Number of epochs
epochs = len(Train_rmse_arr)
epoch_nums = np.arange(1, epochs + 1)  # Array of epoch numbers

# Plotting the learning curve for training and test RMSE
plt.figure(figsize=(10, 6))
plt.plot(epoch_nums, Train_rmse_arr, label="Training RMSE", color="blue", marker="o")
plt.plot(epoch_nums, Test_rmse_arr, label="Test RMSE", color="red", marker="x")

# Adding labels, title, and legend
plt.xlabel("Epochs")
plt.ylabel("RMSE")
plt.title("Learning Curve: 400nm")
plt.legend()

# Show the grid and plot
plt.grid(True)
plt.show()

In [None]:
# Generate a random index from the test set
n = np.random.randint(0, n_test)

# Prepare the sample input and pass it through the model
tmp_sample = torch.tensor(np.reshape(test_wprior[n], (1, 1, 225, -1)))
sample_result = model(tmp_sample.cuda()).cpu().detach().numpy()
sample_result = np.squeeze(sample_result)

# Create subplots for side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # 1 row, 2 columns, adjust figure size as needed

# Left plot: FNO output
im1 = axes[0].imshow(sample_result[0], aspect = 0.4)
axes[0].set_title("FNO Output")
fig.colorbar(im1, ax=axes[0])

# Right plot: True output
im2 = axes[1].imshow(test_results[n][0], aspect = 0.4)
axes[1].set_title("True Output")
fig.colorbar(im2, ax=axes[1])

# Display the plot
plt.show()