In [25]:
import csv
import torch
import random
import os
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
from tqdm import tqdm
from rich.progress import Progress
import torch.nn.functional as F

In [26]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        # Initialize the codebook (embedding vectors)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, x):
        # Flatten input to (batch_size * num_features, embedding_dim)
        flat_x = x.view(-1, self.embedding_dim)

        # Compute distances between input and embedding vectors
        distances = torch.cdist(flat_x, self.embedding.weight, p=2)  # L2 distance
        indices = torch.argmin(distances, dim=1)  # Closest embedding index

        # Get quantized vectors
        quantized = self.embedding(indices).view(x.shape)

        # Get the unique values and create a mapping
        unique_vals, inverse_indices = torch.unique(
            quantized, sorted=True, return_inverse=True
        )

        # # Map each unique value to an integer from 0 to num_unique - 1
        # integer_mapped_tensor = inverse_indices.view(quantized.shape)
        # quantized = integer_mapped_tensor.float()

        # Compute commitment loss
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Add quantization noise during backward pass
        quantized = x + (quantized - x).detach()
        return quantized, indices, loss


class VQVAE(nn.Module):
    def __init__(
        self,
        input_dim,
        latent_dim,
        num_embeddings,
        dropout=0.0,
        use_batch_norm=False,
        commitment_cost=0.25,
    ):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128) if use_batch_norm else nn.Identity(),
            nn.Dropout(dropout),
            nn.Linear(128, latent_dim),
            nn.ReLU(),
        )
        self.quantizer = VectorQuantizer(num_embeddings, latent_dim, commitment_cost)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128) if use_batch_norm else nn.Identity(),
            nn.Dropout(dropout),
            nn.Linear(128, input_dim),
            nn.Sigmoid(),  # Assuming input values are normalized to [0, 1]
        )

    def forward(self, x):
        # Encode
        latent = self.encoder(x)
        # Quantize
        quantized, indices, quantization_loss = self.quantizer(latent)
        # Decode
        reconstructed = self.decoder(quantized)
        return latent, quantized, reconstructed, quantization_loss


In [27]:
# Custom dataset class to handle CSV input and padding
class TreeDataset(Dataset):
    def __init__(self, csv_file, padding_length=161):
        # Load data from CSV
        self.padding_length = padding_length
        self.data = self._load_csv(csv_file)

    def _min_max_normalize(self, array, min_val=0, max_val=100):
        """
        Normalize array to the range [0, 1] based on given min_val and max_val.
        """
        array = np.array(array, dtype=np.float32)
        return (array - min_val) / (max_val - min_val)

    def _load_csv(self, csv_file):
        """Read CSV file using Python's built-in csv module."""
        data = []
        with open(csv_file, "r") as csvfile:
            reader = csv.reader(csvfile)
            with Progress() as progress:
                task = progress.add_task(
                    "[cyan]Processing CSV...", total=sum(1 for _ in csvfile)
                )  # Total rows in file
                csvfile.seek(0)  # Reset file pointer

                for row in reader:
                    row = list(map(int, row))
                    normalized_row = self._min_max_normalize(row)
                    data.append(normalized_row)
                    # Update progress bar
                    progress.update(task, advance=1)
        return data

    def __len__(self):
        """Return the length of the dataset (number of rows)."""
        return len(self.data)

    def __getitem__(self, idx):
        """Retrieve a single data point from the dataset."""
        return self.data[idx]

In [28]:
def set_seed_for_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.enabled = False

    # Limit NumPy threads
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["PYTHONHASHSEED"] = str(seed)

