In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

LEARNING_RATE = 1e-5  # Lower learning rate for stability
BATCH_SIZE = 256
EPOCHS = 5
MAX_GRAD_NORM = 5.0

# -----------------------
# Load and Preprocess Data
# (Replace 'data.csv' with your actual dataset)
# -----------------------
data = pd.read_csv('ARP_MitM_dataset.csv', header=None).values  # Example dataset
# Example: If you have labels: labels = pd.read_csv('labels.csv', header=None).values.flatten()

# Check for NaNs and Infs in the data
print("Checking dataset for NaNs and Infs...")
if np.isnan(data).any():
    print("WARNING: NaNs found in the dataset!")
else:
    print("No NaNs in dataset.")
if np.isinf(data).any():
    print("WARNING: Infs found in the dataset!")
else:
    print("No Infs in dataset.")

# Optional: Scaling
scaler = StandardScaler()
data = scaler.fit_transform(data)

# Confirm again after scaling
if np.isnan(data).any():
    print("WARNING: NaNs found after scaling!")
if np.isinf(data).any():
    print("WARNING: Infs found after scaling!")

# -----------------------
# Simple Dataset
# -----------------------
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data.astype(np.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = SimpleDataset(data)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# -----------------------
# VAE Model
# -----------------------
class VAE(nn.Module):
    def __init__(self, input_dim=100, hidden_dim=64, latent_dim=16):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)
        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.relu(self.fc2(z))
        return self.fc3(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    # Add epsilon to prevent exp(logvar) causing NaNs
    eps = 1e-8
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='mean')
    # More stable KL: ensure logvar doesn't cause NaNs
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - (logvar.exp() + eps))
    return recon_loss + kld_loss

# -----------------------
# Initialize Model
# -----------------------
input_dim = data.shape[1]
vae = VAE(input_dim=input_dim, hidden_dim=64, latent_dim=16).to(DEVICE)
optimizer = torch.optim.Adam(vae.parameters(), lr=LEARNING_RATE)

# -----------------------
# Debugging Before Training
# -----------------------
# Test a forward pass on a small batch
# test_batch = torch.tensor(data[:10], dtype=torch.float32).to(DEVICE)
# vae.eval()
# with torch.no_grad():
#     recon_x_test, mu_test, logvar_test = vae(test_batch)
#     print("Test forward pass outputs:")
#     print("recon_x_test:", recon_x_test)
#     print("mu_test:", mu_test)
#     print("logvar_test:", logvar_test)
#     if torch.isnan(recon_x_test).any():
#         print("NaN found in recon_x_test before training.")
#     if torch.isnan(mu_test).any():
#         print("NaN found in mu_test before training.")
#     if torch.isnan(logvar_test).any():
#         print("NaN found in logvar_test before training.")

# vae.train()

# -----------------------
# Training Loop with Debugging
# -----------------------
for epoch in range(EPOCHS):
    total_loss = 0.0
    for i, batch in enumerate(dataloader):
        batch = batch.to(DEVICE)
        
        # Check input batch for NaNs
        if torch.isnan(batch).any():
            print(f"NaN in input batch at iteration {i}, epoch {epoch}")
        
        optimizer.zero_grad()
        recon_x, mu, logvar = vae(batch)

        # Check for NaNs in the forward pass
        if torch.isnan(recon_x).any():
            print(f"NaN in recon_x at iteration {i}, epoch {epoch}")
        if torch.isnan(mu).any():
            print(f"NaN in mu at iteration {i}, epoch {epoch}")
        if torch.isnan(logvar).any():
            print(f"NaN in logvar at iteration {i}, epoch {epoch}")

        loss = vae_loss(recon_x, batch, mu, logvar)

        # Check for NaNs in loss
        if torch.isnan(loss):
            print(f"NaN in loss at iteration {i}, epoch {epoch}")

        loss.backward()

        # Check gradients for NaNs before optimizer step
        for name, param in vae.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f"NaN gradient in {name} at iteration {i}, epoch {epoch}")

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(vae.parameters(), MAX_GRAD_NORM)
        
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Avg Loss: {avg_loss:.4f}")

    # Check parameter stats after each epoch
    with torch.no_grad():
        for name, param in vae.named_parameters():
            if param.requires_grad:
                if torch.isnan(param).any():
                    print(f"NaN in parameters {name} at epoch {epoch}")
                # Print a summary of parameter stats
                print(f"{name} stats: mean={param.mean().item():.4f}, std={param.std().item():.4f}, "
                      f"max={param.max().item():.4f}, min={param.min().item():.4f}")

# If the code runs without printing NaNs, you have stable training.
# Otherwise, the print statements should guide you to where the NaNs occur.