In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import os

# Define Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.model(x)

# Gradient Penalty for stabilization
def gradient_penalty(discriminator, real, fake):
    batch_size, img_dim = real.size()
    epsilon = torch.rand(batch_size, 1).repeat(1, img_dim).to("cuda")
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    prob_interpolated = discriminator(interpolated)
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                    grad_outputs=torch.ones_like(prob_interpolated),
                                    create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

# Hyperparameters
noise_dim = 52*52
img_dim = 52 * 52
batch_size = 64
epochs = 100
lr_gen = 0.0002
lr_disc = 0.00001
label_smoothing_real = 0.9
label_smoothing_fake = 0.1
grad_penalty_lambda = 10  # Coefficient for gradient penalty

# Dataset Loader for MixedWM38
class WaferMapDataset(Dataset):
    def __init__(self, file_path):
        with np.load(file_path) as data:
            print(data)
            self.images = data['arr_0']

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32).flatten()
        return img



In [None]:
# Load dataset
dataset = WaferMapDataset(file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz")
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# Initialize models, optimizers, and loss
generator = Generator(noise_dim, img_dim).to("cuda")
discriminator = Discriminator(img_dim).to("cuda")
optim_gen = optim.Adam(generator.parameters(), lr=lr_gen)
optim_disc = optim.Adam(discriminator.parameters(), lr=lr_disc)
criterion = nn.BCEWithLogitsLoss()

# Track losses
losses_gen = []
losses_disc = []

# Training loop
for epoch in range(epochs):
    for real in loader:
        real = real.to("cuda")
        real += 0.05 * torch.randn_like(real) # A bit of noise to make it harder for discriminator
        batch_size = real.size(0)

        # Add label flipping
        flip_real = torch.rand(batch_size) < 0.1  # 10% chance to flip labels
        flip_fake = torch.rand(batch_size) < 0.1

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim).to("cuda")
        fake = generator(noise)
        disc_real = discriminator(real).view(-1)
        real_labels = torch.full_like(disc_real, label_smoothing_real)
        real_labels[flip_real] = label_smoothing_fake  # Flip some real labels
        loss_real = criterion(disc_real, real_labels)

        disc_fake = discriminator(fake.detach()).view(-1)
        fake_labels = torch.full_like(disc_fake, label_smoothing_fake)
        fake_labels[flip_fake] = label_smoothing_real  # Flip some fake labels
        loss_fake = criterion(disc_fake, fake_labels)

        gp = gradient_penalty(discriminator, real, fake)  # Apply gradient penalty
        loss_disc = (loss_real + loss_fake) / 2 + grad_penalty_lambda * gp

        optim_disc.zero_grad()
        loss_disc.backward()
        optim_disc.step()

        # Train Generator (2 as frequently)
        for _ in range(2):
            noise = torch.randn(batch_size, noise_dim).to("cuda")
            fake = generator(noise)
            disc_fake = discriminator(fake).view(-1)
            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))
            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()

    # Log losses
    losses_gen.append(loss_gen.item())
    losses_disc.append(loss_disc.item())

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Save generated images
    if (epoch + 1) % 10 == 0:
        save_image(fake.view(-1, 1, 52, 52), f"/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_fake_{epoch+1}.png")
        # Plot learning curves
        plt.figure(figsize=(10, 5))
        plt.plot(losses_gen, label='Generator Loss')
        plt.plot(losses_disc, label='Discriminator Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Learning Curve')
        plt.savefig("/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_learning_curve.png")
        plt.close()
# Save final model
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_discriminator.pth")

# Plot learning curves
plt.figure(figsize=(10, 5))
plt.plot(losses_gen, label='Generator Loss')
plt.plot(losses_disc, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Learning Curve')
plt.savefig("/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_learning_curve.png")
plt.close()


In [None]:
from sklearn.metrics import mean_squared_error

# Evaluate MSE
def calculate_mse(real_images, fake_images):
    real_images_flat = real_images.view(real_images.size(0), -1).cpu().detach().numpy()
    fake_images_flat = fake_images.view(fake_images.size(0), -1).cpu().detach().numpy()
    mse = mean_squared_error(real_images_flat, fake_images_flat)
    return mse

# Generate samples for MSE calculation
real_samples, _ = next(iter(loader))  # Get a batch of real samples
real_samples = real_samples.to("cuda")

noise = torch.randn(batch_size, noise_dim).to("cuda")
generated_samples = generator(noise)

# Calculate MSE
mse = calculate_mse(real_samples, generated_samples)
print(f"Mean Squared Error (MSE): {mse}")