In [1]:
import time
from pylab import *
import matplotlib.gridspec as gridspec

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.gridspec as gridspec
import numpy as np
import random
import torch.nn.functional as F
import math


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
content_path = '/content/gdrive/MyDrive/ic_project/new'

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:

#construct background states, observations with error

def x_to_y(X): # averaging in 2*2 windows (4 pixels)
    dim = X.shape[0]
    dim = 20
    Y = np.zeros((int(dim/2),int(dim/2)))
    for i in range(int(dim/2)):
        for j in range(int(dim/2)):
            Y[i,j] = X[2*i,2*j] + X[2*i+1,2*j] + X[2*i,2*j+1] + X[2*i+1,2*j+1]

            Y_noise = np.random.multivariate_normal(np.zeros(100),0.0000 * np.eye(100))
            Y_noise.shape = (10,10)
            Y = Y + Y_noise
    return Y


class shallow(object):


    time = 0

    plt = []
    fig = []


    def __init__(self, x=[],y=[],h_ini = 1.,u=[],v = [],dx=0.01,dt=0.0001, N=64,L=1., px=16, py=16, R=64, Hp=0.1, g=1., b=0.): # How define no default argument before?


        # add a perturbation in pressure surface


        self.px, self.py = px, py
        self.R = R
        self.Hp = Hp



        # Physical parameters

        self.g = g
        self.b = b
        self.L=L
        self.N=N

        # limits for h,u,v


        #self.dx =  self.L / self.N # a changer
        #self.dt = self.dx / 100.
        self.dx=dx
        self.dt=dt

        self.x,self.y = mgrid[:self.N,:self.N]

        self.u=zeros((self.N,self.N))
        self.v=zeros((self.N,self.N))

        self.h_ini=h_ini

        self.h=self.h_ini * ones((self.N,self.N))

        rr = (self.x-px)**2 + (self.y-py)**2
        self.h[rr<R] = self.h_ini + Hp #set initial conditions

        self.lims = [(self.h_ini-self.Hp,self.h_ini+self.Hp),(-0.02,0.02),(-0.02,0.02)]



    def dxy(self, A, axis=0):
        """
        Compute derivative of array A using balanced finite differences
        Axis specifies direction of spatial derivative (d/dx or d/dy)
        dA[i]/dx =  (A[i+1] - A[i-1] )  / 2dx
        """
        return (roll(A, -1, axis) - roll(A, 1, axis)) / (self.dx*2.) # roll: shift the array axis=0 shift the horizontal axis

    def d_dx(self, A):
        return self.dxy(A,1)

    def d_dy(self, A):
        return self.dxy(A,0)


    def d_dt(self, h, u, v):
        """
        http://en.wikipedia.org/wiki/Shallow_water_equations#Non-conservative_form
        """
        for x in [h, u, v]: # type check
           assert isinstance(x, ndarray) and not isinstance(x, matrix)

        g,b,dx = self.g, self.b, self.dx

        du_dt = -g*self.d_dx(h) - b*u
        dv_dt = -g*self.d_dy(h) - b*v

        H = 0 #h.mean() - our definition of h includes this term
        dh_dt = -self.d_dx(u * (H+h)) - self.d_dy(v * (H+h))

        return dh_dt, du_dt, dv_dt


    def evolve(self):
        """
        Evolve state (h, u, v) forward in time using simple Euler method
        x_{N+1} = x_{N} +   dx/dt * d_t
        """

        dh_dt, du_dt, dv_dt = self.d_dt(self.h, self.u, self.v)
        dt = self.dt

        self.h += dh_dt * dt
        self.u += du_dt * dt
        self.v += dv_dt * dt
        self.time += dt

        return self.h, self.u, self.v





In [None]:

