In [None]:
## Initialize Hyperparameters and import libraries

import numpy as np
import torch
import gc


import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
import operator
from functools import reduce
from functools import partial


batch_size = 20
epochs = 50

In [None]:
## Import Data
# Trained and tested for 20nm wavelength spacing

n_total = 3200
n_train = 3000
n_test = 200

train_structs = np.empty((n_train, 1,  225, 64), np.float32)
train_results = np.empty((n_train, 2, 11, 225, 64), np.float32)
test_structs = np.empty((n_test, 1, 225, 64), np.float32)
test_results = np.empty((n_test, 2, 11, 225, 64),  np.float32)


for i in range (n_train):

    tmp_struct = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez_Fermi\\data_Ez\\patternings\\patterning_"+f'{i:08d}'+".npy")
    tmp_result = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez_Fermi\\data_Ez\\true_results\\result_Ex_"+f'{i:08d}'+".npy")

    sampled_result = tmp_result[::3]

    train_structs[i][0][:15] = np.ones((1*15,64))
    train_structs[i][0][15:75] = np.repeat(np.repeat(tmp_struct, 15, axis=0), 2, axis=1)
    train_structs[i][0][75:] = np.ones((10*15,64))*2
    train_results[i][0] = np.real(sampled_result)
    train_results[i][1] = np.imag(sampled_result)

    if i%500 ==0 :
        print(i)

for i in range (n_test):

    tmp_struct = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez\\patternings\\patterning_"+f'{i:08d}'+".npy")
    tmp_result = np.load("D:\\Sunghyun Nam\\SANZABOO\\data_Ez\\true_results\\result_Ex_"+f'{i:08d}'+".npy")

    sampled_result = tmp_result[::3]

    test_structs[i][0][:15] = np.ones((1*15,64))
    test_structs[i][0][15:75] = np.repeat(np.repeat(tmp_struct, 15, axis=0), 2, axis=1)
    test_structs[i][0][75:] = np.ones((10*15,64))*2
    test_results[i][0] = np.real(sampled_result)
    test_results[i][1] = np.imag(sampled_result)

    if i%500 ==0 :
        print(i)


In [None]:
# Define 2D FNO Model
# Code mostly copied from https://github.com/neuraloperator/neuraloperator/tree/main

# Complex multiplication
def compl_mul2d(a, b):
    # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
    op = partial(torch.einsum, "bixy,ioxy->boxy")
    return torch.stack([
        op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
    ], dim=-1)

################################################################
# 2D Fourier layer
################################################################
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x.to(torch.float32)

# 2D FNO Model: input structure output Ex field

class _2DFNO(nn.Module):
    def __init__(self, modes_x, modes_y, width, FNO_layer_num, FC_neuron):
        super(_2DFNO, self).__init__()

        # default values
        self.modes_x = modes_x
        self.modes_y = modes_y
        self.width = width
        self.FNO_layer_num = FNO_layer_num
        self.FC_neuron = FC_neuron
        
        self.conv_layers = nn.ModuleList([
            SpectralConv2d(self.width, self.width, self.modes_x, self.modes_y) for _ in range(self.FNO_layer_num)
        ])

        self.w_layers = nn.ModuleList([
            nn.Conv2d(self.width, self.width, 1) for _ in range(self.FNO_layer_num)
        ])
        self.fc0 = nn.Linear(1, self.width)

        self.fc1_ori = nn.Linear(self.width, self.FC_neuron)
        self.fc2_ori = nn.Linear(self.FC_neuron, 22)


    def forward(self, x):

        x = x.permute(0, 2, 3, 1)
        x = self.fc0(x)
        x = F.gelu(x)
        x = x.permute(0, 3, 1, 2) # Batch size, channels, dims

        for i in range(self.FNO_layer_num-1):
            x1 = self.conv_layers[i](x)
            x2 = self.w_layers[i](x)
            x = x1 + x2
            x = F.gelu(x)

        x1 = self.conv_layers[i](x)
        x2 = self.w_layers[i](x)
        x = x1 + x2

        x = x.permute(0, 2, 3, 1)
        x = self.fc1_ori(x)
        x = F.gelu(x)
        x = self.fc2_ori(x)

        x = x.permute(0, 3, 1, 2)

        return x.reshape(x.shape[0], 2, 11, x.shape[2], x.shape[3])
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    
    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))
        return c

