# Generating Synthetic Industrial Data with a Hybrid GAN

Welcome! In this notebook, we're going to tackle a common problem in data science: not having enough data. Our solution? We'll teach an AI to create realistic, synthetic data for us.

We'll be working with a simulated dataset for predictive maintenance in machines. Our approach is a bit of a hybrid: we'll use a ready-made model called **CTGAN** (which is great for tabular data) and also build our own **WGAN** (Wasserstein GAN) from scratch to see how they compare and combine.

### Step 1: Getting Our Tools Ready

First things first, let's import all the Python libraries we'll need for this experiment. If you haven't installed them yet, you can use the `requirements.txt` file in this repository.

In [ ]:
# For data handling
import pandas as pd
import numpy as np

# For building our own GAN with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# For the pre-built CTGAN model
from ctgan import CTGAN

# For data scaling and visualization
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import seaborn as sns

print("All libraries are ready to go!")

### Step 2: Creating Our "Real" Dataset

Before we can generate fake data, we need some real data to learn from. Since we don't have a real industrial dataset on hand, we'll simulate one. This will be our "ground truth."

We'll also scale the numerical features to a range between 0 and 1. This is a common practice that helps GANs train more effectively.

In [ ]:
# Let's create some predictable randomness for reproducibility
np.random.seed(42)

# Building a DataFrame with simulated machine sensor data
real_industrial_data = pd.DataFrame({
    'temperature': np.random.uniform(20, 100, 1000),
    'pressure': np.random.uniform(1, 10, 1000),
    'vibration': np.random.uniform(0.1, 5.0, 1000),
    'machine_age': np.random.randint(1, 10, 1000),
    'failure': np.random.choice([0, 1], size=1000) # 0 = No Failure, 1 = Failure
})

# GANs work best with normalized data, so let's scale our features to be between 0 and 1.
scaler = MinMaxScaler()
scaled_real_data = pd.DataFrame(scaler.fit_transform(real_industrial_data), columns=real_industrial_data.columns)

print("Here's a peek at our scaled 'real' data:")
scaled_real_data.head()

### Step 3: Training the CTGAN

Now for our first model. CTGAN is specialized for generating tabular data like ours. We just need to point it to our data, tell it which columns are discrete (like our 'failure' column), and let it train.

In [ ]:
# Initialize the CTGAN model. We'll train it for 50 epochs.
ctgan_model = CTGAN(epochs=50)

# Train the model on our scaled data. We need to tell it which columns are categorical/discrete.
ctgan_model.fit(scaled_real_data, discrete_columns=['failure'])

# Let's generate 500 new samples!
ctgan_synthetic_samples = ctgan_model.sample(500)

### Step 4: Building and Training Our Own WGAN

This is the fun part! We're building a WGAN from scratch. A GAN has two main parts:
- **The Generator**: The "artist" that tries to create realistic-looking data from random noise.
- **The Discriminator**: The "critic" that tries to tell the difference between real data and the Generator's fake data.

They train in a cat-and-mouse game. The Generator gets better at making fakes, and the Discriminator gets better at spotting them. Over time, the Generator becomes a master forger, creating very realistic data.

In [ ]:
# A helper class to prepare our data for PyTorch
class IndustrialDataset(Dataset):
    def __init__(self, data):
        self.data = torch.tensor(data.values, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# The Generator network - our 'artist'
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim),
            nn.Tanh() # Tanh ensures the output is between -1 and 1 (we can adjust later if needed)
        )

    def forward(self, z):
        return self.model(z)

# The Discriminator network - our 'critic'
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1) # Outputs a single score indicating how 'real' it thinks the data is
        )

    def forward(self, x):
        return self.model(x)

# --- Training Setup ---
# Create instances of our models
latent_space_dim = 10
data_dim = scaled_real_data.shape[1]
wgan_generator = Generator(input_dim=latent_space_dim, output_dim=data_dim)
wgan_discriminator = Discriminator(input_dim=data_dim)

# Set up the optimizers
optimizer_G = optim.RMSprop(wgan_generator.parameters(), lr=0.00005)
optimizer_D = optim.RMSprop(wgan_discriminator.parameters(), lr=0.00005)

# Prepare the data loader
real_data_torch = IndustrialDataset(scaled_real_data)
loader = DataLoader(real_data_torch, batch_size=64, shuffle=True)