# Generate the dataset
def generate_dataset(num_simulations=500, snapshots_per_simulation=200, N=64):
    dataset = np.zeros((num_simulations * (snapshots_per_simulation//50), 3, N, N))
    first_snapshots = np.zeros((num_simulations, 3, N, N))
    paras = np.zeros((num_simulations * snapshots_per_simulation, 5))
    snapshot_idx = 0

    for sim in range(num_simulations):

        px=random.randint(54, 74)*1.
        py=random.randint(54, 74)*1.
        R=random.randint(80, 160)*1.
        Hp=random.randint(5, 20)*0.01
        b=random.randint(1, 100)*0.1

        SW = shallow(N=N, px=px, py=py, R=R, Hp=Hp, b=b)

        for step in range(snapshots_per_simulation):
            SW.evolve()
            if (step)%50==0:
              dataset[snapshot_idx, 0, :, :] = SW.u
              dataset[snapshot_idx, 1, :, :] = SW.v
              dataset[snapshot_idx, 2, :, :] = SW.h
              paras[snapshot_idx, 0] = px
              paras[snapshot_idx, 1] = py
              paras[snapshot_idx, 2] = R
              paras[snapshot_idx, 3] = Hp
              paras[snapshot_idx, 4] = b
              snapshot_idx += 1
              # Save the first snapshot of the simulation
              if step == snapshots_per_simulation-1:
                  first_snapshots[sim, 0, :, :] = SW.u
                  first_snapshots[sim, 1, :, :] = SW.v
                  first_snapshots[sim, 2, :, :] = SW.h
    return dataset,first_snapshots,paras




def min_max_normalize(dataset):
    u_min = dataset[:, 0, :, :].min()
    u_max = dataset[:, 0, :, :].max()
    v_min = dataset[:, 1, :, :].min()
    v_max = dataset[:, 1, :, :].max()
    h_min = dataset[:, 2, :, :].min()
    h_max = dataset[:, 2, :, :].max()

    dataset[:, 0, :, :] = (dataset[:, 0, :, :] - u_min) / (u_max - u_min)
    dataset[:, 1, :, :] = (dataset[:, 1, :, :] - v_min) / (v_max - v_min)
    dataset[:, 2, :, :] = (dataset[:, 2, :, :] - h_min) / (h_max - h_min)

    return dataset, [u_min, u_max, v_min, v_max, h_min, h_max]

# Generate the dataset and save it
dataset,first_snapshots,paras = generate_dataset(num_simulations=100, snapshots_per_simulation=5000, N=128)
dataset, normalization_params= min_max_normalize(dataset)


construct 5 to 5 data

In [None]:


# Number of simulations and snapshots per simulation
num_simulations = 100
snapshots_per_simulation = 100
slice_steps = 5
interval = 3
N = 128

# Initialize arrays for slices and their corresponding targets
num_slices_per_simulation = snapshots_per_simulation - (2 * slice_steps-1) * interval
total_slices = num_simulations * num_slices_per_simulation

slices = np.zeros((total_slices, slice_steps, 3, N, N))
targets = np.zeros((total_slices, slice_steps, 3, N, N))
slice_paras = np.zeros((total_slices, slice_steps,5))

# Generate slices and targets
slice_idx = 0
for sim in range(num_simulations):
    for start in range(num_slices_per_simulation):
        sim_start_idx = sim * snapshots_per_simulation
        input_indices = [sim_start_idx + start + i * interval for i in range(slice_steps)]
        target_indices = [sim_start_idx + start + slice_steps * interval + i * interval for i in range(slice_steps)]

        # Ensure indices are within the simulation range
        if target_indices[-1] < (sim + 1) * snapshots_per_simulation:
            slices[slice_idx] = dataset[input_indices]
            targets[slice_idx] = dataset[target_indices]
            slice_paras[slice_idx] = paras[input_indices]
            slice_idx += 1

# Remove unused portion of arrays if any
slices = slices[:slice_idx]
targets = targets[:slice_idx]
train_ratio = 0.8
num_samples = targets.shape[0]
num_train = int(num_samples * train_ratio)
targets_test_oriimg = torch.tensor(targets[num_train:], dtype=torch.float32).to(device)



In [7]:
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

Define architecture of CAE-KAN

In [8]:
class PowerfulAutoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(PowerfulAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # (128, 128, 3) -> (64, 64, 32)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # (64, 64, 32) -> (32, 32, 64)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (32, 32, 64) -> (16, 16, 128)
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (16, 16, 128) -> (8, 8, 256)
            nn.ReLU(),
            nn.Flatten(),
            # nn.Linear(8*8*256, latent_dim)  # (8*8*256) -> (latent_dim)
            KANLinear(8*8*256, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 8*8*256),  # (latent_dim) -> (8*8*256)
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (8, 8, 256) -> (16, 16, 128)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (16, 16, 128) -> (32, 32, 64)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # (32, 32, 64) -> (64, 64, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # (64, 64, 32) -> (128, 128, 3)
            nn.Sigmoid()  # Use sigmoid if your image pixels are in the range [0, 1]
        )

    def forward(self, x):
        z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

model_path = content_path + '/AE_KAN.pth'
latent_dim = 16

# Initialize the model and load the trained parameters
Autoencoder = PowerfulAutoencoder(latent_dim=latent_dim).to(device)
Autoencoder.load_state_dict(torch.load(model_path))
Autoencoder.eval()  # Set the model to evaluation mode

slices_tensor = torch.tensor(slices, dtype=torch.float32).to(device)
targets_tensor = torch.tensor(targets, dtype=torch.float32).to(device)
slice_paras = torch.tensor(slice_paras, dtype=torch.float32).to(device)
print(slices_tensor)

# Get the number of slices
num_slices = slices_tensor.shape[0]

# Initialize arrays to store latent vectors
slices_latent = torch.zeros((num_slices, slice_steps, latent_dim)).to(device)
targets_latent = torch.zeros((num_slices, slice_steps, latent_dim)).to(device)

with torch.no_grad():
    for i in range(num_slices):
        for j in range(slice_steps):

            slices_latent[i, j] = Autoencoder.encoder(slices_tensor[i, j].unsqueeze(0))
            targets_latent[i, j] = Autoencoder.encoder(targets_tensor[i, j].unsqueeze(0))




print("slices_latent",slices_latent.shape)
print("targetss_latent",targets_latent.shape)

Define Surrogate Models

In [11]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dims,base_activation = nn.ReLU()):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(input_dims) - 1):
            layers.append(nn.Linear(input_dims[i], input_dims[i+1]))
            if i < len(input_dims) - 2:  # No activation function on the last layer
                layers.append(base_activation)
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


#LSTM for 16
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout = 0.2):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True,dropout = dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)



    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim,device=x.device).requires_grad_()

        # Initializing cell state for first input with zeros
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim,device=x.device).requires_grad_()
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        # Convert the final state to our desired output shape (batch_size, output_dim)
        out = self.fc(out)

        return out



