In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm
import os
import argparse
from scipy.sparse.linalg import eigsh
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import random
from torch.utils.data import Dataset, DataLoader
#import wandb
import torch.nn.functional as F

In [3]:
#embeddings and gnns

class GaussianEmbedding(nn.Module):
    def __init__(self, num_terms, num_channels):
        super(GaussianEmbedding, self).__init__()
        self.num_terms = num_terms
        self.num_channels = num_channels

        self.h = nn.Parameter(torch.randn(num_terms + 1, num_channels))
        nn.init.xavier_uniform_(self.h)

    def forward(self, A):
        batch_size, num_nodes, _ = A.shape
        Y_hat = torch.zeros(batch_size, num_nodes, self.num_channels, device=A.device)
        for c in range(self.num_channels):
            result = self.h[0, c] * torch.eye(num_nodes, device=A.device).unsqueeze(0).expand(batch_size, -1, -1)
            for i in range(1, self.num_terms + 1):
                A_power_i = torch.matrix_power(A, i)
                result += self.h[i, c] * A_power_i
            Y_hat[..., c] = torch.diagonal(result, dim1=-2, dim2=-1)
        return Y_hat

class EigenEmbedding(nn.Module):
    def __init__(self):
        super(EigenEmbedding, self).__init__()

    def forward(self, A):
        eigenvectors = []
        for i in range(A.shape[0]):
            _, V = torch.linalg.eigh(A[i])
            eigenvectors.append(V)
        return torch.stack(eigenvectors, dim=0)


class GraphConvolutionLayer(nn.Module):
    def __init__(self, num_terms, num_channels):
        super(GraphConvolutionLayer, self).__init__()
        self.num_terms = num_terms
        self.num_channels = num_channels

        self.H = nn.Parameter(torch.randn(num_terms + 1, num_channels, num_channels))
        nn.init.xavier_uniform_(self.H)

        self.activation = nn.Tanh()
        self.layer_norm = nn.LayerNorm(num_channels)

    def forward(self, A, X):
        Y_hat = X @ self.H[0] 
        for i in range(1, self.num_terms + 1):
            A_power_i = torch.matrix_power(A, i)
            Y_hat += torch.bmm(A_power_i, X) @ self.H[i]
        Y_hat = self.layer_norm(Y_hat)
        Y_hat = self.activation(Y_hat)
        return Y_hat


class NodeVarGraphConvolutionLayer(nn.Module):
    def __init__(self, num_terms, num_channels, num_nodes):
        super(NodeVarGraphConvolutionLayer, self).__init__()
        self.num_terms = num_terms
        self.num_channels = num_channels
        self.num_nodes = num_nodes

        self.h = nn.Parameter(torch.randn(num_terms + 1, num_channels, num_nodes))
        nn.init.xavier_uniform_(self.h)

        self.activation = nn.Tanh()
        self.layer_norm = nn.LayerNorm(num_channels)

    def forward(self, A, X):
        batch_size, num_nodes, num_channels_in = X.shape
        Y_hat = torch.zeros(batch_size, num_nodes, self.num_channels, device=A.device)
        for c in range(self.num_channels):
            result = torch.zeros(batch_size, num_nodes, device=A.device)
            h_diag = torch.diag_embed(self.h[0, c])
            A_w = h_diag @ torch.eye(num_nodes, device=A.device).unsqueeze(0).expand(batch_size, -1, -1)
            for ch in range(num_channels_in):
                result += torch.bmm(A_w, X[..., ch].unsqueeze(-1)).squeeze(-1)

            for i in range(1, self.num_terms + 1):
                A_power_i = torch.matrix_power(A, i)
                h_diag = torch.diag_embed(self.h[i, c])
                A_w = h_diag @ A_power_i
                for ch in range(num_channels_in):
                    result += torch.bmm(A_w, X[..., ch].unsqueeze(-1)).squeeze(-1)
            Y_hat[..., c] = result
        Y_hat = self.layer_norm(Y_hat)
        Y_hat = self.activation(Y_hat)
        return Y_hat


