In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
from torch.utils.data import TensorDataset, DataLoader, random_split
import random
import matplotlib.pyplot as plt

In [None]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)
        )
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)
        )

    def compl_mul2d(self, input, weights):
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfft2(x)

        out_ft = torch.zeros(
            batchsize,
            self.out_channels,
            x.size(-2),
            x.size(-1) // 2 + 1,
            dtype=torch.cfloat,
            device=x.device,
        )

        out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, : self.modes1, : self.modes2], self.weights1
        )
        out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d(
            x_ft[:, :, -self.modes1 :, : self.modes2], self.weights2
        )

        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x


class FNO2d(nn.Module):
    def __init__(self, fno_architecture, device=None, padding_frac=1 / 4):
        super(FNO2d, self).__init__()

        self.modes1 = fno_architecture["modes"]
        self.modes2 = fno_architecture["modes"]
        self.width = fno_architecture["width"]
        self.n_layers = fno_architecture["n_layers"]
        self.retrain_fno = fno_architecture["retrain_fno"]

        torch.manual_seed(self.retrain_fno)
        self.padding_frac = padding_frac
        self.fc0 = nn.Linear(3, self.width)

        self.conv_list = nn.ModuleList(
            [nn.Conv2d(self.width, self.width, 1) for _ in range(self.n_layers)]
        )
        self.spectral_list = nn.ModuleList(
            [
                SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
                for _ in range(self.n_layers)
            ]
        )

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

        if device is not None:
            self.to(device)

    def forward(self, x):
        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)

        x1_padding = int(round(x.shape[-1] * self.padding_frac))
        x2_padding = int(round(x.shape[-2] * self.padding_frac))
        x = nn.functional.pad(x, [0, x1_padding, 0, x2_padding])

        for k, (s, c) in enumerate(zip(self.spectral_list, self.conv_list)):
            x1 = s(x)
            x2 = c(x)
            x = x1 + x2
            if k != self.n_layers - 1:
                x = nn.functional.gelu(x)

        x = x[..., :-x1_padding, :-x2_padding]

        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = nn.functional.gelu(x)
        x = self.fc2(x)
        return x