# GRU for 16
class GRUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout=0.2):
        super(GRUModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim, device=x.device).requires_grad_()

        out, hn = self.gru(x, h0.detach())
        out = self.fc(out)

        return out









In [12]:
import torch

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0, path='checkpoint.pth'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

Define the training process

In [None]:
slices_tensor = torch.tensor(slices_latent, dtype=torch.float32).to(device)
targets_tensor = torch.tensor(targets_latent, dtype=torch.float32).to(device)




seed = 5
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

train_ratio = 0.8
num_samples = slices_tensor.shape[0]
num_train = int(num_samples * train_ratio)
num_test = num_samples - num_train
train_slices = slices_tensor[:num_train]


train_paras = slice_paras[:num_train]
train_targets = targets_tensor[:num_train]
test_slices = slices_tensor[num_train:]


test_targets = targets_tensor[num_train:]

train_dataset = torch.utils.data.TensorDataset(train_slices, train_targets,train_paras)
test_dataset = torch.utils.data.TensorDataset(test_slices, test_targets)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)




def train(model,RNN_flag, train_loader,test_loader,early_stopping):

  # mlp = MLP(input_dim, output_dim).to(device)
  criterion = nn.MSELoss()
  # optimizer = optim.Adam(mlp.parameters(), lr=0.001)
  optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

  # Define learning rate scheduler
  # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

  beta = 0.01
  # Training loop
  num_epochs = 30
  train_losses = []
  test_losses = []
  stop_epochs = num_epochs

  for epoch in range(num_epochs):
      model.train()
      train_loss = 0
      for data in train_loader:
          inputs, targets,klg = data
          optimizer.zero_grad()
          #mlp and KAN
          if RNN_flag == False:
            flatten_inputs = inputs.reshape(inputs.shape[0],inputs.shape[1]*inputs.shape[2])
            flatten_outputs = model(flatten_inputs)
            outputs = flatten_outputs.reshape(inputs.shape[0],inputs.shape[1],inputs.shape[2])
          elif RNN_flag == True:
            outputs = model(inputs)

          loss = criterion(outputs, targets)



          loss.backward()
          optimizer.step()

          train_loss += loss.item()

      train_loss /= len(train_loader)
      train_losses.append(train_loss)

      model.eval()
      test_loss = 0
      with torch.no_grad():
          for data in test_loader:
              inputs, targets = data

              if RNN_flag == False:
                flatten_inputs = inputs.reshape(inputs.shape[0],inputs.shape[1]*inputs.shape[2])
                flatten_outputs = model(flatten_inputs)
                outputs = flatten_outputs.reshape(inputs.shape[0],inputs.shape[1],inputs.shape[2])


              elif RNN_flag == True:
                outputs = model(inputs)


              loss = criterion(outputs, targets)
              test_loss += loss.item()

      test_loss /= len(test_loader)
      test_losses.append(test_loss)
      scheduler.step(test_loss)

      print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.6f}, Test Loss: {test_loss:.6f}')
      #Check early stopping condition
      early_stopping(test_loss,model)
      if early_stopping.early_stop:

          print("Early stopping")
          print("best_loss",early_stopping.best_loss)
          break


  print('Training complete')

  # Save the trained model
  # model_save_path = content_path + '/kan.pth'
  # torch.save(kan.state_dict(), model_save_path)
  # print(f'Model saved to {model_save_path}')

  # # Plot the training and testing losses
  # plt.figure(figsize=(10, 5))
  # plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss')
  # plt.plot(range(1, len(test_losses)+1), test_losses, label='Test Loss')
  # plt.xlabel('Epoch')
  # plt.ylabel('Loss')
  # plt.title('Training and Testing Loss per Epoch')
  # plt.legend()
  # plt.show()