class GNN(nn.Module):
    def __init__(self, num_layers, num_terms=3, feature_dim_in=10, feature_dim_out=10):
        super(GNN, self).__init__()

        # self.embedding_layer = GaussianEmbedding(num_terms, feature_dim)
        self.embedding_layer = EigenEmbedding()

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(GraphConvolutionLayer(num_terms, feature_dim_in))

        self.out_x = nn.Linear(feature_dim_in, feature_dim_out)
        self.out_y = nn.Linear(feature_dim_in, feature_dim_out)

    def forward(self, A):
        Z = self.embedding_layer(A)
        for layer in self.layers:
            Z = layer(A, Z)
        X = self.out_x(Z)
        Y = self.out_y(Z)
        outer = torch.bmm(X, Y.transpose(1, 2))
        A_pred = torch.sigmoid(outer)

        return X, Y
    
class GNN_symmetric(nn.Module):
    def __init__(self, num_layers, num_terms=3, feature_dim_in=10, feature_dim_out=10):
        super(GNN_symmetric, self).__init__()

        # self.embedding_layer = GaussianEmbedding(num_terms, feature_dim)
        self.embedding_layer = EigenEmbedding()

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(GraphConvolutionLayer(num_terms, feature_dim_in))

        self.out_x = nn.Linear(feature_dim_in, feature_dim_out)


    def forward(self, A):
        Z = self.embedding_layer(A)
        for layer in self.layers:
            Z = layer(A, Z)
        X = self.out_x(Z)

        outer = torch.bmm(X, X.transpose(1, 2))
        return torch.sigmoid(outer), X


class NodeVarGNN(nn.Module):
    def __init__(self, num_layers, num_terms=3, feature_dim=10):
        super(NodeVarGNN, self).__init__()

        # self.embedding_layer = GaussianEmbedding(num_terms, feature_dim)
        self.embedding_layer = EigenEmbedding()

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(NodeVarGraphConvolutionLayer(num_terms, feature_dim, feature_dim))

        self.out_x = nn.Linear(feature_dim, feature_dim)
        self.out_y = nn.Linear(feature_dim, feature_dim)

    def forward(self, A):
        Z = self.embedding_layer(A)
        for layer in self.layers:
            Z = layer(A, Z)
        X = self.out_x(Z)
        Y = self.out_y(Z)
        outer = torch.bmm(X, Y.transpose(1, 2))
        return torch.sigmoid(outer)

