In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import scipy.sparse as sp
import numpy as np

In [7]:
# Custom Dataset for 4D Noisy Matrix
class NoisyMatrixDataset(Dataset):
    def __init__(self, original_matrix, noise_level=0.1, num_samples=1000):
        self.original_matrix = original_matrix
        self.noise_level = noise_level
        self.num_samples = num_samples
        self.data = self.generate_noisy_samples()

    def generate_noisy_samples(self):
        noisy_samples = []
        for _ in range(self.num_samples):
            noisy_sample = self.original_matrix.clone()
            noise = torch.rand_like(noisy_sample)
            mask = noise < self.noise_level
            noisy_sample[mask] = 1 - noisy_sample[mask]  # Flip bits where mask is True (0 to 1 or 1 to 0)
            noisy_samples.append(noisy_sample)
        return torch.stack(noisy_samples)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx].view(-1), self.original_matrix

In [8]:
# Define the Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, z_dim)       # Mean of latent variable
        self.fc_logvar = nn.Linear(hidden_dim, z_dim)   # Log variance of latent variable

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

# Define the Decoder
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = torch.relu(self.fc1(z))
        x_reconstructed = torch.sigmoid(self.fc2(h))
        return x_reconstructed

# Define the VAE
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, z_dim)
        self.decoder = Decoder(z_dim, hidden_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, mu, logvar

In [9]:
# Loss function for VAE
def vae_loss(x, x_reconstructed, mu, logvar):
    reconstruction_loss = nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence

In [15]:
# Generate initial 4D matrix and create dataset
initial_matrix = torch.randint(0, 2, (10, 4, 5, 6)).float()  # Example initial 4D binary matrix
noise_level = 0.1
num_samples = 10000
dataset = NoisyMatrixDataset(initial_matrix, noise_level=noise_level, num_samples=num_samples)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [16]:
# Model parameters
input_dim = initial_matrix.numel()
hidden_dim = 400
z_dim = 20
input_dim

1200

In [17]:
# Set specific GPU device if available
gpu_device = 0  # Change this value to set a specific GPU device
device = torch.device(f'cuda:{gpu_device}' if torch.cuda.is_available() else 'cpu')
device


device(type='cpu')

In [20]:
# Instantiate model, optimizer, and training loop
vae = VAE(input_dim, hidden_dim, z_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3, weight_decay=0.1)

In [21]:
# Training loop
epochs = 10
for epoch in range(epochs):
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device).float()
        optimizer.zero_grad()
        x_reconstructed, mu, logvar = vae(data)
        loss = vae_loss(data, x_reconstructed, mu, logvar)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

print("Training complete.")

Epoch [1/10], Loss: 6315.6152
Epoch [2/10], Loss: 6500.3955
Epoch [3/10], Loss: 6426.1162
Epoch [4/10], Loss: 6233.6528
Epoch [5/10], Loss: 6208.0840
Epoch [6/10], Loss: 6133.7461
Epoch [7/10], Loss: 6305.9321
Epoch [8/10], Loss: 6136.2178
Epoch [9/10], Loss: 6167.4648
Epoch [10/10], Loss: 6279.9175
Training complete.