In [None]:
fno_architecture = {
    "modes": 32,
    "width": 64,
    "n_layers": 12,
    "retrain_fno": 42
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize the FNO model
fno = FNO2d(fno_architecture, device=device)

In [None]:
model_save_dir = '/content/drive/My Drive/Colab Notebooks/saved_models'
model_filename = 'fno_epoch_1000.pth'  
model_path = os.path.join(model_save_dir, model_filename)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
fno.load_state_dict(torch.load(model_path, map_location=device))
fno.eval()

In [None]:
input_tensor = torch.load('/content/drive/My Drive/Colab Notebooks/input_tensor_test.pt')
output_tensor = torch.load('/content/drive/My Drive/Colab Notebooks/output_tensor_test.pt')

In [None]:
dataset = TensorDataset(input_tensor, output_tensor)
train_size = int(0 * len(dataset))
test_size = len(dataset) - train_size
training_set, testing_set = random_split(dataset, [train_size, test_size])
test_loader = DataLoader(testing_set, batch_size=1, shuffle=False)
mse_loss = nn.MSELoss()
total_loss = 0.0
N=7
plot_data = {}

for idx, (input_sample, output_sample) in enumerate(test_loader):

    input_sample = input_sample.to(device)  
    output_sample = output_sample.to(device)  

    # Use the trained FNO to produce the output data
    with torch.no_grad():
        output_pred = fno(input_sample)  
    loss = mse_loss(output_pred, output_sample)
    total_loss += loss.item()
    input_sample_np = input_sample.cpu().numpy().squeeze(0) 
    output_true_np = output_sample.cpu().numpy().squeeze(0) 
    output_pred_np = output_pred.cpu().numpy().squeeze(0)    

    input_a = input_sample_np[:, :, 0]     # [837, 837]
    true_sample = output_true_np[:, :, 0]  # [837, 837]
    pred_sample = output_pred_np[:, :, 0]  # [837, 837]

    # Plot only the first N samples
    if idx < N:
        # Plot the input a(x,y), true output, and predicted output
        plt.figure(figsize=(18, 5))

        # First Subplot: Input a(x, y)
        plt.subplot(1, 3, 1)
        plt.imshow(input_a, cmap='viridis', origin='lower')
        plt.axis('off')  # Optionally remove axes
        plt.colorbar()

        # Second Subplot: True Output
        plt.subplot(1, 3, 2)
        plt.imshow(true_sample, cmap='viridis', origin='lower')
        plt.axis('off')  # Optionally remove axes
        plt.colorbar()

        # Third Subplot: Predicted Output
        plt.subplot(1, 3, 3)
        plt.imshow(pred_sample, cmap='viridis', origin='lower')
        plt.axis('off')  # Optionally remove axes
        plt.colorbar()

        plt.tight_layout()
        plt.show()

        # Extract the center line along y at the center x-coordinate
        center_x = true_sample.shape[0] // 2  # Assuming shape[1] is the x-dimension
        y_values = np.arange(true_sample.shape[1])  # y-coordinates

        true_center_line = true_sample[ :,center_x]
        pred_center_line = pred_sample[ :,center_x]

        # Plot the center line values
        plt.figure(figsize=(8, 6))
        index = idx
        # Calculate the range of the axes to make them equal
        #x_min, x_max = min(y_values), max(y_values)
        #y_min, y_max = min(min(true_center_line), min(pred_center_line)), max(max(true_center_line), max(pred_center_line))
        #verall_min = min(x_min, y_min)
        #overall_max = max(x_max, y_max)

        # Save the data for the current plot
        plot_data[index] = {
            'y_values': y_values,
            'true_center_line': true_center_line,
            'pred_center_line': pred_center_line
        }
        with plt.rc_context({'font.size': 16}):
            plt.plot(y_values, true_center_line, label='Simulation', color='blue')
            plt.plot(y_values, pred_center_line, label='FNO', color='red', linestyle='--')

            # Set title to the index of the datum
            plt.title(f'Index: {index}')

            # Remove labels
            # plt.xlabel('y-coordinate')
            # plt.ylabel('Value')

            plt.legend()
            ax = plt.gca()

            # Automatically fit x and y axis into a square plot
            ax.set_aspect('equal', adjustable='box')  # Maintain equal aspect ratio with adjustable limits

            # Tighten the layout
            plt.tight_layout()

            # Remove x and y axis numbers
            ax.set_xticks([])
            ax.set_yticks([])

            plt.show()

        plt.tight_layout()
        plt.show()

# Calculate the average MSE loss over all samples
average_loss = total_loss / len(test_loader)
print(f'MSE loss on the test set: {average_loss}')

Output hidden; open in https://colab.research.google.com to view.

In [None]:
for idx, (input_sample, output_sample) in enumerate(test_loader):
    input_sample = input_sample.to(device)
    output_sample = output_sample.to(device)

    with torch.no_grad():
        output_pred = fno(input_sample)

    loss = mse_loss(output_pred, output_sample)
    total_loss += loss.item()

    input_sample_np = input_sample.cpu().numpy().squeeze(0)
    output_true_np = output_sample.cpu().numpy().squeeze(0)
    output_pred_np = output_pred.cpu().numpy().squeeze(0)

    input_a = input_sample_np[:, :, 0]
    true_sample = output_true_np[:, :, 0]
    pred_sample = output_pred_np[:, :, 0]

    if idx < N:
        plt.figure(figsize=(18, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(input_a, cmap='viridis', origin='lower')
        plt.axis('off')
        plt.colorbar()
        plt.title('Input a(x, y)')

        plt.subplot(1, 3, 2)
        plt.imshow(true_sample, cmap='viridis', origin='lower')
        plt.axis('off')
        plt.colorbar()
        plt.title('True Output')

        plt.subplot(1, 3, 3)
        plt.imshow(pred_sample, cmap='viridis', origin='lower')
        plt.axis('off')
        plt.colorbar()
        plt.title('Predicted Output')

        plt.tight_layout()
        plt.show()

        center_x = true_sample.shape[1] // 2
        y_values = np.arange(true_sample.shape[0])

        true_center_line = true_sample[:,center_x]
        pred_center_line = pred_sample[:,center_x]

        plt.figure(figsize=(10, 6))
        index = idx

        plot_data[index] = {
            'y_values': y_values,
            'true_center_line': true_center_line,
            'pred_center_line': pred_center_line
        }

        with plt.rc_context({'font.size': 16}):
            plt.plot(y_values, true_center_line, label='Simulation', color='blue')
            plt.plot(y_values, pred_center_line, label='FNO', color='red', linestyle='--')
            plt.title(f'Center Line Plot - Index: {index}')
            plt.xlabel('y-coordinate')
            plt.ylabel('Value')
            plt.legend()
            ax = plt.gca()
            ax.set_aspect('auto', adjustable='box')
            plt.tight_layout()
            ax.set_xticks([])
            ax.set_yticks([])
            plt.show()

        fft_true = np.fft.fft(true_center_line)
        fft_pred = np.fft.fft(pred_center_line)
        freqs = np.fft.fftfreq(len(y_values), d=1)
        pos_mask = freqs >= 0
        freqs_pos = freqs[pos_mask]
        fft_true_pos = fft_true[pos_mask]
        fft_pred_pos = fft_pred[pos_mask]

        magnitude_true = np.abs(fft_true_pos)
        magnitude_pred = np.abs(fft_pred_pos)

        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.plot(freqs_pos, magnitude_true, color='blue')
        plt.title('FFT of True Center Line')
        plt.xlabel('Frequency')
        plt.ylabel('Magnitude')
        plt.grid(True)

        plt.subplot(1, 2, 2)
        plt.plot(freqs_pos, magnitude_pred, color='red')
        plt.title('FFT of Predicted Center Line')
        plt.xlabel('Frequency')
        plt.ylabel('Magnitude')
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        plt.figure(figsize=(10, 6))
        plt.plot(freqs_pos, magnitude_true, label='Simulation FFT', color='blue')
        plt.plot(freqs_pos, magnitude_pred, label='FNO FFT', color='red', linestyle='--')
        plt.title(f'FFT Comparison - Index: {index}')
        plt.xlabel('Frequency')
        plt.ylabel('Magnitude')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


In [None]:
dataset = TensorDataset(input_tensor, output_tensor)
train_size = int(0 * len(dataset))
test_size = len(dataset) - train_size
training_set, testing_set = random_split(dataset, [train_size, test_size])

M = 10  

random_indices = random.sample(range(len(testing_set)), M)

input_samples = []
output_samples = []

for idx in random_indices:
    input_sample, output_sample = testing_set[idx]
    input_samples.append(input_sample)
    output_samples.append(output_sample)