In [29]:
def train_model(model, optimizer, epochs, train_loader, device):
    print(f"Training the Autoencoder, Total epochs: {epochs}")
    for epoch in range(epochs):
        model.train()
        epoch_reconstruction_loss = 0
        epoch_quantization_loss = 0
        epoch_total_loss = 0
        epoch_rmse = 0
        total_samples = 0
        for batch in tqdm(
            train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", unit="batch"
        ):
            batch = batch.to(device)
            optimizer.zero_grad()
            latent, quantized, reconstructed, quantization_loss = model(batch)
            reconstruction_loss = F.mse_loss(
                reconstructed, batch
            )  # Reconstruction loss
            rmse = torch.sqrt(reconstruction_loss)
            total_loss = reconstruction_loss + quantization_loss
            total_loss.backward()
            optimizer.step()

            unique_values, counts = torch.unique(quantized, return_counts=True)

            # Accumulate loss
            batch_size = batch.size(0)
            total_samples += batch_size

            epoch_reconstruction_loss += reconstruction_loss.item() * batch_size
            epoch_quantization_loss += quantization_loss.item() * batch_size
            epoch_total_loss += total_loss.item() * batch_size
            epoch_rmse += rmse.item() * batch_size
            # Calculate and accumulate RMSE

        # Compute average loss and RMSE over all samples
        epoch_total_loss /= total_samples
        epoch_reconstruction_loss /= total_samples
        epoch_quantization_loss /= total_samples
        epoch_rmse /= total_samples

        print(
            f"Epoch [{epoch+1}/{epochs}], Reconstruction Loss: {epoch_reconstruction_loss:.6f}, "
            f"Quantization Loss: {epoch_quantization_loss:.6f}, Total Loss: {epoch_total_loss:.6f}, RMSE: {epoch_rmse:.6f}"
        )

In [30]:
def validate_model(model, val_loader, device):
    """
    Validate the VQ-VAE model on the validation set.

    Args:
        model: The VQ-VAE model.
        val_loader: DataLoader for validation data.

    Returns:
        val_total_loss: Average total loss over the validation set.
        val_reconstruction_loss: Average reconstruction loss.
        val_quantization_loss: Average quantization loss.
        val_rmse: Root Mean Squared Error (RMSE) based on reconstruction loss.
    """
    model.eval()  # Set model to evaluation mode
    val_reconstruction_loss = 0.0
    val_quantization_loss = 0.0
    val_total_loss = 0.0
    val_rmse = 0.0
    total_samples = 0
    latent_representations = []

    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(val_loader, desc="Validating...", unit="batch"):
            batch = batch.to(device)
            # Move data to appropriate device
            x = batch.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

            # Forward pass
            latent, quantized, reconstructed, quantization_loss = model(x)
            # Compute reconstruction loss and RMSE
            reconstruction_loss = F.mse_loss(reconstructed, x)
            rmse = torch.sqrt(reconstruction_loss)

            # Compute total loss
            total_loss = reconstruction_loss + quantization_loss

            # Accumulate losses weighted by batch size
            batch_size = x.size(0)
            total_samples += batch_size
            val_reconstruction_loss += reconstruction_loss.item() * batch_size
            val_quantization_loss += quantization_loss.item() * batch_size
            val_total_loss += total_loss.item() * batch_size
            val_rmse += rmse.item() * batch_size

            # Store latent representations
            latent_representations.append(latent)

    # Compute average losses and RMSE over all samples
    val_reconstruction_loss /= total_samples
    val_quantization_loss /= total_samples
    val_total_loss /= total_samples
    val_rmse /= total_samples

    # Concatenate latent representations
    latent_representations = torch.cat(latent_representations).cpu().numpy()

    # Print validation metrics
    print(f"Latent representations shape: {latent_representations.shape}")
    print(
        f"Validation Reconstruction Loss: {val_reconstruction_loss:.6f}, "
        f"Validation Quantization Loss: {val_quantization_loss:.6f}, "
        f"Validation Total Loss: {val_total_loss:.6f}, "
        f"Validation RMSE: {val_rmse:.6f}"
    )

    return val_rmse


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

seed = 1234
set_seed_for_everything(seed=seed)
batch_size = 32

# Load dataset
csv_file = "data/train_data.csv"  # Replace with your CSV file path
dataset = TreeDataset(csv_file)
print(f"Dataset size: {len(dataset)}")
train_size = int(0.8 * len(dataset))
print(f"Train size: {train_size}")
val_size = len(dataset) - train_size
print(f"Validation size: {val_size}")
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=lambda worker_id: np.random.seed(seed),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    worker_init_fn=lambda worker_id: np.random.seed(seed),
)

