In [None]:
import time
from pylab import *
import matplotlib.gridspec as gridspec

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, TensorDataset, random_split

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
content_path = '/content/gdrive/MyDrive/ic_project/new'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import time
from pylab import *
import matplotlib.gridspec as gridspec
import numpy as np
#construct background states, observations with error

def x_to_y(X): # averaging in 2*2 windows (4 pixels)
    dim = X.shape[0]
    dim = 20
    Y = np.zeros((int(dim/2),int(dim/2)))
    for i in range(int(dim/2)):
        for j in range(int(dim/2)):
            Y[i,j] = X[2*i,2*j] + X[2*i+1,2*j] + X[2*i,2*j+1] + X[2*i+1,2*j+1]

            Y_noise = np.random.multivariate_normal(np.zeros(100),0.0000 * np.eye(100))
            Y_noise.shape = (10,10)
            Y = Y + Y_noise
    return Y


class shallow(object):


    time = 0

    plt = []
    fig = []


    def __init__(self, x=[],y=[],h_ini = 1.,u=[],v = [],dx=0.01,dt=0.0001, N=64,L=1., px=16, py=16, R=64, Hp=0.1, g=1., b=0.): # How define no default argument before?


        # add a perturbation in pressure surface


        self.px, self.py = px, py
        self.R = R
        self.Hp = Hp



        # Physical parameters

        self.g = g
        self.b = b
        self.L=L
        self.N=N

        # limits for h,u,v


        #self.dx =  self.L / self.N # a changer
        #self.dt = self.dx / 100.
        self.dx=dx
        self.dt=dt

        self.x,self.y = mgrid[:self.N,:self.N]

        self.u=zeros((self.N,self.N))
        self.v=zeros((self.N,self.N))

        self.h_ini=h_ini

        self.h=self.h_ini * ones((self.N,self.N))

        rr = (self.x-px)**2 + (self.y-py)**2
        self.h[rr<R] = self.h_ini + Hp #set initial conditions

        self.lims = [(self.h_ini-self.Hp,self.h_ini+self.Hp),(-0.02,0.02),(-0.02,0.02)]



    def dxy(self, A, axis=0):
        """
        Compute derivative of array A using balanced finite differences
        Axis specifies direction of spatial derivative (d/dx or d/dy)
        dA[i]/dx =  (A[i+1] - A[i-1] )  / 2dx
        """
        return (roll(A, -1, axis) - roll(A, 1, axis)) / (self.dx*2.) # roll: shift the array axis=0 shift the horizontal axis

    def d_dx(self, A):
        return self.dxy(A,1)

    def d_dy(self, A):
        return self.dxy(A,0)


    def d_dt(self, h, u, v):
        """
        http://en.wikipedia.org/wiki/Shallow_water_equations#Non-conservative_form
        """
        for x in [h, u, v]: # type check
           assert isinstance(x, ndarray) and not isinstance(x, matrix)

        g,b,dx = self.g, self.b, self.dx

        du_dt = -g*self.d_dx(h) - b*u
        dv_dt = -g*self.d_dy(h) - b*v

        H = 0 #h.mean() - our definition of h includes this term
        dh_dt = -self.d_dx(u * (H+h)) - self.d_dy(v * (H+h))

        return dh_dt, du_dt, dv_dt


    def evolve(self):
        """
        Evolve state (h, u, v) forward in time using simple Euler method
        x_{N+1} = x_{N} +   dx/dt * d_t
        """

        dh_dt, du_dt, dv_dt = self.d_dt(self.h, self.u, self.v)
        dt = self.dt

        self.h += dh_dt * dt
        self.u += du_dt * dt
        self.v += dv_dt * dt
        self.time += dt

        return self.h, self.u, self.v





