In [2]:
#Libraries

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm
import os
import argparse
from scipy.sparse.linalg import eigsh
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import random
from torch.utils.data import Dataset, DataLoader
#import wandb
import torch.nn.functional as F

In [None]:
#Models and Embeddings

#EMBEDDINGS
class GaussianEmbedding(nn.Module):
    def __init__(self, num_terms, num_channels):
        super(GaussianEmbedding, self).__init__()
        self.num_terms = num_terms
        self.num_channels = num_channels

        self.h = nn.Parameter(torch.randn(num_terms + 1, num_channels))
        nn.init.xavier_uniform_(self.h)

    def forward(self, A):
        batch_size, num_nodes, _ = A.shape
        Y_hat = torch.zeros(batch_size, num_nodes, self.num_channels, device=A.device)
        for c in range(self.num_channels):
            result = self.h[0, c] * torch.eye(num_nodes, device=A.device).unsqueeze(0).expand(batch_size, -1, -1)
            for i in range(1, self.num_terms + 1):
                A_power_i = torch.matrix_power(A, i)
                result += self.h[i, c] * A_power_i
            Y_hat[..., c] = torch.diagonal(result, dim1=-2, dim2=-1)
        return Y_hat

class EigenEmbedding(nn.Module):
    def __init__(self, num_eigenvectors = None):
        super(EigenEmbedding, self).__init__()
        self.num_eigenvectors = num_eigenvectors

    def forward(self, A):
        eigenvectors = []
        for i in range(A.shape[0]):
            _, V = torch.linalg.eigh(A[i])
            if self.num_eigenvectors is not None:
                eigenvectors.append(V[:, -self.num_eigenvectors:])
            else:
                eigenvectors.append(V)
        return torch.stack(eigenvectors, dim=0)

#MODELS
class GraphConvolutionFilter(nn.Module):
    def __init__(self, K, F, G, layer_norm = False):
        super(GraphConvolutionFilter, self).__init__()

        self.K = K #number of filter taps
        self.F = F #number of input channels
        self.G = G #number of output channels

        self.H = nn.Parameter(torch.randn(K, F, G))
        nn.init.xavier_uniform_(self.H)

        if layer_norm:
            self.layer_norm = nn.LayerNorm(G)
        else:
            self.layer_norm = None
    def forward(self, A, X):

        #input X is B x N x F
        Z = X @ self.H[0]

        A_power_i = A.clone()

        for i in range(1, self.K):
            Z += torch.bmm(A_power_i, (X @ self.H[i]))
            A_power_i = torch.bmm(A_power_i, A)

        if self.layer_norm is not None:
            Z = self.layer_norm(Z)

        return Z
    