Output()

Using device: cuda


Dataset size: 851229
Train size: 680983
Validation size: 170246


In [33]:
# Define model, optimizer, and number of epochs
input_dim = 161  # Example input dimension (e.g., flattened 28x28 images)
latent_dim = 64  # Latent dimension
num_embeddings = 64  # Number of discrete codes
epochs = 10  # Number of epochs
lr = 0.0001
dropout = 0
weight_decay = 0

model = VQVAE(
    input_dim, latent_dim, num_embeddings, dropout=dropout, use_batch_norm=True
)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [34]:
train_model(model, optimizer, epochs, train_loader, device)
val_rmse = validate_model(model, val_loader, device)

Training the Autoencoder, Total epochs: 10


Epoch [1/10]: 100%|██████████| 21281/21281 [00:26<00:00, 814.94batch/s]


Epoch [1/10], Reconstruction Loss: 0.006435, Quantization Loss: 0.015624, Total Loss: 0.022059, RMSE: 0.031765


Epoch [2/10]: 100%|██████████| 21281/21281 [00:26<00:00, 808.48batch/s]


Epoch [2/10], Reconstruction Loss: 0.000075, Quantization Loss: 0.000207, Total Loss: 0.000282, RMSE: 0.008575


Epoch [3/10]: 100%|██████████| 21281/21281 [00:26<00:00, 808.06batch/s]


Epoch [3/10], Reconstruction Loss: 0.000072, Quantization Loss: 0.000219, Total Loss: 0.000291, RMSE: 0.008431


Epoch [4/10]: 100%|██████████| 21281/21281 [00:25<00:00, 821.94batch/s]


Epoch [4/10], Reconstruction Loss: 0.000071, Quantization Loss: 0.000224, Total Loss: 0.000295, RMSE: 0.008371


Epoch [5/10]: 100%|██████████| 21281/21281 [00:26<00:00, 808.65batch/s]


Epoch [5/10], Reconstruction Loss: 0.000071, Quantization Loss: 0.000225, Total Loss: 0.000296, RMSE: 0.008332


Epoch [6/10]: 100%|██████████| 21281/21281 [00:26<00:00, 816.05batch/s]


Epoch [6/10], Reconstruction Loss: 0.000070, Quantization Loss: 0.000223, Total Loss: 0.000294, RMSE: 0.008305


Epoch [7/10]: 100%|██████████| 21281/21281 [00:26<00:00, 807.91batch/s]


Epoch [7/10], Reconstruction Loss: 0.000070, Quantization Loss: 0.000220, Total Loss: 0.000290, RMSE: 0.008288


Epoch [8/10]: 100%|██████████| 21281/21281 [00:26<00:00, 815.15batch/s]


Epoch [8/10], Reconstruction Loss: 0.000070, Quantization Loss: 0.000218, Total Loss: 0.000288, RMSE: 0.008276


Epoch [9/10]: 100%|██████████| 21281/21281 [00:26<00:00, 812.06batch/s]


Epoch [9/10], Reconstruction Loss: 0.000070, Quantization Loss: 0.000221, Total Loss: 0.000290, RMSE: 0.008270


Epoch [10/10]: 100%|██████████| 21281/21281 [00:26<00:00, 812.46batch/s]


Epoch [10/10], Reconstruction Loss: 0.000069, Quantization Loss: 0.000221, Total Loss: 0.000290, RMSE: 0.008256


Validating...: 100%|██████████| 5321/5321 [00:02<00:00, 1775.18batch/s]

Latent representations shape: (170246, 64)
Validation Reconstruction Loss: 0.000065, Validation Quantization Loss: 0.000201, Validation Total Loss: 0.000265, Validation RMSE: 0.007978





In [57]:
# Save the model
torch.save(model.state_dict(), "VQVAE_64dim_BN_64n.pth")