# --- The Training Loop ---
print("Starting WGAN training...")
for epoch in range(50):
    for i, real_samples in enumerate(loader):
        # --- Train the Critic (Discriminator) ---
        # The critic needs to get better at telling real from fake.
        optimizer_D.zero_grad()
        
        # Generate some fake samples
        z = torch.randn(real_samples.size(0), latent_space_dim)
        fake_samples = wgan_generator(z).detach() # .detach() so we don't train the generator here
        
        # The WGAN loss aims to maximize the distance between the critic's scores for real vs. fake data
        loss_D = -(torch.mean(wgan_discriminator(real_samples)) - torch.mean(wgan_discriminator(fake_samples)))
        loss_D.backward()
        optimizer_D.step()
        
        # This is a key part of WGAN: clamp the critic's weights to a small range.
        # It prevents the critic from getting too confident, which keeps the training stable.
        for p in wgan_discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)
        
        # --- Train the Artist (Generator) ---
        # The artist's goal is to fool the critic.
        optimizer_G.zero_grad()
        
        # Generate a new batch of fake samples
        z = torch.randn(real_samples.size(0), latent_space_dim)
        fake_samples = wgan_generator(z)
        
        # We want to maximize the critic's score for fake samples (i.e., make it think they're real).
        loss_G = -torch.mean(wgan_discriminator(fake_samples))
        loss_G.backward()
        optimizer_G.step()
    
    print(f"Epoch {epoch+1}, Discriminator Loss: {loss_D.item():.4f}, Generator Loss: {loss_G.item():.4f}")

### Step 5: Generating Data with Our WGAN

Now that our custom WGAN is trained, let's use the generator to create another 500 synthetic samples.

In [ ]:
# Create some random noise to feed our generator
z_noise = torch.randn(500, latent_space_dim)

# Generate the scaled synthetic data
wgan_synthetic_samples = wgan_generator(z_noise).detach().numpy()
wgan_synthetic_samples = pd.DataFrame(wgan_synthetic_samples, columns=real_industrial_data.columns)

print("WGAN synthetic samples generated.")
wgan_synthetic_samples.head()

### Step 6: Creating the Hybrid Dataset

Time to combine our efforts! We'll take the samples from both the CTGAN and our WGAN and merge them into one final synthetic dataset. We'll also use our scaler to transform the data back from the 0-1 range to its original scale.

In [ ]:
# Combine the synthetic data from both models
hybrid_synthetic_data_scaled = pd.concat([ctgan_synthetic_samples, wgan_synthetic_samples], ignore_index=True)

# Convert the scaled data back to its original range
hybrid_synthetic_data = pd.DataFrame(scaler.inverse_transform(hybrid_synthetic_data_scaled), columns=real_industrial_data.columns)

# Save our final creation to a CSV file
hybrid_synthetic_data.to_csv('synthetic_industrial_data.csv', index=False)

print("Synthetic Hybrid Dataset Generated and Saved Successfully!")
hybrid_synthetic_data.head()

### Step 7: The Moment of Truth - Visualization

Did our hybrid GAN do a good job? The best way to find out is to visualize it. We'll plot the distributions of the features from our original data and our new synthetic data side-by-side. If the plots look similar, we've succeeded!

In [ ]:
# Set a nice style for our plots
sns.set(style='whitegrid')

# --- Plotting the Distributions ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Comparing Real vs. Synthetic Data Distributions', fontsize=16)

# Temperature
sns.histplot(real_industrial_data['temperature'], kde=True, ax=axes[0, 0], color='blue', label='Original')
sns.histplot(hybrid_synthetic_data['temperature'], kde=True, ax=axes[0, 0], color='orange', label='Synthetic')
axes[0, 0].set_title('Temperature Distribution')
axes[0, 0].legend()

# Pressure
sns.histplot(real_industrial_data['pressure'], kde=True, ax=axes[0, 1], color='blue', label='Original')
sns.histplot(hybrid_synthetic_data['pressure'], kde=True, ax=axes[0, 1], color='orange', label='Synthetic')
axes[0, 1].set_title('Pressure Distribution')
axes[0, 1].legend()

# Vibration
sns.histplot(real_industrial_data['vibration'], kde=True, ax=axes[1, 0], color='blue', label='Original')
sns.histplot(hybrid_synthetic_data['vibration'], kde=True, ax=axes[1, 0], color='orange', label='Synthetic')
axes[1, 0].set_title('Vibration Distribution')
axes[1, 0].legend()

# Machine Age
sns.histplot(real_industrial_data['machine_age'], kde=True, ax=axes[1, 1], color='blue', label='Original')
sns.histplot(hybrid_synthetic_data['machine_age'], kde=True, ax=axes[1, 1], color='orange', label='Synthetic')
axes[1, 1].set_title('Machine Age Distribution')
axes[1, 1].legend()

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

# --- Plotting the Relationship between two variables ---
fig, ax = plt.subplots(figsize=(10, 6))
sns.scatterplot(x=real_industrial_data['temperature'], y=real_industrial_data['pressure'], color='blue', label='Original', alpha=0.6)
sns.scatterplot(x=hybrid_synthetic_data['temperature'], y=hybrid_synthetic_data['pressure'], color='orange', label='Synthetic', alpha=0.6)
ax.set_title('Scatter Plot: Real vs Synthetic (Temperature vs Pressure)')
ax.legend()
plt.show()