In [None]:
import random
# Generate the dataset
def generate_dataset(num_simulations=200, snapshots_per_simulation=5000, N=64):
    dataset = np.zeros((num_simulations * (snapshots_per_simulation//50), 3, N, N))
    first_snapshots = np.zeros((num_simulations, 3, N, N))
    snapshot_idx = 0

    for sim in range(num_simulations):
        px=random.randint(54, 74)*1.
        py=random.randint(54, 74)*1.
        R=random.randint(80, 160)*1.
        Hp=random.randint(5, 20)*0.01
        b=random.randint(1, 100)*0.1

        SW = shallow(N=N, px=px, py=py, R=R, Hp=Hp, b=b)
        print("sim",sim)

        for step in range(snapshots_per_simulation):
            SW.evolve()
            if (step)%50==0:

              dataset[snapshot_idx, 0, :, :] = SW.u
              dataset[snapshot_idx, 1, :, :] = SW.v
              dataset[snapshot_idx, 2, :, :] = SW.h
              snapshot_idx += 1

              # Save the first snapshot of the simulation
              if step == snapshots_per_simulation-1:
                  first_snapshots[sim, 0, :, :] = SW.u
                  first_snapshots[sim, 1, :, :] = SW.v
                  first_snapshots[sim, 2, :, :] = SW.h


    return dataset,first_snapshots

#apply min_max normalization
def min_max_normalize(dataset):
    u_min = dataset[:, 0, :, :].min()
    u_max = dataset[:, 0, :, :].max()
    v_min = dataset[:, 1, :, :].min()
    v_max = dataset[:, 1, :, :].max()
    h_min = dataset[:, 2, :, :].min()
    h_max = dataset[:, 2, :, :].max()

    dataset[:, 0, :, :] = (dataset[:, 0, :, :] - u_min) / (u_max - u_min)
    dataset[:, 1, :, :] = (dataset[:, 1, :, :] - v_min) / (v_max - v_min)
    dataset[:, 2, :, :] = (dataset[:, 2, :, :] - h_min) / (h_max - h_min)

    return dataset, (u_min, u_max, v_min, v_max, h_min, h_max)

# Generate the dataset
dataset,first_snapshots = generate_dataset(num_simulations=200, snapshots_per_simulation=5000, N=128)
dataset_norm, normalization_params= min_max_normalize(dataset)

Define KAN

In [None]:

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [None]:

class PowerfulAutoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(PowerfulAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # (128, 128, 3) -> (64, 64, 32)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # (64, 64, 32) -> (32, 32, 64)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (32, 32, 64) -> (16, 16, 128)
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (16, 16, 128) -> (8, 8, 256)
            nn.ReLU(),
            nn.Flatten(),
            #nn.Linear(8*8*256, latent_dim)  # (8*8*256) -> (latent_dim)
            KANLinear(8*8*256, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 8*8*256),  # (latent_dim) -> (8*8*256)
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (8, 8, 256) -> (16, 16, 128)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (16, 16, 128) -> (32, 32, 64)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # (32, 32, 64) -> (64, 64, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # (64, 64, 32) -> (128, 128, 3)
            nn.Sigmoid()  # Use sigmoid if your image pixels are in the range [0, 1]
        )

    def forward(self, x):
        z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed,z


def loss_function(reconstructed, original, latent, lambda_reg=0.1):
    reconstruction_loss = nn.MSELoss()(reconstructed, original)
    latent_reg_loss = torch.mean(torch.norm(latent, p=2, dim=1))
    total_loss = reconstruction_loss + lambda_reg * latent_reg_loss
    return total_loss, reconstruction_loss, latent_reg_loss

In [None]:
dataset = torch.tensor(dataset_norm, dtype=torch.float32)
# Create DataLoader
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Split the dataset into training and testing sets (80% training, 20% testing)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
seed = 5
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
model = PowerfulAutoencoder(latent_dim=16).to(device)
print(model)
# Weight initialization function
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 30
lambda_reg = 1e-6
train_losses = []
test_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_latent_loss = 0
    for data in train_loader:
        data = data.to(device)  # Move data to GPU

        # Forward pass
        output,latent = model(data)
        loss, reconstruction_loss, latent_reg_loss = loss_function(output,data,latent,lambda_reg=lambda_reg)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_latent_loss += latent_reg_loss.item()

    train_loss /= len(train_loader)
    train_latent_loss /= len(train_loader)
    train_losses.append(train_loss)

    model.eval()
    test_loss = 0
    test_latent_loss = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)  # Move data to GPU

            # Forward pass
            output,latent = model(data)
            loss, reconstruction_loss, latent_reg_loss = loss_function(output,data,latent,lambda_reg=lambda_reg)

            test_loss += loss.item()
            test_latent_loss += latent_reg_loss.item()

    test_loss /= len(test_loader)
    test_latent_loss /= len(test_loader)
    test_losses.append(test_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.6f}, Test Loss: {test_loss:.6f},Train Latent loss: {train_latent_loss:.6f}, Test Latent Loss: {test_latent_loss:.6f}')

print('Training complete')

# Save the trained model
model_path = content_path + '/AE_KAN.pth'
torch.save(model.state_dict(), model_path)
print('Model saved to autoencoder.pth')

# Plot the training and testing losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss per Epoch')
plt.legend()
plt.show()

plot loss in log scale

In [None]:
plt.figure(figsize=(10, 5))
plt.semilogy(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.semilogy(range(1, num_epochs + 1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss per Epoch/logscale')
plt.legend()
plt.show()

In [None]:

def visualize_reconstruction(model, data_loader, num_images=5):
    model.eval()
    data_iter = iter(data_loader)

    original_data = next(data_iter).to(device)

    with torch.no_grad():
        encoded_data = model.encoder(original_data)
        reconstructed_data = model.decoder(encoded_data)

    original_data = original_data.cpu().numpy()
    reconstructed_data = reconstructed_data.cpu().numpy()

    fig, axes = plt.subplots(num_images, 6, figsize=(18, num_images * 3))
    for i in range(num_images):
        for j in range(3):
            ax = axes[i, j*2]

            ax.imshow(original_data[i, j], cmap='viridis')
            ax.set_title(f'Original - Channel {j+1}')
            ax.axis('off')

            ax = axes[i, j*2 + 1]
            ax.imshow(reconstructed_data[i, j], cmap='viridis')
            ax.set_title(f'Reconstructed - Channel {j+1}')
            ax.axis('off')

    plt.tight_layout()
    plt.show()

# Visualize some reconstructions
visualize_reconstruction(model, test_loader, num_images=10)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_reconstruction(model, data_loader, num_images=5):
    model.eval()
    # data_iter = iter(data_loader)
    # original_data = next(data_iter).to(device)
    count = 0
    for data in data_loader:
        count += 1
        if count == 7:
            data_iter = data.to(device)
            break

    original_data = data_iter.to(device)

    with torch.no_grad():
        encoded_data = model.encoder(original_data)
        reconstructed_data = model.decoder(encoded_data)

    original_data = original_data.cpu().numpy()
    reconstructed_data = reconstructed_data.cpu().numpy()

    fig, axes = plt.subplots(3, num_images, figsize=(num_images * 5, 15))  # Adjusting rows and columns
    for i in range(num_images):
        # Display Input (Original) Image
        ax = axes[0, i]
        img = ax.imshow(original_data[i, 2], cmap='viridis')
        ax.set_title('Input')
        ax.axis('off')
        fig.colorbar(img, ax=ax, orientation='vertical')

        # Display Reconstructed Image
        ax = axes[1, i]
        img = ax.imshow(reconstructed_data[i, 2], cmap='viridis')
        ax.set_title('Reconstruction')
        ax.axis('off')
        fig.colorbar(img, ax=ax, orientation='vertical')

        # Display Error Image
        ax = axes[2, i]
        error_img = np.abs(original_data[i, 2] - reconstructed_data[i, 2])
        img = ax.imshow(error_img, cmap='viridis')
        ax.set_title('Error')
        ax.axis('off')
        fig.colorbar(img, ax=ax, orientation='vertical')

    plt.tight_layout()
    plt.show()

# Visualize some reconstructions
visualize_reconstruction(model, test_loader, num_images=5)



In [None]:
from skimage.metrics import structural_similarity as ssim
# Define RRMSE and SSIM functions
def rrmse(img1, img2):
    return np.sqrt(np.mean((img1 - img2) ** 2)) / np.sqrt(np.mean(img1 ** 2))

def compute_ssim(img1, img2):
    return ssim(img1, img2, data_range=img2.max() - img2.min(), multichannel=True,win_size=5, channel_axis=-1)

def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0  # Assuming the pixel values are normalized between 0 and 1
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

# Compute RRMSE and SSIM for the entire test dataset
def evaluate_model(model, data_loader):
    model.eval()
    rrmse_values = []
    ssim_values = []
    mse_values = []
    psnr_values = []

    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)

            reconstructed_data,_ = model(data)
            original_data = data.cpu().numpy()
            reconstructed_data = reconstructed_data.cpu().numpy()

            for i in range(original_data.shape[0]):
                original_img = original_data[i].transpose(1, 2, 0)
                reconstructed_img = reconstructed_data[i].transpose(1, 2, 0)

                rrmse_value = rrmse(original_img, reconstructed_img)

                ssim_value = compute_ssim(original_img, reconstructed_img)

                mse = np.mean((original_img- reconstructed_img) ** 2)

                psnr_value = compute_psnr(original_img, reconstructed_img)

                rrmse_values.append(rrmse_value)
                ssim_values.append(ssim_value)
                mse_values.append(mse)
                psnr_values.append(psnr_value)

    mean_rrmse = np.mean(rrmse_values)
    mean_ssim = np.mean(ssim_values)
    mean_mse = np.mean(mse_values)
    mean_psnr = np.mean(psnr_values)
    print("metric for AE without KAN")
    #print("metric for AE")
    print(f'Mean RRMSE: {mean_rrmse:.4f}')
    print(f'Mean SSIM: {mean_ssim:.4f}')
    print(f'Mean MSE: {mean_mse:.6f}')
    print(f'Mean PSNR: {mean_psnr:.4f}')

    return mean_rrmse, mean_ssim

# Evaluate the model on the test dataset
mean_rrmse, mean_ssim = evaluate_model(model, test_loader)