In [14]:
import torch
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
import numpy as np

# Define the function to calculate RRMSE
def rrmse(img1, img2):
    return np.sqrt(np.mean((img1 - img2) ** 2)) / np.sqrt(np.mean(img1 ** 2))

# def compute_ssim(img1, img2):
#     return ssim(img1, img2, data_range=img2.max() - img2.min(), multichannel=True)
def compute_ssim(img1, img2):
    return ssim(img1, img2, data_range=img2.max() - img2.min(), channel_axis=2)

def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def plot_images(output_img, error_img):
    # Plot the first channel of the output image in a separate figure
    plt.figure(figsize=(6, 6))
    cax0 = plt.imshow(output_img, cmap='viridis')
    plt.title('')
    plt.axis('off')
    plt.colorbar(cax0)  # Add colorbar for the output image
    plt.show()

    # Plot the error heatmap in a separate figure
    plt.figure(figsize=(6, 6))
    cax1 = plt.imshow(error_img, cmap='viridis')
    plt.title('')
    plt.axis('off')
    plt.colorbar(cax1)  # Add colorbar for the heatmap
    plt.show()

def gaim_metrics(test_loader, Autoencoder, model, RNN_flag):

  # Ensure the model is in evaluation mode
  model.eval()
  print(model)
  Autoencoder.eval()

  # Decode the LSTM predictions
  decoded_outputs = torch.zeros((len(test_loader.dataset), slice_steps, 3, N, N)).to(device)

  with torch.no_grad():
      global_idx = 0
      for idx, data in enumerate(test_loader):
          inputs, targets = data
          batch_size = inputs.size(0)
          if RNN_flag == False:
            flatten_inputs = inputs.reshape(inputs.shape[0],inputs.shape[1]*inputs.shape[2])
            flatten_outputs = model(flatten_inputs)
            outputs_encoded = flatten_outputs.reshape(inputs.shape[0],inputs.shape[1],inputs.shape[2])


          elif RNN_flag == True:
            inputs = inputs

            outputs_encoded = model(inputs)
          # Decode the LSTM predictions
          for i in range(batch_size):

              for j in range(slice_steps):
                  decoded_outputs[global_idx + i, j] = Autoencoder.decoder(outputs_encoded[i, j].unsqueeze(0)).squeeze(0)

          global_idx += batch_size


  # Loop through the decoded outputs and calculate RRMSE and SSIM
  rrmse_values = []
  ssim_values = []
  mse_values = []
  psnr_values = []
  with torch.no_grad():
      for i in range(len(decoded_outputs)):
          for j in range(slice_steps):
              target_img = targets_test_oriimg[i, j].cpu().numpy()
              output_img = decoded_outputs[i, j].cpu().numpy()
              target_img = target_img.transpose((1, 2, 0))
              output_img = output_img.transpose((1, 2, 0))


              target_img_flatten = target_img.flatten()
              output_img_flatten = output_img.flatten()
              # Compute the Mean Squared Error
              mse = np.mean((target_img_flatten - output_img_flatten) ** 2)
              mse_values.append(mse)

              rrmse_value = rrmse(target_img,output_img).item()
              rrmse_values.append(rrmse_value)
              # Calculate SSIM for each channel and average them
              ssim_value = compute_ssim(target_img, output_img)
              ssim_values.append(ssim_value)
              psnr_value = compute_psnr(target_img, output_img)
              psnr_values.append(psnr_value)
              if i == 100 and j == 3:
                target_img = targets_test_oriimg[i, j].cpu().numpy()
                output_img = decoded_outputs[i, j].cpu().numpy()

                # Transpose images back to their original dimensions
                target_img = target_img.transpose((1, 2, 0))
                output_img = output_img.transpose((1, 2, 0))

                # Select only the first channel (channel 0)
                target_img_channel_0 = target_img[:, :, 0]
                output_img_channel_0 = output_img[:, :, 0]

                # Calculate the absolute error for channel 0
                error_img = np.abs(target_img_channel_0 - output_img_channel_0)

                # Plot the output image (channel 0) and error heatmap using 'viridis'
                plot_images(output_img_channel_0, error_img)
  return rrmse_values,ssim_values,mse_values,psnr_values,decoded_outputs






