In [1]:
# Import necessary libraries
import os
import math
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler

# Import the SyntheticData class from factor_eval.py
# Make sure factor_eval.py is in the same directory as your notebook
from factor_eval import SyntheticData

# Set the random seed for reproducibility
seed = 42
n_samples = 50000
n_test = 12500
n_sources = 5
k = 10
snr = 5
correlate_sources = False
get_covariance = False
random_scale = False
nuisance = 10

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Generate synthetic data
# We want:
# - 50,000 training samples, 12,500 test samples
# - 5 latent factors (n_sources=5), k=10 variables per factor => 1170 observed variables in total
# - SNR = 0.5 for moderate noise
# - No correlation between sources (correlate_sources=False)
# - No nuisance variables (nuisance=0)
# - Use a fixed seed for reproducibility
synthetic_data = SyntheticData(
    n_samples=n_samples,
    n_test=n_test, 
    n_sources=n_sources,
    k=k,
    snr=snr,             # Lower SNR to increase noise
    correlate_sources=correlate_sources,  # Introduce correlation among latent factors
    get_covariance=get_covariance,
    random_scale=random_scale,    # Randomly scale variables for additional complexity
    nuisance=nuisance,           # No nuisance variables
    seed=seed               # Fixed seed for reproducibility
)

n_observed = synthetic_data.train.shape[1]
# Extract training and test sets
X_train = synthetic_data.train
X_test = synthetic_data.test

# -------------------------- Normalization Block (BEGIN) --------------------------
# Compute min and max from the training data
train_min = np.min(X_train, axis=0, keepdims=True)
train_max = np.max(X_train, axis=0, keepdims=True)

# Apply min-max normalization: (X - min) / (max - min)
# Handle the case where max == min to avoid division by zero (if any variable is constant)
denominator = (train_max - train_min)
denominator[denominator == 0] = 1e-8  # A small number to avoid division by zero

X_train = (X_train - train_min) / denominator
X_test = (X_test - train_min) / denominator
# -------------------------- Normalization Block (END) --------------------------

# Convert the data to PyTorch tensors
train_tensor = torch.tensor(X_train, dtype=torch.float32)
test_tensor = torch.tensor(X_test, dtype=torch.float32)

# Create TensorDatasets for training and test sets
train_dataset = TensorDataset(train_tensor)
test_dataset = TensorDataset(test_tensor)

# Create a generator for deterministic shuffling
generator = torch.Generator()
generator.manual_seed(seed)

# We don't need to do a random split now since SyntheticData already provides train/test sets
# We'll just create a RandomSampler for the train dataset to ensure reproducible shuffling
train_sampler = RandomSampler(train_dataset, generator=generator)

# Create DataLoaders for training and test sets
batch_size = 32  # Adjust the batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Display shapes and verify
print(f"Training Data Shape: {X_train.shape}")
print(f"Test Data Shape: {X_test.shape}")
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Training Data Shape: (50000, 60)
Test Data Shape: (12500, 60)
Number of training batches: 1563
Number of test batches: 391


In [2]:
# Let's print the first 5 rows of the training data
print("First 5 rows of training data:")
print(X_train[:5])

# Let's also print summary statistics of the first few columns to see their distribution
num_cols_to_inspect = 5  # you can change this number
cols_to_inspect = X_train[:, :num_cols_to_inspect]

print("\nSummary statistics for the first 5 columns of the training data:")
means = cols_to_inspect.mean(axis=0)
stds = cols_to_inspect.std(axis=0)
mins = cols_to_inspect.min(axis=0)
maxs = cols_to_inspect.max(axis=0)

for i in range(num_cols_to_inspect):
    print(f"Column {i}: mean={means[i]:.4f}, std={stds[i]:.4f}, min={mins[i]:.4f}, max={maxs[i]:.4f}")

