In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import os

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the Generator and Discriminator architectures

class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

# Parameters
input_dim = 100
output_dim = 10  # Assuming user profiles have 10 features
num_epochs = 5000
batch_size = 64
learning_rate = 0.0002
beta1 = 0.5

# Load the dataset
data_path = '../data/processed/synthetic_user_profiles.csv'
data = pd.read_csv(data_path)
data = data.values

# Normalize data
data = (data - data.min()) / (data.max() - data.min())

# Create DataLoader
tensor_data = torch.tensor(data, dtype=torch.float32)
dataset = TensorDataset(tensor_data)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the models
generator = Generator(input_dim=input_dim, output_dim=output_dim).to(device)
discriminator = Discriminator(input_dim=output_dim).to(device)

# Loss function
criterion = nn.BCELoss()

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Training loop
losses_g = []
losses_d = []
for epoch in range(num_epochs):
    for real_data in dataloader:
        real_data = real_data[0].to(device)
        batch_size = real_data.size(0)
        
        # Create labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # Train Discriminator
        optimizer_d.zero_grad()
        outputs = discriminator(real_data)
        d_loss_real = criterion(outputs, real_labels)
        
        noise = torch.randn(batch_size, input_dim).to(device)
        fake_data = generator(noise)
        outputs = discriminator(fake_data.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()
        
        # Train Generator
        optimizer_g.zero_grad()
        noise = torch.randn(batch_size, input_dim).to(device)
        fake_data = generator(noise)
        outputs = discriminator(fake_data)
        g_loss = criterion(outputs, real_labels)
        
        g_loss.backward()
        optimizer_g.step()
        
    losses_g.append(g_loss.item())
    losses_d.append(d_loss.item())
    
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}]  Loss D: {d_loss.item()}, Loss G: {g_loss.item()}")

# Save the trained models
os.makedirs('../models', exist_ok=True)
torch.save(generator.state_dict(), '../models/generator.pth')
torch.save(discriminator.state_dict(), '../models/discriminator.pth')

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(losses_g, label='Generator Loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