Roll out

In [15]:

def rollout_metrics(test_loader, Autoencoder, model, RNN_flag):
  # Initialize lists to store RRMSE and SSIM scores

  # Ensure the model is in evaluation mode
  # model.load_state_dict(torch.load(model_save_path))
  model.eval()
  print(model)
  Autoencoder.eval()


  # Decode the LSTM predictions
  decoded_outputs = torch.zeros((len(test_loader.dataset), slice_steps, 3, N, N)).to(device)

  with torch.no_grad():
      global_idx = 0
      for idx, data in enumerate(test_loader):
          inputs, targets = data

          batch_size = inputs.size(0)
          if RNN_flag == False:
            first_input = inputs[0]
            model_input = first_input.unsqueeze(0)
            outputs_encoded = torch.zeros((batch_size, slice_steps, 16)).to(device)
            model_input = model_input.reshape(model_input.shape[0],model_input.shape[1]*model_input.shape[2])
            #apply roll out
            for b in range(batch_size):

              model_out = model(model_input.unsqueeze(0))
              output_reshape = model_out.reshape(1,5,16)
              outputs_encoded[b] = output_reshape.squeeze(0)
              model_input = model_out.squeeze(0)

          elif RNN_flag == True:
            first_input = inputs[0]
            model_input = first_input
            outputs_encoded = torch.zeros((batch_size, slice_steps, 16)).to(device)
            for b in range(batch_size):

              model_out = model(model_input.unsqueeze(0))
              outputs_encoded[b] = model_out.squeeze(0)
              model_input = model_out.squeeze(0)







          # Decode the model predictions
          for i in range(batch_size):
              for j in range(slice_steps):
                  decoded_outputs[global_idx + i, j] = Autoencoder.decoder(outputs_encoded[i, j].unsqueeze(0)).squeeze(0)

          global_idx += batch_size

  # Loop through the decoded outputs and calculate metrics
  rrmse_values = []
  ssim_values = []
  mse_values = []
  psnr_values = []
  with torch.no_grad():
      for i in range(len(decoded_outputs)):
          for j in range(slice_steps):
              target_img = targets_test_oriimg[i, j].cpu().numpy()
              output_img = decoded_outputs[i, j].cpu().numpy()
              target_img = target_img.transpose((1, 2, 0))
              output_img = output_img.transpose((1, 2, 0))


              target_img_flatten = target_img.flatten()
              output_img_flatten = output_img.flatten()
              mse = np.mean((target_img_flatten - output_img_flatten) ** 2)
              mse_values.append(mse)
              rrmse_value = rrmse(target_img,output_img).item()
              rrmse_values.append(rrmse_value)
              ssim_value = compute_ssim(target_img, output_img)
              ssim_values.append(ssim_value)
              psnr_value = compute_psnr(target_img, output_img)
              psnr_values.append(psnr_value)
  return rrmse_values,ssim_values,mse_values,psnr_values







KAN

In [None]:
input_dim = 5 * 16
output_dim = 5 * 16
input_dims = [input_dim,256,output_dim]  # Example hidden layers