First 5 rows of training data:
[[0.50868695 0.4987838  0.61679009 0.51139569 0.57808391 0.53119005
  0.59510766 0.49915261 0.48816219 0.56378535 0.54389231 0.46161187
  0.44176781 0.44011464 0.5212177  0.44182656 0.50323608 0.47318434
  0.48656391 0.43129428 0.62553241 0.63957391 0.55248758 0.50820067
  0.5742377  0.5785061  0.5395921  0.58961463 0.62899889 0.52459604
  0.59427365 0.68768851 0.60007923 0.67818638 0.65084412 0.63100136
  0.65105545 0.60424543 0.67327988 0.57388115 0.45481663 0.58123842
  0.5387345  0.37032482 0.46125067 0.51695561 0.4722074  0.52507556
  0.42781294 0.38517918 0.67877306 0.32525958 0.45856828 0.62778782
  0.54594568 0.28734671 0.23028904 0.49383711 0.59874729 0.41008683]
 [0.48820029 0.41353023 0.45875492 0.43953774 0.46819887 0.50459208
  0.4889965  0.47278801 0.51049108 0.49991677 0.63305502 0.63168452
  0.68632706 0.72601394 0.65869075 0.69213919 0.69485117 0.65005208
  0.52951655 0.60731036 0.63851802 0.54294673 0.591343   0.58300996
  0.56239008 0.5

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(
        self,
        input_dim=n_observed,         # Number of obaserved variables
        hidden_dims=[128, 64],        # Shared hidden layers
        output_dim=5,                 # Number of latent factors
        embedding_dim=64
    ):
        super(Encoder, self).__init__()

        # ---------------------------------------------------------------------
        # 1. Sanity checks on dimensions (exactly as in your original code, but
        #    with n_observed as a parameter).
        # ---------------------------------------------------------------------
        assert input_dim == n_observed, f"Expected input_dim to be {n_observed}, but got {input_dim}"
        assert output_dim == 5,        f"Expected output_dim to be 5, but got {output_dim}"
        assert embedding_dim == 64,    f"Expected embedding_dim to be 64, but got {embedding_dim}"

        # ---------------------------------------------------------------------
        # 2. Define MLP "shared" layers that will feed into mu and log_var.
        #    (We do NOT produce an output = '5' from these layers directly.)
        # ---------------------------------------------------------------------
        dims = [input_dim] + hidden_dims  # e.g. [n_observed, 128, 64]
        self.shared_layers = nn.ModuleList()
        for i in range(len(dims) - 1):
            in_features = dims[i]
            out_features = dims[i + 1]
            # Each hidden layer can be [Linear -> ReLU]
            layer = nn.Sequential(
                nn.Linear(in_features, out_features),
                nn.ReLU()
            )
            self.shared_layers.append(layer)

        # ---------------------------------------------------------------------
        # 3. Final "heads" for mu and log_var (each of dimension output_dim=5).
        # ---------------------------------------------------------------------
        last_dim = dims[-1]  # e.g. 64
        self.fc_mu = nn.Linear(last_dim, output_dim)
        self.fc_log_var = nn.Linear(last_dim, output_dim)

        # ---------------------------------------------------------------------
        # 4. Learnable embedding matrix e_i for each latent dimension z_i
        #    shape: (5, 64).
        # ---------------------------------------------------------------------
        self.e = nn.Parameter(torch.randn(output_dim, embedding_dim))
        self.embedding_dim = embedding_dim
        self.output_dim = output_dim

        # Basic shape checks for the embedding matrix
        assert self.e.shape == (5, 64), \
            f"Expected embedding matrix e to have shape (5, 64), got {self.e.shape}"

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, n_observed)
        Returns:
          hat_Z: shape (batch_size, output_dim, embedding_dim)
          mu, log_var: shape (batch_size, output_dim) each
        """
        # ---------------------------------------------------------------------
        # 1. Basic checks on input shape
        # ---------------------------------------------------------------------
        assert x.dim() == 2, f"Expected x to be a 2D tensor, but got {x.dim()}D."
        batch_size = x.size(0)
        assert x.size(1) == n_observed, \
            f"Expected x to have {n_observed} features, got {x.shape[1]}"

        # ---------------------------------------------------------------------
        # 2. Pass through the shared hidden layers (MLP).
        # ---------------------------------------------------------------------
        for layer in self.shared_layers:
            x = layer(x)  # [batch_size, hidden_dims[-1]]

        # ---------------------------------------------------------------------
        # 3. Compute mu and log_var from the last hidden layer output.
        #    - We apply some activation to each (just as an example).
        # ---------------------------------------------------------------------
        mu = self.fc_mu(x)             # shape: (batch_size, output_dim=5)
        mu = torch.tanh(mu)            # e.g. constrain mu to -1..1

        log_var = self.fc_log_var(x)   # shape: (batch_size, 5)
        log_var = torch.sigmoid(log_var)  # e.g. constrain log_var to 0..1

        # ---------------------------------------------------------------------
        # 4. Reparameterization trick:
        #    std = exp(0.5 * log_var), then z = mu + eps * std
        # ---------------------------------------------------------------------
        std = torch.exp(0.5 * log_var)   # shape: (batch_size, 5)
        eps = torch.randn_like(std)      # same shape
        z = mu + std * eps               # shape: (batch_size, 5)

        # ---------------------------------------------------------------------
        # 5. Convert z (batch_size, 5) => hat_Z (batch_size, 5, 64)
        # ---------------------------------------------------------------------
        # Expand z to (batch_size, 5, 1)
        z_expanded = z.unsqueeze(2)
        assert z_expanded.shape == (batch_size, 5, 1), \
            f"Expected z_expanded to have shape ({batch_size}, 5, 1), but got {z_expanded.shape}"

        # Expand e to (1, 5, 64)
        e_expanded = self.e.unsqueeze(0)
        assert e_expanded.shape == (1, 5, 64), \
            f"Expected e_expanded to have shape (1, 5, 64), got {e_expanded.shape}"

        # Multiply => hat_Z (batch_size, 5, 64)
        hat_Z = z_expanded * e_expanded
        assert hat_Z.shape == (batch_size, 5, 64), \
            f"Expected hat_Z to have shape ({batch_size}, 5, 64), but got {hat_Z.shape}"

        return hat_Z, mu, log_var


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Decoder(nn.Module):
    def __init__(self, input_dim=n_observed, embedding_dim=64, hidden_dims=[]):
        super(Decoder, self).__init__()
        self.input_dim = input_dim      # Number of observed variables (n)
        self.embedding_dim = embedding_dim

        # Assert input dimensions
        assert input_dim == n_observed, f"Expected input_dim to be {n_observed}, but got {input_dim}"
        assert embedding_dim == 64, f"Expected embedding_dim to be 64, but got {embedding_dim}"

        # Learnable query embeddings (e1, e2, ..., e_n)
        self.query_embeddings = nn.Parameter(torch.randn(input_dim, embedding_dim))

        # Assert query_embeddings shape
        assert self.query_embeddings.shape == (n_observed, 64), \
            f"Expected query_embeddings to have shape ({n_observed}, 64), but got {self.query_embeddings.shape}"

        # MultiheadAttention module with 1 head
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, batch_first=True)

        # Layer normalization
        self.layer_norm = nn.LayerNorm(embedding_dim)

        # Define individual MLPs for each observed variable
        dims = [embedding_dim] + hidden_dims + [1]

        # Create MLPs for each observed variable
        self.mlp_layers = nn.ModuleList([
            nn.Sequential(*[
                nn.Linear(dims[i], dims[i+1]) if i == len(dims) - 2 else nn.Sequential(
                    nn.Linear(dims[i], dims[i+1]),
                    nn.ReLU()
                )
                for i in range(len(dims) - 1)
            ])
            for _ in range(input_dim)
        ])

        # Assert we have one MLP per observed variable
        assert len(self.mlp_layers) == n_observed, \
            f"Expected {n_observed} MLPs in mlp_layers, but got {len(self.mlp_layers)}"

        # Verify that MLPs do not share parameters
        mlp_params = [set(mlp.parameters()) for mlp in self.mlp_layers]
        for i in range(len(mlp_params)):
            for j in range(i+1, len(mlp_params)):
                assert mlp_params[i].isdisjoint(mlp_params[j]), \
                    f"MLP {i} and MLP {j} share parameters"

    def forward(self, hat_Z):
        """
        hat_Z: Tensor of shape (batch_size, output_dim, embedding_dim)
        """
        # Assert the shape of hat_Z
        assert hat_Z.dim() == 3, f"Expected hat_Z to be 3D, got {hat_Z.dim()}D."
        batch_size, output_dim, embedding_dim = hat_Z.shape
        assert embedding_dim == 64, \
            f"Expected hat_Z embedding_dim to be 64, but got {embedding_dim}"
        assert output_dim == 5, \
            f"Expected hat_Z output_dim to be 5, but got {output_dim}"

        # Prepare query embeddings and expand to batch size
        query_embeddings = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        assert query_embeddings.shape == (batch_size, self.input_dim, 64), \
            f"Expected query_embeddings to have shape ({batch_size}, {self.input_dim}, 64), got {query_embeddings.shape}"

        # Apply scaled dot-product attention
        attn_output, attn_weights = self.attention(query_embeddings, hat_Z, hat_Z)
        assert attn_output.shape == (batch_size, self.input_dim, 64), \
            f"Expected attn_output to have shape ({batch_size}, {self.input_dim}, 64), got {attn_output.shape}"
        assert attn_weights.shape == (batch_size, self.input_dim, output_dim), \
            f"Expected attn_weights to have shape ({batch_size}, {self.input_dim}, {output_dim}), got {attn_weights.shape}"

        # Add residual connection and apply layer normalization
        out = self.layer_norm(attn_output + query_embeddings)
        assert out.shape == (batch_size, self.input_dim, 64), \
            f"Expected out to have shape ({batch_size}, {self.input_dim}, 64), got {out.shape}"

        # Pass each context vector through its corresponding MLP
        x_hat = []
        for i in range(self.input_dim):
            x_i = out[:, i, :]  # (batch_size, 64)
            x_i_hat = self.mlp_layers[i](x_i)  # (batch_size, 1)
            x_hat.append(x_i_hat)
        x_hat = torch.cat(x_hat, dim=1)  # (batch_size, self.input_dim)

        assert x_hat.shape == (batch_size, self.input_dim), \
            f"Expected x_hat to have shape ({batch_size}, {self.input_dim}), got {x_hat.shape}"

        return x_hat, attn_weights


In [5]:
class Model(nn.Module):
    def __init__(self, input_dim, output_dim, embedding_dim, encoder_hidden_dims=[], decoder_hidden_dims=[]):
        super(Model, self).__init__()
        self.encoder = Encoder(
            input_dim=input_dim,
            output_dim=output_dim,
            embedding_dim=embedding_dim,
            hidden_dims=encoder_hidden_dims
        )
        self.decoder = Decoder(
            input_dim=input_dim,
            embedding_dim=embedding_dim,
            hidden_dims=decoder_hidden_dims
        )

    def forward(self, x):
        # Encoder outputs: hat_Z (reparameterized latent factors), mu, log_var
        hat_Z, mu, log_var = self.encoder(x)
        
        # Decoder reconstructs x from hat_Z
        x_hat, attn_weights = self.decoder(hat_Z)
        
        # Return all necessary components: reconstruction, attention, and encoder outputs for KL divergence
        return x_hat, attn_weights, mu, log_var


In [6]:
import os
import math
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import adjusted_rand_score
from torch.utils.data import DataLoader, TensorDataset

# Assume 'synthetic_data' and 'model' are already defined.
# Also assume 'train_loader' and 'test_loader' are defined from previous code.
# The 'Model' now yields (x_hat, attn_weights, mu, log_var).

# Extract parameters from synthetic_data
input_dim = synthetic_data.train.shape[1]  # Number of observed variables
output_dim = synthetic_data.n_sources      # Number of latent factors
true_labels = np.array(synthetic_data.clusters, dtype=int)
assert np.all(true_labels >= -1), "Some observed variables have invalid cluster labels."

# Define the network dimensions
input_dim = n_observed
output_dim = 5
embedding_dim = 64
encoder_hidden_dims = [128, 64]
decoder_hidden_dims = [64, 32]

# Instantiate the variational model
model = Model(
    input_dim=input_dim,
    output_dim=output_dim,
    embedding_dim=embedding_dim,
    encoder_hidden_dims=encoder_hidden_dims,
    decoder_hidden_dims=decoder_hidden_dims
)

# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else 
                      "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# Reconstruction loss (mean-squared error)
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training parameters
num_epochs = 20
print_every = 1

# 1) Entropy regularizer coefficient
max_lambda_entropy = 0.0  # as before1
use_entropy_regularizer = True

# 2) KL divergence coefficient
max_lambda_kl = 5*1e-2       # you can pick a different value if desired
use_kl_regularizer = True

# Function to schedule entropy weight
def get_lambda_entropy(epoch, num_epochs, max_lambda_entropy, schedule_type='exponential',
                       use_entropy_regularizer=True):
    if not use_entropy_regularizer:
        return 0.0
    if schedule_type == 'constant':
        return max_lambda_entropy
    elif schedule_type == 'linear':
        return max_lambda_entropy * (epoch / num_epochs)
    elif schedule_type == 'exponential':
        k = 5
        numerator = math.exp(k * epoch / num_epochs) - 1
        denominator = math.exp(k) - 1
        return max_lambda_entropy * (numerator / denominator)
    elif schedule_type == 'logarithmic':
        if epoch == 0:
            return 0.0
        else:
            return max_lambda_entropy * math.log(epoch + 1) / math.log(num_epochs + 1)
    else:
        raise ValueError(f"Unknown schedule_type: {schedule_type}")

# Function to schedule KL weight similarly (optional)
def get_lambda_kl(epoch, num_epochs, max_lambda_kl, schedule_type='constant',
                  use_kl_regularizer=True):
    if not use_kl_regularizer:
        return 0.0
    # Example: keep it constant for demonstration
    if schedule_type == 'constant':
        return max_lambda_kl
    elif schedule_type == 'linear':
        return max_lambda_kl * (epoch / num_epochs)
    # etc. — same logic as the entropy schedule
    else:
        return max_lambda_kl

def compute_ari_per_sample(true_labels, predicted_labels):
    """
    Computes the ARI between true_labels and predicted_labels.
    If there are -1 labels (nuisance), they will be filtered out.
    """
    mask = true_labels != -1
    filtered_true = true_labels[mask]
    filtered_pred = predicted_labels[mask]
    if len(filtered_true) == 0:
        return 1.0
    return adjusted_rand_score(filtered_true, filtered_pred)

# Possibly load model
model_path = "trained_model.pth"
if os.path.exists(model_path):
    print("Trained model found. Loading the model.")
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint)
    print("Model loaded successfully!")
else:
    print("No trained model found. Starting from scratch.")

# Initialize a list to store average attention matrices per epoch
attention_matrices = []

# For normalizing attention entropy
ent_norm = 1.0 / (input_dim * math.log(output_dim))

# Lists to store ARIs
train_ari_list = []
test_ari_list = []

# ------------------ Training Loop ------------------
for epoch in range(num_epochs):
    # Compute current lambda for entropy and KL
    lambda_entropy = get_lambda_entropy(
        epoch, num_epochs, max_lambda_entropy, schedule_type='exponential',
        use_entropy_regularizer=use_entropy_regularizer)
    lambda_kl = get_lambda_kl(
        epoch, num_epochs, max_lambda_kl, schedule_type='constant',
        use_kl_regularizer=use_kl_regularizer)

    # ------ Training phase ------
    model.train()
    running_loss = 0.0
    running_recon_loss = 0.0
    running_kl_loss = 0.0
    running_entropy_loss = 0.0

    epoch_attn_weights = []
    epoch_ari = []

    for batch_idx, (batch,) in enumerate(train_loader):
        batch = batch.to(device)
        batch_size = batch.size(0)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass => (x_hat, attn_weights, mu, log_var)
        x_hat, attn_weights, mu, log_var = model(batch)

        # Compute MSE reconstruction loss (average over batch for convenience)
        recon_loss = criterion(x_hat, batch)

        # KL divergence: -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
        # Usually aggregated per sample => we can average or sum
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        kl_loss = kl_loss / batch_size  # average

        # Entropy regularizer (attention)
        entropy_reg = 0.0
        if use_entropy_regularizer:
            epsilon = 1e-8
            # attn_weights shape: (batch_size, input_dim, output_dim)
            # sum over output_dim => sum over each row in the last dimension
            entropy = -torch.sum(attn_weights * torch.log(attn_weights + epsilon), dim=2)
            # sum entropies over queries -> average over batch
            entropy_reg = ent_norm * torch.mean(torch.sum(entropy, dim=1))

        # Total loss
        loss = recon_loss + lambda_kl * kl_loss + lambda_entropy * entropy_reg

        # Backward + optimize
        loss.backward()
        optimizer.step()

        # Accumulate stats
        running_loss += loss.item()
        running_recon_loss += recon_loss.item()
        running_kl_loss += kl_loss.item()
        running_entropy_loss += entropy_reg.item()

        # Compute ARI per sample
        attn_np = attn_weights.detach().cpu().numpy()
        batch_ari_vals = []
        for i in range(batch_size):
            pred_labels = np.argmax(attn_np[i], axis=1)
            ari_val = compute_ari_per_sample(true_labels, pred_labels)
            batch_ari_vals.append(ari_val)
        avg_ari_batch = np.mean(batch_ari_vals)
        epoch_ari.append(avg_ari_batch)

        # Save attention for average attention matrix
        epoch_attn_weights.append(attn_weights.detach().cpu())

    # Compute epoch-level stats
    avg_train_loss = running_loss / len(train_loader)
    avg_train_recon_loss = running_recon_loss / len(train_loader)
    avg_train_kl_loss = running_kl_loss / len(train_loader)
    avg_train_entropy_loss = running_entropy_loss / len(train_loader)
    avg_train_ari = np.mean(epoch_ari)
    train_ari_list.append(avg_train_ari)

    # Compute average attention matrix for the epoch
    epoch_attn_weights_tensor = torch.cat(epoch_attn_weights, dim=0)
    avg_attn_weights_epoch = epoch_attn_weights_tensor.mean(dim=0)  # shape (input_dim, output_dim)
    attention_matrices.append(avg_attn_weights_epoch.cpu().numpy().T)

    # ------ Testing phase ------
    model.eval()
    test_loss = 0.0
    test_recon_loss = 0.0
    test_kl_loss = 0.0
    test_entropy_loss = 0.0
    epoch_ari_test = []

    with torch.no_grad():
        for batch_idx, (batch,) in enumerate(test_loader):
            batch = batch.to(device)
            batch_size = batch.size(0)

            x_hat, attn_weights, mu, log_var = model(batch)

            recon_loss = criterion(x_hat, batch)

            # KL
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            kl_loss = kl_loss / batch_size

            # Entropy
            entropy_reg = 0.0
            if use_entropy_regularizer:
                epsilon = 1e-8
                entropy = -torch.sum(attn_weights * torch.log(attn_weights + epsilon), dim=2)
                entropy_reg = ent_norm * torch.mean(torch.sum(entropy, dim=1))

            # total
            loss = recon_loss + lambda_kl * kl_loss + lambda_entropy * entropy_reg

            test_loss += loss.item()
            test_recon_loss += recon_loss.item()
            test_kl_loss += kl_loss.item()
            test_entropy_loss += entropy_reg.item()

            # ARI
            attn_np = attn_weights.detach().cpu().numpy()
            batch_ari_vals = []
            for i in range(batch_size):
                pred_labels = np.argmax(attn_np[i], axis=1)
                ari_val = compute_ari_per_sample(true_labels, pred_labels)
                batch_ari_vals.append(ari_val)
            avg_ari_batch = np.mean(batch_ari_vals)
            epoch_ari_test.append(avg_ari_batch)

        avg_test_loss = test_loss / len(test_loader)
        avg_test_recon_loss = test_recon_loss / len(test_loader)
        avg_test_kl_loss = test_kl_loss / len(test_loader)
        avg_test_entropy_loss = test_entropy_loss / len(test_loader)
        avg_test_ari = np.mean(epoch_ari_test)
        test_ari_list.append(avg_test_ari)

    # Print results
    if (epoch + 1) % print_every == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"lambda_entropy={lambda_entropy:.6f}, lambda_kl={lambda_kl:.6f}, "
              f"Train: Loss={avg_train_loss:.4f}, Recon={avg_train_recon_loss:.4f}, KL={avg_train_kl_loss:.4f}, Entropy={avg_train_entropy_loss:.4f}, ARI={avg_train_ari:.4f} | "
              f"Test: Loss={avg_test_loss:.4f}, Recon={avg_test_recon_loss:.4f}, KL={avg_test_kl_loss:.4f}, Entropy={avg_test_entropy_loss:.4f}, ARI={avg_test_ari:.4f}")

# Optionally save model
# torch.save(model.state_dict(), "trained_model.pth")
# print("Training complete and model saved.")


Using device: mps
No trained model found. Starting from scratch.
Epoch [1/20], lambda_entropy=0.000000, lambda_kl=0.050000, Train: Loss=0.0152, Recon=0.0150, KL=0.0040, Entropy=0.9692, ARI=0.0264 | Test: Loss=0.0144, Recon=0.0144, KL=0.0000, Entropy=0.9888, ARI=0.0426
Epoch [2/20], lambda_entropy=0.000000, lambda_kl=0.050000, Train: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9901, ARI=0.0478 | Test: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9915, ARI=0.0337
Epoch [3/20], lambda_entropy=0.000000, lambda_kl=0.050000, Train: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9922, ARI=0.0371 | Test: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9884, ARI=0.0526
Epoch [4/20], lambda_entropy=0.000000, lambda_kl=0.050000, Train: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9919, ARI=0.0397 | Test: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9933, ARI=0.0341
Epoch [5/20], lambda_entropy=0.000000, lambda_kl=0.050000, Train: Loss=0.0142, Recon=0.0142, KL=0.0000, Entropy=0.9945,