In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import os

# Define Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim, num_classes):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, noise_dim)
        self.model = nn.Sequential(
            nn.Linear(noise_dim * 2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_input = self.label_emb(labels)
        x = torch.cat((noise, label_input), dim=1)
        return self.model(x)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim, num_classes):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, img_dim)
        self.model = nn.Sequential(
            nn.Linear(img_dim * 2, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, img, labels):
        label_input = self.label_emb(labels)
        x = torch.cat((img, label_input), dim=1)
        return self.model(x)

# Hyperparameters
noise_dim = 100
img_dim = 52 * 52  # Adjusted to match dataset size
num_classes = 8  # Based on defect types
batch_size = 128  # Increased batch size
epochs = 50
lr_gen = 0.00008  # Reduced learning rate for generator
lr_disc = 0.00004  # Reduced learning rate for discriminator
lambda_gp = 2.0  # Adjusted gradient penalty coefficient
num_disc_updates = 2  # Balanced training

# Dataset Loader for MixedWM38
class WaferMapDataset(Dataset):
    def __init__(self, file_path, transform=None):
        with np.load(file_path) as data:
            self.images = data['arr_0']  # Wafer maps
            self.labels = data['arr_1']  # One-hot labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32)
        label = np.argmax(self.labels[idx])  # Convert one-hot to class index
        if self.transform:
            img = self.transform(img)
        return img.flatten(), label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((52, 52)),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset
dataset = WaferMapDataset(file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz", transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models, optimizers, and loss
generator = Generator(noise_dim, img_dim, num_classes).to("cuda")
discriminator = Discriminator(img_dim, num_classes).to("cuda")
optim_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))

# Gradient Penalty for stabilization
def gradient_penalty(discriminator, real, fake, labels):
    batch_size, img_dim = real.size()
    epsilon = torch.rand(batch_size, 1).repeat(1, img_dim).to("cuda")
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    prob_interpolated = discriminator(interpolated, labels)
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                    grad_outputs=torch.ones_like(prob_interpolated),
                                    create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

# Training loop
for epoch in range(epochs):
    for i, (real, labels) in enumerate(loader):
        real, labels = real.to("cuda"), labels.to("cuda")
        batch_size = real.size(0)

        # Train Discriminator
        for _ in range(num_disc_updates):
            noise = torch.randn(batch_size, noise_dim).to("cuda")
            fake = generator(noise, labels)
            disc_real = discriminator(real, labels).view(-1)
            disc_fake = discriminator(fake.detach(), labels).view(-1)
            gp = gradient_penalty(discriminator, real, fake, labels)
            loss_disc = -torch.mean(disc_real) + torch.mean(disc_fake) + lambda_gp * gp
            optim_disc.zero_grad()
            loss_disc.backward()
            nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=10)  # Clip gradients
            optim_disc.step()

        # Train Generator
        noise = torch.randn(batch_size, noise_dim).to("cuda")
        fake = generator(noise, labels)
        disc_fake = discriminator(fake, labels).view(-1)
        loss_gen = -torch.mean(disc_fake)
        optim_gen.zero_grad()
        loss_gen.backward()
        nn.utils.clip_grad_norm_(generator.parameters(), max_norm=10)  # Clip gradients
        optim_gen.step()

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Save generated images
    if (epoch + 1) % 10 == 0:
        save_image(fake.view(-1, 1, 52, 52), f"/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/fake_{epoch+1}.png")

# Save final model
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/discriminator.pth")

Epoch [1/50] Loss D: -141.6975, Loss G: 33.9487
Epoch [2/50] Loss D: -144.7217, Loss G: 10.3861
Epoch [3/50] Loss D: -150.3700, Loss G: -29.5455
Epoch [4/50] Loss D: -153.1742, Loss G: -43.1910
Epoch [5/50] Loss D: -145.6715, Loss G: -49.6991
Epoch [6/50] Loss D: -156.2062, Loss G: -54.6608
Epoch [7/50] Loss D: -150.1908, Loss G: -65.3483
Epoch [8/50] Loss D: -163.0183, Loss G: -61.5171
Epoch [9/50] Loss D: -157.3172, Loss G: -67.9827
Epoch [10/50] Loss D: -153.1800, Loss G: -81.5960
Epoch [11/50] Loss D: -146.5420, Loss G: -67.8680
Epoch [12/50] Loss D: -161.9406, Loss G: -75.5958
Epoch [13/50] Loss D: -149.2812, Loss G: -82.1329
Epoch [14/50] Loss D: -147.2630, Loss G: -75.0279
Epoch [15/50] Loss D: -148.8197, Loss G: -73.3200
Epoch [16/50] Loss D: -151.3278, Loss G: -78.2941
Epoch [17/50] Loss D: -157.0057, Loss G: -73.7788
Epoch [18/50] Loss D: -155.1833, Loss G: -80.1763
Epoch [19/50] Loss D: -140.5251, Loss G: -60.2681
Epoch [20/50] Loss D: -156.9916, Loss G: -73.9149
Epoch [21/5