kan = KAN(input_dims).to(device)
RNN_flag = False
print(kan)
model_save_path = content_path + 'kan.pth'
early_stopping = EarlyStopping(patience=10, min_delta=1e-5, path=model_save_path)
train(kan,RNN_flag, train_loader,test_loader,early_stopping)
kan_rrmse,kan_ssim,kan_mse,kan_psnr,kan_decoded = gaim_metrics(test_loader, Autoencoder, kan, RNN_flag)
test_ro_loader = torch.utils.data.DataLoader(test_dataset, batch_size=73, shuffle=False)
kan_ro_rrmse,kan_ro_ssim,kan_ro_mse,kan_ro_psnr = rollout_metrics(test_ro_loader, Autoencoder, kan, RNN_flag)
#mean metrics
kan_avg_ssim = np.mean(kan_ssim)
kan_avg_rrmse = np.mean(kan_rrmse)
kan_avg_mse = np.mean(kan_mse)
kan_avg_psnr = np.mean(kan_psnr)
#mean rollout metircs
kan_avg_ro_ssim = np.mean(kan_ro_ssim)
kan_avg_ro_rrmse = np.mean(kan_ro_rrmse)
kan_avg_ro_mse = np.mean(kan_ro_mse)
kan_avg_ro_psnr = np.mean(kan_ro_psnr)


print(f'Mean RRMSE: {kan_avg_rrmse:.6f}')
print(f'Mean SSIM: {kan_avg_ssim:.6f}')
print(f'Mean MSE: {kan_avg_mse:.6f}')
print(f'Mean PSNR: {kan_avg_psnr:.6f}')

print(f'Mean RRMSE rollout: {kan_avg_ro_rrmse:.6f}')
print(f'Mean SSIM rollout: {kan_avg_ro_ssim:.6f}')
print(f'Mean MSE rollout: {kan_avg_ro_mse:.6f}')
print(f'Mean PSNR rollout: {kan_avg_ro_psnr:.6f}')


MLP


In [None]:
input_dim = 5 * 16
output_dim = 5 * 16
input_dims = [input_dim,256,output_dim]  # Example hidden layers

mlp = MLP(input_dims).to(device)
RNN_flag = False
print(mlp)
mlp_save_path = content_path + '/mlp.pth'
early_stopping = EarlyStopping(patience=10, min_delta=1e-5, path=mlp_save_path)
train(mlp,RNN_flag, train_loader,test_loader,early_stopping)
mlp_rrmse,mlp_ssim,mlp_mse,mlp_psnr,mlp_decoded = gaim_metrics(test_loader, Autoencoder, mlp, RNN_flag)
test_ro_loader = torch.utils.data.DataLoader(test_dataset, batch_size=73, shuffle=False)
mlp_ro_rrmse,mlp_ro_ssim,mlp_ro_mse,mlp_ro_psnr = rollout_metrics(test_ro_loader, Autoencoder, mlp, RNN_flag)
#mean metrics
mlp_avg_ssim = np.mean(mlp_ssim)
mlp_avg_rrmse = np.mean(mlp_rrmse)
mlp_avg_mse = np.mean(mlp_mse)
mlp_avg_psnr = np.mean(mlp_psnr)
#mean rollout metircs
mlp_avg_ro_ssim = np.mean(mlp_ro_ssim)
mlp_avg_ro_rrmse = np.mean(mlp_ro_rrmse)
mlp_avg_ro_mse = np.mean(mlp_ro_mse)
mlp_avg_ro_psnr = np.mean(mlp_ro_psnr)


print(f'Mean RRMSE: {mlp_avg_rrmse:.6f}')
print(f'Mean SSIM: {mlp_avg_ssim:.6f}')
print(f'Mean MSE: {mlp_avg_mse:.6f}')
print(f'Mean PSNR: {kan_avg_psnr:.6f}')

print(f'Mean RRMSE rollout: {mlp_avg_ro_rrmse:.6f}')
print(f'Mean SSIM rollout: {mlp_avg_ro_ssim:.6f}')
print(f'Mean MSE rollout: {mlp_avg_ro_mse:.6f}')
print(f'Mean PSNR rollout: {mlp_avg_ro_psnr:.6f}')

LSTM

In [None]:
lstm = LSTMModel(16,100,16,1,0.2).to(device)
RNN_flag = True
print(lstm)
model_save_path = content_path + '/lstm.pth'
early_stopping = EarlyStopping(patience=10, min_delta=1e-5, path=model_save_path)
train(lstm,RNN_flag, train_loader,test_loader,early_stopping)
lstm_rrmse,lstm_ssim,lstm_mse,lstm_psnr,lstm_decoded = gaim_metrics(test_loader, Autoencoder, lstm, RNN_flag)
test_ro_loader = torch.utils.data.DataLoader(test_dataset, batch_size=73, shuffle=False)
lstm_ro_rrmse,lstm_ro_ssim,lstm_ro_mse,lstm_ro_psnr = rollout_metrics(test_ro_loader, Autoencoder, lstm, RNN_flag)
#mean metrics
lstm_avg_ssim = np.mean(lstm_ssim)
lstm_avg_rrmse = np.mean(lstm_rrmse)
lstm_avg_mse = np.mean(lstm_mse)
lstm_avg_psnr = np.mean(lstm_psnr)
#mean rollout metircs
lstm_avg_ro_ssim = np.mean(lstm_ro_ssim)
lstm_avg_ro_rrmse = np.mean(lstm_ro_rrmse)
lstm_avg_ro_mse = np.mean(lstm_ro_mse)
lstm_avg_ro_psnr = np.mean(lstm_ro_psnr)


