In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and preprocess your dataset
url = 'https://raw.githubusercontent.com/gregversteeg/LinearCorex/master/tests/data/test_big5.csv'
df = pd.read_csv(url)
df = df / 4.0  # Normalize to [0, 1]
data = df.values.astype(np.float32)  # Shape: (num_samples, 50)

# Convert data to torch tensors
data_tensor = torch.tensor(data)
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [2]:
# Define the Vector Quantizer class
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)
    
    def forward(self, inputs):
        # Flatten input
        inputs = inputs.view(-1, self.embedding_dim)
        
        # Compute distances
        distances = (torch.sum(inputs ** 2, dim=1, keepdim=True)
                     + torch.sum(self.embedding.weight ** 2, dim=1)
                     - 2 * torch.matmul(inputs, self.embedding.weight.t()))
        
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize
        quantized = torch.matmul(encodings, self.embedding.weight)
        quantized = quantized.view(inputs.shape)
        
        # Loss
        e_latent_loss = torch.mean((quantized.detach() - inputs) ** 2)
        q_latent_loss = torch.mean((quantized - inputs.detach()) ** 2)
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        
        # Compute perplexity
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        return loss, quantized, perplexity


In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, embedding_dim):
        super(Encoder, self).__init__()
        self.fc_layers = nn.Sequential()
        in_dim = input_dim
        for i, h_dim in enumerate(hidden_dims):
            self.fc_layers.add_module(f"fc_{i}", nn.Linear(in_dim, h_dim))
            self.fc_layers.add_module(f"relu_{i}", nn.ReLU())
            in_dim = h_dim
        # Output layer with embedding_dim now set to 5
        self.fc_layers.add_module("fc_latent", nn.Linear(in_dim, embedding_dim))
        
    def forward(self, x):
        z_e = self.fc_layers(x)
        return z_e

In [4]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hidden_dims, embedding_dim):
        super(Decoder, self).__init__()
        self.fc_layers = nn.Sequential()
        in_dim = embedding_dim  # Input dimension is now 5
        for i, h_dim in enumerate(hidden_dims):
            self.fc_layers.add_module(f"fc_{i}", nn.Linear(in_dim, h_dim))
            self.fc_layers.add_module(f"relu_{i}", nn.ReLU())
            in_dim = h_dim
        # Output layer to reconstruct input_dim features
        self.fc_layers.add_module("fc_output", nn.Linear(in_dim, output_dim))
        self.fc_layers.add_module("sigmoid", nn.Sigmoid())  # Output between 0 and 1

    def forward(self, z_q):
        x_recon = self.fc_layers(z_q)
        return x_recon

In [5]:
class VQVAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, embedding_dim, num_embeddings, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims, embedding_dim)
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        self.decoder = Decoder(input_dim, hidden_dims[::-1], embedding_dim)
    
    def forward(self, x):
        z_e = self.encoder(x)
        loss, z_q, perplexity = self.quantizer(z_e)
        x_recon = self.decoder(z_q)
        return loss, x_recon, perplexity

In [6]:
# Model Hyperparameters
input_dim = data.shape[1]  # Should be 50 for your dataset
hidden_dims = [128, 64]
embedding_dim = 5          
num_embeddings = 128
commitment_cost = 0.25

# Instantiate the model with embedding_dim set to 5
model = VQVAE(input_dim, hidden_dims, embedding_dim, num_embeddings, commitment_cost)
model = model.to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# Training parameters
num_epochs = 100
log_interval = 10

# Training Loop
for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0.0
    for batch_idx, (data_batch,) in enumerate(data_loader):
        data_batch = data_batch.to(device)
        
        optimizer.zero_grad()
        vq_loss, recon_batch, perplexity = model(data_batch)
        
        recon_loss = ((data_batch - recon_batch) ** 2).mean()
        loss = recon_loss + vq_loss
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx % log_interval == 0:
            print(f'Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(data_loader)}] '
                  f'Loss: {loss.item():.4f}, Recon Loss: {recon_loss.item():.4f}, '
                  f'VQ Loss: {vq_loss.item():.4f}, Perplexity: {perplexity.item():.4f}')
    
    avg_loss = train_loss / len(data_loader)
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')

# # Save the trained model
# torch.save(model.state_dict(), 'vqvae_model.pth')
# print("Training complete and model saved.")

Epoch [1/100] Batch [0/63] Loss: 0.1278, Recon Loss: 0.1214, VQ Loss: 0.0064, Perplexity: 2.2360
Epoch [1/100] Batch [10/63] Loss: 0.1217, Recon Loss: 0.1211, VQ Loss: 0.0006, Perplexity: 13.0288
Epoch [1/100] Batch [20/63] Loss: 0.1143, Recon Loss: 0.1140, VQ Loss: 0.0003, Perplexity: 12.5305
Epoch [1/100] Batch [30/63] Loss: 0.1087, Recon Loss: 0.1083, VQ Loss: 0.0003, Perplexity: 5.4630
Epoch [1/100] Batch [40/63] Loss: 0.1110, Recon Loss: 0.1103, VQ Loss: 0.0007, Perplexity: 1.6679
Epoch [1/100] Batch [50/63] Loss: 0.1030, Recon Loss: 0.1026, VQ Loss: 0.0004, Perplexity: 1.6679
Epoch [1/100] Batch [60/63] Loss: 0.1050, Recon Loss: 0.1042, VQ Loss: 0.0008, Perplexity: 1.7548
====> Epoch: 1 Average loss: 0.1100
Epoch [2/100] Batch [0/63] Loss: 0.0908, Recon Loss: 0.0899, VQ Loss: 0.0009, Perplexity: 1.6202
Epoch [2/100] Batch [10/63] Loss: 0.1041, Recon Loss: 0.1028, VQ Loss: 0.0013, Perplexity: 1.6910
Epoch [2/100] Batch [20/63] Loss: 0.0937, Recon Loss: 0.0924, VQ Loss: 0.0013, Per