In [None]:
# Prepare training and test data for Ex

train_input = torch.tensor(train_structs, dtype=torch.float32)
train_output = torch.tensor(train_results, dtype=torch.float32)

test_input = torch.tensor(test_structs, dtype=torch.float32)
test_output = torch.tensor(test_results, dtype=torch.float32)

# train and test loader for model_x

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_input, train_output), batch_size=batch_size, shuffle=True)
test_loader  = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_input, test_output), batch_size=batch_size, shuffle=False)

In [None]:
# Train model

Train_rmse_arr = [] 
Test_rmse_arr = []

model = _2DFNO(modes_x = 32, modes_y=15, width=10, FNO_layer_num=4, FC_neuron=128).cuda()
print(model.count_params())

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay = 0)

step_size = 50
gamma = 0.5
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
total_time = 0

gc.collect()
torch.cuda.empty_cache()

loss = nn.MSELoss()

for ep in range(epochs):
    t1 = default_timer()
    model.train()
    Train_mse = 0
    for input_shape, result in train_loader:
        input_shape, result  = input_shape.cuda(), result.cuda()
        optimizer.zero_grad()
        out = model(input_shape)
        Train_mse_temp = loss(out.reshape(batch_size, -1), result.reshape(batch_size, -1))
        Train_mse_temp.backward()
        optimizer.step()
        Train_mse += np.float64(Train_mse_temp.item())*batch_size
    scheduler.step()
    
    model.eval()
    Test_mse = 0.0
    with torch.no_grad():
        for input_shape, result in test_loader:
            input_shape, result  = input_shape.cuda(), result.cuda()
            out = model(input_shape)
            Test_mse_temp = loss(out.reshape(batch_size, -1), result.reshape(batch_size, -1))
            Test_mse += np.float64(Test_mse_temp.item())*batch_size

    Train_mse /= n_train
    Test_mse /= n_test

    Train_rmse = np.sqrt(Train_mse)      
    Test_rmse = np.sqrt(Test_mse)

    Train_rmse_arr.append(Train_rmse)
    Test_rmse_arr.append(Test_rmse)

    t2 = default_timer()
    print(ep, t2-t1, Train_rmse, Test_rmse)
    total_time = total_time + t2 -t1

In [None]:
# Number of epochs
epochs = len(Train_rmse_arr)
epoch_nums = np.arange(1, epochs + 1)

# Plotting the learning curve for training and test RMSE
plt.figure(figsize=(10, 6))
plt.plot(epoch_nums, Train_rmse_arr, label="Training RMSE", color="blue", marker="o")
plt.plot(epoch_nums, Test_rmse_arr, label="Test RMSE", color="red", marker="x")

plt.xlabel("Epochs")
plt.ylabel("RMSE")
plt.title("Learning Curve: 400nm")
plt.legend()

plt.grid(True)
plt.show()

In [None]:
# Generate a random index from the test set
n = np.random.randint(0, n_test)


# Prepare the sample input and pass it through the model
tmp_sample = torch.tensor(np.reshape(test_structs[n], (1, 1, 225, -1)))
sample_result = model(tmp_sample.cuda()).cpu().detach().numpy()
sample_result = np.squeeze(sample_result)

fig, axes = plt.subplots(1, 2, figsize=(12, 6)) 

wavelength = 1

# Left plot: FNO output
im1 = axes[0].imshow(sample_result[0][wavelength], aspect = 0.4)
axes[0].set_title("FNO Output")
fig.colorbar(im1, ax=axes[0])

# Right plot: True output
im2 = axes[1].imshow(test_results[n][0][wavelength], aspect = 0.4)
axes[1].set_title("True Output")
fig.colorbar(im2, ax=axes[1])

# Display the plot
plt.show()