print(f'Mean RRMSE: {lstm_avg_rrmse:.6f}')
print(f'Mean SSIM: {lstm_avg_ssim:.6f}')
print(f'Mean MSE: {lstm_avg_mse:.6f}')
print(f'Mean PSNR: {kan_avg_psnr:.6f}')

print(f'Mean RRMSE rollout: {lstm_avg_ro_rrmse:.6f}')
print(f'Mean SSIM rollout: {lstm_avg_ro_ssim:.6f}')
print(f'Mean MSE rollout: {lstm_avg_ro_mse:.6f}')
print(f'Mean PSNR rollout: {lstm_avg_ro_psnr:.6f}')

GRU

In [None]:
gru = GRUModel(16,100,16,1,0.2).to(device)
RNN_flag = True
print(gru)
model_save_path = content_path + '/gru.pth'
early_stopping = EarlyStopping(patience=10, min_delta=1e-5, path=model_save_path)
train(gru,RNN_flag, train_loader,test_loader,early_stopping)
gru_rrmse,gru_ssim,gru_mse,gru_psnr,gru_decoded = gaim_metrics(test_loader, Autoencoder, gru, RNN_flag)
test_ro_loader = torch.utils.data.DataLoader(test_dataset, batch_size=73, shuffle=False)
gru_ro_rrmse,gru_ro_ssim,gru_ro_mse,gru_ro_psnr = rollout_metrics(test_ro_loader, Autoencoder, gru, RNN_flag)
#mean metrics
gru_avg_ssim = np.mean(gru_ssim)
gru_avg_rrmse = np.mean(gru_rrmse)
gru_avg_mse = np.mean(gru_mse)
gru_avg_psnr = np.mean(gru_psnr)
#mean rollout metircs
gru_avg_ro_ssim = np.mean(gru_ro_ssim)
gru_avg_ro_rrmse = np.mean(gru_ro_rrmse)
gru_avg_ro_mse = np.mean(gru_ro_mse)
gru_avg_ro_psnr = np.mean(gru_ro_psnr)


print(f'Mean RRMSE: {gru_avg_rrmse:.6f}')
print(f'Mean SSIM: {gru_avg_ssim:.6f}')
print(f'Mean MSE: {gru_avg_mse:.6f}')
print(f'Mean PSNR: {kan_avg_psnr:.6f}')

print(f'Mean RRMSE rollout: {gru_avg_ro_rrmse:.6f}')
print(f'Mean SSIM rollout: {gru_avg_ro_ssim:.6f}')
print(f'Mean MSE rollout: {gru_avg_ro_mse:.6f}')
print(f'Mean PSNR rollout: {gru_avg_ro_psnr:.6f}')

In [None]:
import matplotlib.pyplot as plt

steps = range(len(kan_ro_ssim[:73]))  # 定义步数，与 SSIM 数组长度相同


plt.figure(figsize=(12, 6))


plt.subplot(1, 2, 1)
plt.plot(steps, kan_ro_ssim[:73], label='KAN', linewidth=1)
plt.plot(steps, mlp_ro_ssim[:73], label='MLP', linewidth=0.9,linestyle = '--')
plt.plot(steps, lstm_ro_ssim[:73], label='LSTM', linewidth=0.9,linestyle = '--')
plt.plot(steps, gru_ro_ssim[:73], label='GRU', linewidth=0.9,linestyle = '--')
plt.xlabel('Step')
plt.ylabel('SSIM')
plt.title('SSIM over Steps')
plt.legend()
plt.grid(True)


plt.subplot(1, 2, 2)
plt.plot(steps, kan_ro_rrmse[:73], label='KAN', linewidth=1)
plt.plot(steps, mlp_ro_rrmse[:73], label='MLP', linewidth=0.9,linestyle = '--')
plt.plot(steps, lstm_ro_rrmse[:73], label='LSTM', linewidth=0.9,linestyle = '--')
plt.plot(steps, gru_ro_rrmse[:73], label='GRU', linewidth=0.9,linestyle = '--')
plt.xlabel('Step')
plt.ylabel('RRMSE')