In [4]:
#Attention layers


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module as described in 'Attention Is All You Need' paper.
    
    This implementation supports masked attention and different input/output dimensions.
    """
    
    def __init__(self, d_model, num_heads, d_k=None, d_v=None, dropout=0.0, bias=False):
        """
        Initialize the Multi-Head Attention module.
        
        Parameters:
        - d_model: Model dimension (input and output dimension)
        - num_heads: Number of attention heads
        - d_k: Dimension of keys (default: d_model // num_heads)
        - d_v: Dimension of values (default: d_model // num_heads)
        - dropout: Dropout probability
        """
        super(MultiHeadAttention, self).__init__()
        
        self.num_heads = num_heads
        self.d_model = d_model
        
        # If d_k and d_v are not specified, set them to d_model // num_heads
        self.d_k = d_k if d_k is not None else d_model // num_heads
        self.d_v = d_v if d_v is not None else d_model // num_heads
        
        # Linear projections for queries, keys, and values
        self.W_q = nn.Linear(d_model, num_heads * self.d_k, bias=bias)
        self.W_k = nn.Linear(d_model, num_heads * self.d_k, bias=bias)
        self.W_v = nn.Linear(d_model, num_heads * self.d_v, bias=bias)
        
        # Output projection
        self.W_o = nn.Linear(num_heads * self.d_v, d_model, bias=bias)
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
        
        # Layer normalization for the output
        self.layer_norm = nn.LayerNorm(d_model)
        
        # Scaling factor for dot product attention
        # self.scale = 1 / math.sqrt(self.d_k)
        self.scale = 1

        # Linear layer to combine attention scores from different heads
        self.score_combination = nn.Linear(num_heads, 1, bias=False)
    
    def forward(self, x, mask=None, residual=None):
        """
        Forward pass of the Multi-Head Attention module.
        
        Parameters:
        - Q: Query tensor of shape (batch_size, seq_len_q, d_model)
        - K: Key tensor of shape (batch_size, seq_len_k, d_model)
        - V: Value tensor of shape (batch_size, seq_len_v, d_model)
        - mask: Optional mask tensor of shape (batch_size, seq_len_q, seq_len_k)
        - residual: Optional residual connection
        
        Returns:
        - output: Output tensor of shape (batch_size, seq_len_q, d_model)
        - attention: Attention weights of shape (batch_size, num_heads, seq_len_q, seq_len_k)
        """
        batch_size = x.size(0)
        
        # If residual connection is not provided, use Q as residual
        if residual is None:
            residual = x

        # Linear projections and reshaping for multi-head attention
        # Shape: (batch_size, seq_len, num_heads, d_*)

        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)


        q = q.view(batch_size, -1, self.num_heads, self.d_k)
        k = k.view(batch_size, -1, self.num_heads, self.d_k)
        v = v.view(batch_size, -1, self.num_heads, self.d_v)



        # Transpose to shape: (batch_size, num_heads, seq_len, d_*)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        
        # Calculate attention scores
        # (batch_size, num_heads, seq_len_q, seq_len_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        # Apply mask if provided
        if mask is not None:
            # Add an extra dimension for the number of heads
            if mask.dim() == 3:  # (batch_size, seq_len_q, seq_len_k)
                mask = mask.unsqueeze(1)
            
            # Set masked positions to a large negative value before softmax
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply dropout to attention weights
        attn_weights = self.dropout(attn_weights)
        
        # Calculate weighted sum of values
        # (batch_size, num_heads, seq_len_q, d_v)
        context = torch.matmul(attn_weights, v)
        
        # Transpose and reshape to (batch_size, seq_len_q, num_heads * d_v)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_v)
        
        # Apply output projection
        output = self.W_o(context)
        
        # Apply dropout and residual connection
        output = self.dropout(output)
        output = self.layer_norm(output + residual)

        # Combine attention scores from different heads using learned weights
        # Transpose scores to have heads dimension last: (batch_size, seq_len_q, seq_len_k, num_heads)
        scores = scores.permute(0, 2, 3, 1)
        # Apply linear combination: (batch_size, seq_len_q, seq_len_k, 1)
        combined_scores = self.score_combination(scores)
        # Remove last singleton dimension
        combined_scores = combined_scores.squeeze(-1)
        
        return output, combined_scores


class MultiLayerAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, d_k=None, d_v=None, dropout=0.0, bias=False):
        super().__init__()
        
        # If d_k and d_v are not specified, set them equal to d_model/num_heads
        if d_k is None:
            d_k = d_model // num_heads
        if d_v is None:
            d_v = d_model // num_heads
            
        # Create stack of attention layers
        self.layers = nn.ModuleList([
            MultiHeadAttention(d_model, num_heads=num_heads, d_k=d_k, d_v=d_v, dropout=dropout, bias=bias)
            for _ in range(num_layers)
        ])
        
    def forward(self, x, mask=None):
        # Keep track of attention scores from each layer
        attention_scores = []
        
        # Pass through each attention layer sequentially
        for layer in self.layers:

            x, scores = layer(x, mask=mask)

            attention_scores.append(scores)
            
        return x, attention_scores

In [5]:
# Define sequential model combining embedding and denoising
class SequentialDenoisingModel(nn.Module):
    def __init__(self, embedding_model, denoising_model = None):
        super(SequentialDenoisingModel, self).__init__()
        self.embedding_model = embedding_model
        self.denoising_model = denoising_model
        
    def forward(self, x):
        # First apply embedding model
        X, Y = self.embedding_model(x)

        Z = torch.cat([X, Y], dim=2)

        if self.denoising_model is not None:
            Z_pred = Z + self.denoising_model(Z)[0]
        else:
            Z_pred = Z

        X_pred = Z_pred[:, :, 0:X.shape[2]]
        Y_pred = Z_pred[:, :, X.shape[2]:]

        A_recon = torch.sigmoid(torch.bmm(X_pred, Y_pred.transpose(1, 2)))                                  
                                                                                                                                                                                                                                                                                                        
        return A_recon

In [6]:
#Functions
def generate_sbm_adjacency(block_sizes, p, q, rng=None):
    """
    Generate an adjacency matrix for a stochastic block model with variable block sizes.

    Parameters:
    - block_sizes: List of sizes for each block.
    - p: Probability of intra-block edges.
    - q: Probability of inter-block edges.
    - rng: Random number generator (optional).

    Returns:
    - Adjacency matrix as a numpy array.
    """
    if rng is None:
        rng = np.random.default_rng()

    n_blocks = len(block_sizes)
    n = sum(block_sizes)

    # Initialize the adjacency matrix with zeros
    
    adj_matrix = np.zeros((n, n))

    # Calculate the starting index of each block
    block_starts = [0]
    for i in range(n_blocks-1):
        block_starts.append(block_starts[-1] + block_sizes[i])

    for i in range(n_blocks):
        for j in range(i, n_blocks):
            density = p if i == j else q
            block_start_i = block_starts[i]
            block_end_i = block_start_i + block_sizes[i]
            block_start_j = block_starts[j]
            block_end_j = block_start_j + block_sizes[j]

            # Generate random edges within or between blocks
            block_i_size = block_sizes[i]
            block_j_size = block_sizes[j]
            adj_matrix[block_start_i:block_end_i, block_start_j:block_end_j] = (
                rng.random((block_i_size, block_j_size)) < density
            ).astype(int)

            # Make the matrix symmetric (for undirected graphs)
            if i != j:
                adj_matrix[block_start_j:block_end_j, block_start_i:block_end_i] = (
                    adj_matrix[block_start_i:block_end_i, block_start_j:block_end_j].T
                )

    return adj_matrix

def add_digress_noise(A, p, rng=None):
    """
    Add noise to an adjacency matrix by flipping edges with probability p.
    
    Parameters:
    - adj_matrix: A 2D numpy array or tensor representing an adjacency matrix (0s and 1s)
    - p: Probability of flipping each element (0 to 1, 1 becomes 0 and 0 becomes 1)
    - rng: Random number generator (optional)
    
    Returns:
    - Noisy adjacency matrix with some edges flipped
    """
    if rng is None:
        rng = np.random.default_rng()
    
    # Create a copy of the original matrix to avoid modifying it
    A_noisy = A
    
    # Generate random values for each element
    random_values = torch.rand_like(torch.tensor(A))
    
    # Create a mask for elements to flip (where random value < p)
    flip_mask = random_values < p
    
    # Flip the elements where the mask is True (using XOR operation)
    # XOR with 1 flips 0→1 and 1→0
    A_noisy = torch.where(flip_mask, 1 - torch.tensor(A), torch.tensor(A))
    
    l, V = torch.linalg.eigh(A_noisy)

    return torch.tensor(A_noisy, dtype=torch.float32), torch.tensor(V, dtype=torch.float32), torch.tensor(l, dtype=torch.float32)


#Functions
def generate_block_sizes(n, min_blocks=2, max_blocks=4, min_size=2, max_size=15):
    # Example usage:
    # n is the number of nodes
    # n = 20
    # partitions = generate_block_sizes(n)
    # print(f"Valid block size partitions for n={n}:")
    # for p in partitions:
    #     print(p)
    valid_partitions = []
    
    # Try different numbers of blocks
    for num_blocks in range(min_blocks, max_blocks + 1):
        def generate_partitions(remaining, blocks_left, current_partition):
            # Base cases
            if blocks_left == 0:
                if remaining == 0:
                    valid_partitions.append(current_partition[:])
                return
            
            # Try different sizes for current block
            start = max(min_size, remaining - (blocks_left-1)*max_size)
            end = min(max_size, remaining - (blocks_left-1)*min_size) + 1
            
            for size in range(start, end):
                if size <= remaining:
                    current_partition.append(size)
                    generate_partitions(remaining - size, blocks_left - 1, current_partition)
                    current_partition.pop()
        
        generate_partitions(n, num_blocks, [])
    
    return valid_partitions

class PermutedAdjacencyDataset(Dataset):
    def __init__(self, adjacency_matrices, num_samples):
        self.adjacency_matrices = adjacency_matrices
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Randomly choose between A_1 and A_2
        matrix_idx = torch.randint(0, len(self.adjacency_matrices), (1,)).item()
        adjacency_matrix = self.adjacency_matrices[matrix_idx]
        
        # Generate random permutation
        permuted_indices = torch.randperm(adjacency_matrix.size(0))
        A_permuted = adjacency_matrix[permuted_indices, :][:, permuted_indices]
        return A_permuted



In [None]:
random.seed(42)  # For reproducibility
torch.manual_seed(42)
np.random.seed(42)

#hyperparameters
#######################################
use_wandb = False

#number of nodes
n = 20

use_transformer = True

#GNN parameters
gnn_num_layers = 2
gnn_num_terms = 2
gnn_feature_dim_in = 20
gnn_feature_dim_out = 5 #this is the low dimensional embedding dimension

#transformer parameters
transformer_num_layers = 4
transformer_num_heads = 4
transformer_d_k = 10
transformer_d_v = 10

loss_type = "BCE" #loss criteria: either "MSE" or "BCE"

learning_rate = 0.005

noise_levels = [0.005, 0.02, 0.05, 0.1, 0.25, 0.4, 0.5]

num_epochs = 200
test_epochs = 10
train_batch_size = 100
test_batch_size = 100
########################################





#gives all possible partitions of n into 2-4 blocks
partitions = generate_block_sizes(n)





# Sample 10 partitions for training and 10 for test
train_partitions = random.sample(partitions, 10)
test_partitions = random.sample([p for p in partitions if p not in train_partitions], 10)

As_train = []
As_test = []

p_intra = 1.0
q_inter = 0.0

for p in train_partitions:
    A = generate_sbm_adjacency(p, p_intra, q_inter)
    A = torch.tensor(A)
    As_train.append(A)

for p in test_partitions:
    A = generate_sbm_adjacency(p, p_intra, q_inter)
    A = torch.tensor(A)
    As_test.append(A)

print("\nTraining partitions:")
for p in train_partitions:
    print(p)
    
print("\nTest partitions:") 
for p in test_partitions:
    print(p)


num_samples = 1000  # Define the number of samples you want
train_dataset = PermutedAdjacencyDataset(As_train, num_samples)
train_dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)

test_dataset = PermutedAdjacencyDataset(As_test, num_samples)
test_dataloader = DataLoader(test_dataset, batch_size=100, shuffle=True)

In [11]:
len(As_train)

10

In [None]:
if use_wandb:
    wandb.init(project="graph-denoising", name="training_run")



model_embedding = GNN(num_layers=gnn_num_layers, num_terms=gnn_num_terms, feature_dim_in=gnn_feature_dim_in, feature_dim_out=gnn_feature_dim_out)
model_embedding = model_embedding.double()



model_denoiser = MultiLayerAttention(2*gnn_feature_dim_out, num_heads = transformer_num_heads, d_k = transformer_d_k, d_v = transformer_d_v, num_layers = transformer_num_layers, bias = True)
model_denoiser.double()



# Create combined sequential model
if use_transformer == True:
    model = SequentialDenoisingModel(model_embedding, model_denoiser)
else:
    model = SequentialDenoisingModel(model_embedding, None)




# Log model hyperparameters if using wandb
if use_wandb:
    wandb.config.update({
        "gnn_num_layers": gnn_num_layers,
        "gnn_num_terms": gnn_num_terms,
        "gnn_feature_dim_in": gnn_feature_dim_in,
        "gnn_feature_dim_out": gnn_feature_dim_out,
        "transformer_num_layers": transformer_num_layers,
        "transformer_num_heads": transformer_num_heads,
        "transformer_d_k": transformer_d_k,
        "transformer_d_v": transformer_d_v,
        "loss_type": loss_type,
        "use_transformer": use_transformer,
        "train_batch_size": train_batch_size,
        "test_batch_size": test_batch_size,
        "learning_rate": learning_rate,
    })

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)

if loss_type == "MSE":  
    criterion = nn.MSELoss()
elif loss_type == "BCE":
    criterion = nn.BCELoss()

#Training loop
for epoch in tqdm(range(num_epochs), desc="Training epochs"):
    model.train()
    epoch_loss = 0.0
    num_batches = 0
    
    for batch in train_dataloader:
        optimizer.zero_grad()

        # Sample random noise level for this batch
        eps = np.random.choice(noise_levels)
        batch_noisy = (add_digress_noise(batch, eps))[0]

        batch_noisy = batch_noisy.double()
        batch = batch.double()

        output = model(batch_noisy)
        
        # Compute loss
        loss = criterion(output, batch)
        epoch_loss += loss.item()
        num_batches += 1
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    # Step the scheduler
    scheduler.step()
    
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    
    #evaluate test error
    test_loss = 0.0
    num_batches = 0
    model.eval()
    with torch.no_grad():
        for batch in test_dataloader:
            # Sample random noise level for this batch
            eps = np.random.choice(noise_levels)
            batch_noisy = (add_digress_noise(batch, eps))[0]

            batch_noisy = batch_noisy.double()
            batch = batch.double()

            output = model(batch_noisy)
            
            # Compute loss
            loss = criterion(output, batch)
            test_loss += loss.item()
            num_batches += 1
    avg_test_loss = test_loss / num_batches
    
    # Log metrics to wandb if enabled
    if use_wandb:
        wandb.log({
            "train_loss": avg_epoch_loss,
            "test_loss": avg_test_loss,
            "noise_level": eps,
            "learning_rate": scheduler.get_last_lr()[0]
        })
    
    if (epoch) % test_epochs == 0:

        
        if use_wandb:
            #evaluate train and test error per noise level
            model.eval()
            with torch.no_grad():
                train_losses = []
                test_losses = []
                for eps in noise_levels:

                    train_loss = 0.0
                    test_loss = 0.0
                    num_batches = 0

                    for batch in train_dataloader:
                        batch_noisy = (add_digress_noise(batch, eps))[0]
                        batch_noisy = batch_noisy.double()
                        batch = batch.double()
                        output = model(batch_noisy)
                        loss = criterion(output, batch)
                        train_loss += loss.item()
                        num_batches += 1

                    train_losses.append(train_loss / num_batches)
                    
                    for batch in test_dataloader:
                        batch_noisy = (add_digress_noise(batch, eps))[0]
                        batch_noisy = batch_noisy.double()
                        batch = batch.double()
                        output = model(batch_noisy)
                        loss = criterion(output, batch)
                        test_loss += loss.item()
                        num_batches += 1

                    test_losses.append(test_loss / num_batches)

                    wandb.log({
                        "eps_" + str(eps) + "_train_loss": train_losses[-1],
                        "eps_" + str(eps) + "_test_loss": test_losses[-1],
                    })

               

                #visualize the results
                eps_values = noise_levels
                # Get reconstructions from the model for different noise levels
                test_idx = np.random.randint(len(As_test))
                train_idx = np.random.randint(len(As_train))
                A_test = As_test[test_idx]
                A_train = As_train[train_idx]

                for j, A_orig in enumerate([A_test, A_train]):
                    A_orig = A_orig.unsqueeze(0)

                    fig_size = 4

                    # Create a figure with subplots - 2 rows for noisy/recon pairs, 5 columns for noise levels
                    fig, axes = plt.subplots(2, len(eps_values), figsize=(fig_size*len(eps_values), fig_size*2))

                    # Plot pairs of noisy and reconstructed matrices for each noise level
                    for i, eps in enumerate(eps_values):
                        A_noisy = add_digress_noise(A_orig, eps)[0]
                        with torch.no_grad():
                            A_recon = model(A_noisy.double())[0]
                        
                        # Plot noisy matrix on top row
                        im1 = axes[0,i].imshow(A_noisy.squeeze(0).numpy(), cmap='viridis')
                        axes[0,i].set_title(f'Noisy (ε={eps})')
                        
                        # Plot reconstructed matrix below
                        im2 = axes[1,i].imshow(A_recon.squeeze(0).numpy(), cmap='viridis')
                        axes[1,i].set_title(f'Reconstructed')

                    # Log visualization to wandb if enabled
                    if j == 0:
                        wandb.log({"Test graph reconstruction": wandb.Image(fig)})
                    else:
                        wandb.log({"Train graph reconstruction": wandb.Image(fig)})
                    plt.close(fig)

if use_wandb:
    wandb.finish()


Training partitions:
[12, 2, 6]
[3, 3, 14]
[3, 8, 7, 2]
[3, 4, 10, 3]
[3, 2, 11, 4]
[2, 3, 3, 12]
[10, 5, 5]
[10, 4, 3, 3]
[8, 7, 5]
[6, 3, 7, 4]

Test partitions:
[3, 11, 6]
[3, 9, 8]
[9, 6, 5]
[3, 2, 12, 3]
[3, 4, 4, 9]
[8, 8, 2, 2]
[3, 6, 11]
[2, 10, 6, 2]
[11, 3, 3, 3]
[6, 4, 5, 5]


  random_values = torch.rand_like(torch.tensor(A))
  A_noisy = torch.where(flip_mask, 1 - torch.tensor(A), torch.tensor(A))
  return torch.tensor(A_noisy, dtype=torch.float32), torch.tensor(V, dtype=torch.float32), torch.tensor(l, dtype=torch.float32)
Training epochs:  47%|████████████████████████████████████████████████████████████████████████████▏                                                                                     | 94/200 [00:40<00:44,  2.41it/s]