class LinearPE(nn.Module):
    def __init__(self, F, D, bias = True, layer_norm = False):
        super(LinearPE, self).__init__()
        
        self.F = F #number of input channels
        self.D = D #number of output channels
        

        self.W = nn.Parameter(torch.randn(F, D))

        if bias:
            self.b = nn.Parameter(torch.randn(D))
        else:
            self.b = torch.zeros(D)
        
        if layer_norm:
            self.layer_norm = nn.LayerNorm(D)
        else:
            self.layer_norm = None


    def forward(self, A, X):
        #input X is B x N x F

        Z = X @ self.W + self.b

        if self.layer_norm is not None:
            Z = self.layer_norm(Z)

        return Z
    
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module as described in 'Attention Is All You Need' paper.
    
    This implementation supports masked attention and different input/output dimensions.
    """
    
    def __init__(self, d_model, num_heads, d_k=None, d_v=None, dropout=0.0, bias=False):
        """
        Initialize the Multi-Head Attention module.
        
        Parameters:
        - d_model: Model dimension (input and output dimension)
        - num_heads: Number of attention heads
        - d_k: Dimension of keys (default: d_model // num_heads)
        - d_v: Dimension of values (default: d_model // num_heads)
        - dropout: Dropout probability
        """
        super(MultiHeadAttention, self).__init__()
        
        self.num_heads = num_heads
        self.d_model = d_model
        
        # If d_k and d_v are not specified, set them to d_model // num_heads
        self.d_k = d_k if d_k is not None else d_model // num_heads
        self.d_v = d_v if d_v is not None else d_model // num_heads
        
        # Linear projections for queries, keys, and values
        self.W_q = nn.Linear(d_model, num_heads * self.d_k, bias=bias)
        self.W_k = nn.Linear(d_model, num_heads * self.d_k, bias=bias)
        self.W_v = nn.Linear(d_model, num_heads * self.d_v, bias=bias)
        
        # Output projection
        self.W_o = nn.Linear(num_heads * self.d_v, d_model, bias=bias)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Layer normalization for the output
        self.layer_norm = nn.LayerNorm(d_model)
        
        # Scaling factor for dot product attention
        # self.scale = 1 / math.sqrt(self.d_k)
        self.scale = 1

        # Linear layer to combine attention scores from different heads
        self.score_combination = nn.Linear(num_heads, 1, bias=False)
    
    def forward(self, A, x, mask=None, residual=None):
        """
        Forward pass of the Multi-Head Attention module.
        
        Parameters:
        - Q: Query tensor of shape (batch_size, seq_len_q, d_model)
        - K: Key tensor of shape (batch_size, seq_len_k, d_model)
        - V: Value tensor of shape (batch_size, seq_len_v, d_model)
        - mask: Optional mask tensor of shape (batch_size, seq_len_q, seq_len_k)
        - residual: Optional residual connection
        
        Returns:
        - output: Output tensor of shape (batch_size, seq_len_q, d_model)
        - attention: Attention weights of shape (batch_size, num_heads, seq_len_q, seq_len_k)
        """
     
        batch_size = x.size(0)
        
        # If residual connection is not provided, use Q as residual
        if residual is None:
            residual = x

        # Linear projections and reshaping for multi-head attention
        # Shape: (batch_size, seq_len, num_heads, d_*)

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)


        q = q.view(batch_size, -1, self.num_heads, self.d_k)
        k = k.view(batch_size, -1, self.num_heads, self.d_k)
        v = v.view(batch_size, -1, self.num_heads, self.d_v)



        # Transpose to shape: (batch_size, num_heads, seq_len, d_*)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        
        # Calculate attention scores
        # (batch_size, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Apply mask if provided
        if mask is not None:
            # Add an extra dimension for the number of heads
            if mask.dim() == 3:  # (batch_size, seq_len_q, seq_len_k)
                mask = mask.unsqueeze(1)
            
            # Set masked positions to a large negative value before softmax
            scores = scores.masked_fill(mask == 0, -1e9)

        
        # Apply softmax to get attention weights
        attn_weights = torch.nn.functional.softmax(scores, dim=-1)
        
        # Apply dropout to attention weights
        attn_weights = self.dropout(attn_weights)
        
        # Calculate weighted sum of values
        # (batch_size, num_heads, seq_len_q, d_v)

        context = torch.matmul(attn_weights, v)
        
        # Transpose and reshape to (batch_size, seq_len_q, num_heads * d_v)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_v)
        
        # Apply output projection
        output = self.W_o(context)
        
        # Apply dropout and residual connection
        output = self.dropout(output)
        output = self.layer_norm(output + residual)

        # Combine attention scores from different heads using learned weights
        # Transpose scores to have heads dimension last: (batch_size, seq_len_q, seq_len_k, num_heads)
        scores = scores.permute(0, 2, 3, 1)
        # Apply linear combination: (batch_size, seq_len_q, seq_len_k, 1)
        combined_scores = self.score_combination(scores)
        # Remove last singleton dimension
        combined_scores = combined_scores.squeeze(-1)
        
        return output

In [None]:
#Define sequential model combining embedding and denoising
class SequentialDenoisingModel(nn.Module):
    def __init__(self, embedding_model, denoising_model, activation = None, asymmetric = False, F = None, D = None):
        super(SequentialDenoisingModel, self).__init__()
        self.embedding_model = embedding_model
        self.denoising_model = denoising_model
        self.activation = activation
        self.asymmetric = asymmetric

        if self.asymmetric:
            self.D = D 
            self.F = F
            self.layer_norm = nn.LayerNorm(D)
            self.W_x = nn.Parameter(torch.randn(F, D))
            self.W_y = nn.Parameter(torch.randn(F, D))
        
    def forward(self, A):
        # First apply embedding model to get node embeddings
        E = self.embedding_model(A)

        #Then apply denoising model to get denoised embeddings
        Z = self.denoising_model(A, E.float())

        #Then if we want asymmetry, apply linear transformation to get X and Y
        if self.asymmetric:
            X = Z @ self.W_x
            Y = Z @ self.W_y

            X = self.layer_norm(X)
            Y = self.layer_norm(Y)
        else:
            X = Z
            Y = Z
        
        #Then if we want to apply a nonlinearity, apply it otherwise just take outer product of X and Y to get A_recon
        if self.activation is None:
            A_recon = torch.bmm(X, Y.transpose(1, 2))
        else:
            A_recon = self.activation(torch.bmm(X, Y.transpose(1, 2)))
                                                                                                                                                                                                                                                                                                        
        return A_recon

In [None]:
#Assorted functions

def generate_sbm_adjacency(block_sizes, p, q, rng=None):
    """
    Generate an adjacency matrix for a stochastic block model with variable block sizes.

    Parameters:
    - block_sizes: List of sizes for each block.
    - p: Probability of intra-block edges.
    - q: Probability of inter-block edges.
    - rng: Random number generator (optional).

    Returns:
    - Adjacency matrix as a numpy array.
    """
    if rng is None:
        rng = np.random.default_rng()

    n_blocks = len(block_sizes)
    n = sum(block_sizes)

    # Initialize the adjacency matrix with zeros
    
    adj_matrix = np.zeros((n, n))

    # Calculate the starting index of each block
    block_starts = [0]
    for i in range(n_blocks-1):
        block_starts.append(block_starts[-1] + block_sizes[i])

    for i in range(n_blocks):
        for j in range(i, n_blocks):
            density = p if i == j else q
            block_start_i = block_starts[i]
            block_end_i = block_start_i + block_sizes[i]
            block_start_j = block_starts[j]
            block_end_j = block_start_j + block_sizes[j]

            # Generate random edges within or between blocks
            block_i_size = block_sizes[i]
            block_j_size = block_sizes[j]
            adj_matrix[block_start_i:block_end_i, block_start_j:block_end_j] = (
                rng.random((block_i_size, block_j_size)) < density
            ).astype(int)

            # Make the matrix symmetric (for undirected graphs)
            if i != j:
                adj_matrix[block_start_j:block_end_j, block_start_i:block_end_i] = (
                    adj_matrix[block_start_i:block_end_i, block_start_j:block_end_j].T
                )

    return adj_matrix

def add_digress_noise(A, p, rng=None, symmetric=False):
    """
    Add noise to an adjacency matrix by flipping edges with probability p.

    Parameters:
    - A: A 2D numpy array or tensor representing an adjacency matrix (0s and 1s)
    - p: Probability of flipping each element (0 to 1, 1 becomes 0 and 0 becomes 1)
    - rng: Random number generator (optional)
    - symmetric: If True, noise pattern will be symmetric (upper-tri noise mirrored to lower-tri)

    Returns:
    - Noisy adjacency matrix with some edges flipped
    """
    if rng is None:
        rng = np.random.default_rng()

    A_tensor = torch.tensor(A)
    device = A_tensor.device if hasattr(A_tensor, "device") else "cpu"

    # Generate random values for each matrix element
    if symmetric:
        # Only generate random mask for upper triangle (including diagonal)
        n = A_tensor.shape[0]
        random_values_upper = torch.rand((n, n), device=device)
        # Keep only the upper triangle
        upper_tri_mask = torch.triu(torch.ones((n, n), device=device, dtype=bool))
        random_values_upper[~upper_tri_mask] = 1.1  # Set below-tri elements > p
        flip_mask_upper = random_values_upper < p
        
        # Build full flip mask by mirroring upper to lower, keeping symmetry
        flip_mask = flip_mask_upper | flip_mask_upper.t()
        # Ensure diagonal follows upper (optional: could also do flip_mask.diagonal().fill_(value) if needed)
    else:
        random_values = torch.rand_like(A_tensor)
        flip_mask = random_values < p

    # Flip elements using mask
    A_noisy = torch.where(flip_mask, 1 - A_tensor, A_tensor)

    l, V = torch.linalg.eigh(A_noisy.float())

    return torch.tensor(A_noisy, dtype=torch.float32), torch.tensor(V, dtype=torch.float32), torch.tensor(l, dtype=torch.float32)

def generate_block_sizes(n, min_blocks=2, max_blocks=4, min_size=2, max_size=15):
    # Example usage:
    # n is the number of nodes
    # n = 20
    # partitions = generate_block_sizes(n)
    # print(f"Valid block size partitions for n={n}:")
    # for p in partitions:
    #     print(p)
    valid_partitions = []
    
    # Try different numbers of blocks
    for num_blocks in range(min_blocks, max_blocks + 1):
        def generate_partitions(remaining, blocks_left, current_partition):
            # Base cases
            if blocks_left == 0:
                if remaining == 0:
                    valid_partitions.append(current_partition[:])
                return
            
            # Try different sizes for current block
            start = max(min_size, remaining - (blocks_left-1)*max_size)
            end = min(max_size, remaining - (blocks_left-1)*min_size) + 1
            
            for size in range(start, end):
                if size <= remaining:
                    current_partition.append(size)
                    generate_partitions(remaining - size, blocks_left - 1, current_partition)
                    current_partition.pop()
        
        generate_partitions(n, num_blocks, [])
    
    return valid_partitions

class PermutedAdjacencyDataset(Dataset):
    def __init__(self, adjacency_matrices, num_samples):
        self.adjacency_matrices = adjacency_matrices
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Randomly choose between A_1 and A_2
        matrix_idx = torch.randint(0, len(self.adjacency_matrices), (1,)).item()
        adjacency_matrix = self.adjacency_matrices[matrix_idx]
        
        # Generate random permutation
        permuted_indices = torch.randperm(adjacency_matrix.size(0))
        A_permuted = adjacency_matrix[permuted_indices, :][:, permuted_indices]
        return A_permuted



In [None]:
#Training Loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


random.seed(42)  # For reproducibility
torch.manual_seed(42)
np.random.seed(42)

#hyperparameters
#######################################
use_wandb = False

#number of nodes
n = 20

#GNN parameters
gnn_num_layers = 2
gnn_num_terms = 2
gnn_feature_dim_in = 50
gnn_feature_dim_out = 10 #this is the low dimensional embedding dimension

#transformer parameters
transformer_num_layers = 1
transformer_num_heads = 1
transformer_d_k = 10
transformer_d_v = 10

loss_type = "MSE" #loss criteria: either "MSE" or "BCE"

learning_rate = 0.01

noise_levels = [0.2, 0.2] 
num_samples = 128  # Define the number of samples you want
num_epochs = 500
test_epochs = 10
train_batch_size = 64
test_batch_size = 64

embedding_type = "gaussian"
dropout = 0.0
num_partitions = 1

p_intra = 1.0
q_inter = 0.0
########################################

#gives all possible partitions of n into 2-4 blocks
partitions = generate_block_sizes(n)

# Sample num_partitions many partitions for training and num_partitions many partitions for test
train_partitions = random.sample(partitions, num_partitions)
test_partitions = random.sample([p for p in partitions if p not in train_partitions], num_partitions)

As_train = []
As_test = []

for p in train_partitions:
    A = generate_sbm_adjacency(p, p_intra, q_inter)
    A = torch.tensor(A, dtype=torch.float)
    As_train.append(A)

for p in test_partitions:
    A = generate_sbm_adjacency(p, p_intra, q_inter)
    A = torch.tensor(A, dtype=torch.float)
    As_test.append(A)

print("\nTraining partitions:")
for p in train_partitions:
    print(p)
    
print("\nTest partitions:") 
for p in test_partitions:
    print(p)

train_dataset = PermutedAdjacencyDataset(As_train, num_samples)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

test_dataset = PermutedAdjacencyDataset(As_test, num_samples)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)

level = 0.1
noise_levels = [level, level]
num_eigenvectors = 10
H = 2 #number of heads in the transformer



for layer_norm in [True]:

    for final_nonlinearity in [True]:

        for asymmetric in [True]:

            for embedding_type in ["eigen"]:

                for model_type in ['GraphConvolutionFilter']:

                    print('layer_norm: ', layer_norm)
                    print('final_nonlinearity: ', final_nonlinearity)
                    print('asymmetric: ', asymmetric)
                    print('embedding_type: ', embedding_type)
                    print('model_type: ', model_type)

            
                    if use_wandb:
                        wandb.init(project="graph-denoising")


                    #New Cleaned Up Models

                    K = 3 #number of filter taps in the filter bank
                    F = 50
                    G = 10


                    #case 1: simple linear graph filter bank
                    if model_type == 'GraphConvolutionFilter':
                        if embedding_type == "gaussian":
                            model_denoiser = GraphConvolutionFilter(K=K, F=F, G=G, layer_norm = layer_norm)
                        elif embedding_type == "eigen":
                            model_denoiser = GraphConvolutionFilter(K=K, F=num_eigenvectors, G=G, layer_norm = layer_norm)
                    #case 2: linear PE
                    elif model_type == 'LinearPE':
                        if embedding_type == "gaussian":
                            model_denoiser = LinearPE(F=F, D=G, layer_norm = layer_norm)
                        elif embedding_type == "eigen":
                            model_denoiser = LinearPE(F=num_eigenvectors, D=G, layer_norm = layer_norm)
                    
                    #case 3: self-attention
                    elif model_type == 'SelfAttention':
                        if embedding_type == "gaussian":
                            model_denoiser = MultiHeadAttention(d_model=F, num_heads = H, d_k=G, d_v=G)
                        elif embedding_type == "eigen":
                            model_denoiser = MultiHeadAttention(d_model=num_eigenvectors, num_heads = H, d_k=G, d_v=G)

                    
                    
                    
                    
                    if embedding_type == "gaussian":
                        model_embedding = GaussianEmbedding(num_terms=K, num_channels=F)
                    elif embedding_type == "eigen":
                        model_embedding = EigenEmbedding(num_eigenvectors = num_eigenvectors) 


                    model_denoiser.float().to(device)
                    model_embedding.float().to(device)
                    if model_type == 'SelfAttention' and embedding_type == "gaussian":
                        model = SequentialDenoisingModel(model_embedding, model_denoiser, asymmetric = asymmetric, F = F, D = G)
                    else:
                        if final_nonlinearity:
                            model = SequentialDenoisingModel(model_embedding, model_denoiser, activation = nn.Sigmoid(), asymmetric = asymmetric, F = G, D = G)
                        else:
                            model = SequentialDenoisingModel(model_embedding, model_denoiser, asymmetric = asymmetric, F = G, D = G)
                    model = model.to(device)


                    # Log model hyperparameters if using wandb
                    if use_wandb:
                        wandb.config.update({
                            "loss_type": loss_type,
                            "train_batch_size": train_batch_size,
                            "test_batch_size": test_batch_size,
                            "learning_rate": learning_rate,
                            "embedding_type": embedding_type,
                            "num_partitions": num_partitions,
                            "num_samples": num_samples,
                            "num_epochs": num_epochs,
                            "dropout": dropout,
                            "layer_norm": layer_norm,
                            "final_nonlinearity": final_nonlinearity,
                            "asymmetric": asymmetric,
                            "num_eigenvectors": num_eigenvectors,
                            "K": K,
                            "F": F,
                            "G": G,
                            "model_type": model_type,
                        })

                    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1000, T_mult=2)
                    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

                    if loss_type == "MSE":  
                        criterion = nn.MSELoss()
                    elif loss_type == "BCE":
                        criterion = nn.BCELoss()

                    #Training loop
                    for epoch in tqdm(range(num_epochs), desc="Training epochs"):
                        model.train()
                        epoch_loss = 0.0
                        num_batches = 0
                        
                        for batch in train_dataloader:
                            optimizer.zero_grad()

                            # Sample random noise level for this batch
                            eps = np.random.choice(noise_levels)
                            batch_noisy = (add_digress_noise(batch, eps))[0]

                            batch_noisy = batch_noisy.float().to(device)
                            batch = batch.float().to(device)

                 

                            output = model(batch_noisy)
                            
                            # Compute loss
                            loss = criterion(output, batch)
                            epoch_loss += loss.item()
                            num_batches += 1
                            
                            # Backward pass and optimization
                            loss.backward()
                            optimizer.step()
                        
                        # Step the scheduler
                        scheduler.step()
                        
                        # Calculate average epoch loss
                        avg_epoch_loss = epoch_loss / num_batches
                        
                        #evaluate test error
                        test_loss = 0.0
                        num_batches = 0
                        model.eval()
                        with torch.no_grad():
                            for batch in test_dataloader:
                                # Sample random noise level for this batch
                                eps = np.random.choice(noise_levels)
                                batch_noisy = (add_digress_noise(batch, eps))[0]

                                batch_noisy = batch_noisy.float().to(device)
                                batch = batch.float().to(device)

                                output = model(batch_noisy)
                                
                                # Compute loss
                                loss = criterion(output, batch)
                                test_loss += loss.item()
                                num_batches += 1
                        avg_test_loss = test_loss / num_batches
                        
                        # Log metrics to wandb if enabled
                        if use_wandb:
                            wandb.log({
                                "train_loss": avg_epoch_loss,
                                "test_loss": avg_test_loss,
                                "noise_level": eps,
                                "learning_rate": scheduler.get_last_lr()[0]
                            }, step=epoch)
                        
                        if (epoch) % test_epochs == 0:

                            if use_wandb:
                                #evaluate train and test error per noise level
                                model.eval()
                                with torch.no_grad():
                                    train_losses = []
                                    test_losses = []
                                    for eps in noise_levels:

                                        train_loss = 0.0
                                        test_loss = 0.0
                                        num_batches = 0

                                        for batch in train_dataloader:
                                            batch_noisy = (add_digress_noise(batch, eps))[0]
                                            batch_noisy = batch_noisy.float().to(device)
                                            batch = batch.float().to(device)
                                            output = model(batch_noisy)
                                            loss = criterion(output, batch)
                                            train_loss += loss.item()
                                            num_batches += 1

                                        train_losses.append(train_loss / num_batches)
                                        
                                        for batch in test_dataloader:
                                            batch_noisy = (add_digress_noise(batch, eps))[0]
                                            batch_noisy = batch_noisy.float().to(device)
                                            batch = batch.float().to(device)
                                            output = model(batch_noisy)
                                            loss = criterion(output, batch)
                                            test_loss += loss.item()
                                            num_batches += 1

                                        test_losses.append(test_loss / num_batches)

                                        wandb.log({
                                            "eps_" + str(eps) + "_train_loss": train_losses[-1],
                                            "eps_" + str(eps) + "_test_loss": test_losses[-1],
                                        }, step=epoch)

                                    #visualize the results
                                    eps_values = noise_levels
                                    # Get reconstructions from the model for different noise levels
                                    test_idx = np.random.randint(len(As_test))
                                    train_idx = np.random.randint(len(As_train))
                                    A_test = As_test[test_idx]
                                    A_train = As_train[train_idx]

                                    for j, A_orig in enumerate([A_test, A_train]):
                                        A_orig = A_orig.unsqueeze(0)

                                        l_orig, V_orig = torch.linalg.eigh(A_orig)

                                        fig_size = 4

                                        # Create a figure with subplots - 2 rows for noisy/recon pairs, 5 columns for noise levels
                                        fig, axes = plt.subplots(2, len(eps_values), figsize=(fig_size*len(eps_values), fig_size*2))

                                        # Plot pairs of noisy and reconstructed matrices for each noise level
                                        for i, eps in enumerate(eps_values):
                                            A_noisy = add_digress_noise(A_orig, eps)[0]
                                            with torch.no_grad():
                                                A_recon = model(A_noisy.float().to(device))[0]

                                                l_noisy, V_noisy = torch.linalg.eigh(A_noisy)
                                                l_recon, V_recon = torch.linalg.eigh(A_recon)
                                            plt.figure()

                                            V_orig = V_orig.squeeze(0)
                                            V_noisy = V_noisy.squeeze(0)

                                            V_orig = (V_orig[:, -3:]).to(device)
                                            V_noisy = (V_noisy[:, -3:]).to(device)
                                            V_recon = (V_recon[:, -3:]).to(device)

                                            d_subspace_noisy_orig = torch.linalg.norm(V_orig@V_orig.T - V_noisy@V_noisy.T)
                                            d_subspace_recon_orig = torch.linalg.norm(V_orig@V_orig.T - V_recon@V_recon.T)

                                            # plt.stem(l_orig.squeeze(0).detach().cpu().numpy())
                                            # plt.show()
                                            # plt.stem(l_noisy.squeeze(0).detach().cpu().numpy())
                                            # plt.show()
                                            # plt.stem(l_recon.detach().cpu().numpy())
                                            # plt.show()
                                            # assert False

                                            # Plot noisy matrix on top row
                                            im1 = axes[0,i].imshow(A_noisy.squeeze(0).detach().cpu().numpy(), cmap='viridis')
                                            axes[0,i].set_title(f'Noisy (Îµ={eps})')
                                            
                                            # Plot reconstructed matrix below
                                            im2 = axes[1,i].imshow(A_recon.squeeze(0).detach().cpu().numpy(), cmap='viridis')
                                            axes[1,i].set_title(f'Reconstructed')

                                        if j == 0:
                                            wandb.log({"d_subspace_noisy_orig (test)": d_subspace_noisy_orig, "d_subspace_recon_orig (test)": d_subspace_recon_orig}, step=epoch)
                                        else:
                                            wandb.log({"d_subspace_noisy_orig (train)": d_subspace_noisy_orig, "d_subspace_recon_orig (train)": d_subspace_recon_orig}, step=epoch)

                                        # Log visualization to wandb if enabled
                                        if j == 0:
                                            wandb.log({"Test graph reconstruction": wandb.Image(fig)}, step=epoch)
                                        else:
                                            wandb.log({"Train graph reconstruction": wandb.Image(fig)}, step=epoch)
                                        plt.close(fig)

                    # Save the trained model
                    torch.save(model.state_dict(), 'models/denoiser_model_1.pt')
                    print("Model saved successfully to denoiser_model_1.pt")

                    if use_wandb:
                        wandb.finish()