plt.title('RRMSE over Steps')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
decoded_outputs = kan_decoded
# Assuming your images are in the shape (128, 128, 3) and you want to visualize them
steps = decoded_outputs.shape[1]  # Number of time steps
channels = decoded_outputs.shape[2]  # Number of channels
image_height, image_width = decoded_outputs.shape[3], decoded_outputs.shape[4]  # Image dimensions

# Increase the figsize to make each subplot larger
fig, axs = plt.subplots(channels * 2, steps, figsize=(20, 20))  # Adjust the figsize as needed

for step in range(steps):
    for channel in range(channels):
        # Plot the original image (target)
        axs[channel * 2, step].imshow(targets_test_oriimg[10, step, channel].cpu().numpy(), cmap='viridis')
        axs[channel * 2, step].axis('off')
        axs[channel * 2, step].set_title(f'Original - Step {step + 1}, Channel {channel + 1}')

        # Plot the reconstructed image (output)
        axs[channel * 2 + 1, step].imshow(decoded_outputs[10, step, channel].cpu().numpy(), cmap='viridis')
        axs[channel * 2 + 1, step].axis('off')
        axs[channel * 2 + 1, step].set_title(f'Reconstructed - Step {step + 1}, Channel {channel + 1}')

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to normalize images
def normalize_image(image):
    min_val = np.min(image)
    max_val = np.max(image)
    normalized_image = (image - min_val) / (max_val - min_val)
    return normalized_image

# Function to combine images
def combine_images(images, num_images, steps, channels):
    combined_images = []
    for i in range(num_images):
        combined_set = []
        for j in range(channels):
            combined_input = np.concatenate([normalize_image(images[i, step, j]) for step in range(steps)], axis=1)
            combined_set.append(combined_input)
        combined_images.append(combined_set)
    return combined_images

# Function to visualize the original and decoded data
def visualize_reconstruction(inputs, outputs, num_images=5):
    inputs_combined = combine_images(inputs.cpu().numpy(), num_images, 5, 3)
    outputs_combined = combine_images(outputs.cpu().numpy(), num_images, 5, 3)

    fig, axes = plt.subplots(num_images, 2, figsize=(18, num_images * 6))

    # Ensure axes is always a 2D array
    if num_images == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(num_images):
        input_image = np.concatenate(inputs_combined[i], axis=0)
        output_image = np.concatenate(outputs_combined[i], axis=0)
        print("input_image.shape", input_image.shape)
        print("output_image.shape", output_image.shape)

        ax = axes[i, 0]
        ax.imshow(input_image, cmap='viridis')
        ax.set_title(f'Input 5 steps in u,v,h')
        ax.axis('off')

        ax = axes[i, 1]
        ax.imshow(output_image, cmap='viridis')
        ax.set_title(f'Output 5 steps in u,v,h')
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Select a batch of test data for visualization
count = 0
for data in test_loader:
    test_data = data
    #show the later images
    if count==20:
      break
    count+=1

test_inputs, test_targets = test_data

if RNN_flag == False:
  flatten_inputs = test_inputs.reshape(test_inputs.shape[0],test_inputs.shape[1]*test_inputs.shape[2])
  flatten_outputs = kan(flatten_inputs)
  test_outputs = flatten_outputs.reshape(test_inputs.shape[0],test_inputs.shape[1],test_inputs.shape[2])
elif RNN_flag == True:
  test_outputs = kan(test_inputs)

# Decode the latent vectors to original shape
decoded_inputs = torch.zeros(test_inputs.size(0), test_inputs.size(1), 3, N, N).to(device)
decoded_outputs = torch.zeros(test_outputs.size(0), test_outputs.size(1), 3, N, N).to(device)
with torch.no_grad():
    for i in range(test_inputs.size(0)):
        for j in range(test_inputs.size(1)):
            decoded_inputs[i, j] = Autoencoder.decoder(test_inputs[i, j].unsqueeze(0)).squeeze(0)
            decoded_outputs[i, j] = Autoencoder.decoder(test_outputs[i, j].unsqueeze(0)).squeeze(0)

# Visualize the first set of slices and their corresponding targets
visualize_reconstruction(decoded_inputs, decoded_outputs, num_images=2)