In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import os

# Define Generator
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.model(x)

# Gradient Penalty for stabilization
def gradient_penalty(discriminator, real, fake):
    batch_size, img_dim = real.size()
    epsilon = torch.rand(batch_size, 1).repeat(1, img_dim).to("cuda")
    interpolated = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)
    prob_interpolated = discriminator(interpolated)
    gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                    grad_outputs=torch.ones_like(prob_interpolated),
                                    create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

# Hyperparameters
noise_dim = 128  # Increased noise dimensionality
img_dim = 52 * 52  # Adjusted to match dataset size
batch_size = 64
epochs = 50
lr_gen = 0.0002
lr_disc = 0.0001
label_smoothing_real = 0.9
label_smoothing_fake = 0.1
grad_penalty_lambda = 10  # Coefficient for gradient penalty

# Dataset Loader for MixedWM38
class WaferMapDataset(Dataset):
    def __init__(self, file_path):
        with np.load(file_path) as data:
            self.images = data['arr_0']

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx].astype(np.float32).flatten()
        return img

# Load dataset
dataset = WaferMapDataset(file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz")
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models, optimizers, and loss
generator = Generator(noise_dim, img_dim).to("cuda")
discriminator = Discriminator(img_dim).to("cuda")
optim_gen = optim.Adam(generator.parameters(), lr=lr_gen)
optim_disc = optim.Adam(discriminator.parameters(), lr=lr_disc)
criterion = nn.BCEWithLogitsLoss()  # Updated to match non-sigmoid output

# Track losses
losses_gen = []
losses_disc = []

# Training loop
for epoch in range(epochs):
    for real in loader:
        real = real.to("cuda")
        batch_size = real.size(0)

        # Add label flipping
        flip_real = torch.rand(batch_size) < 0.05  # 5% chance to flip labels
        flip_fake = torch.rand(batch_size) < 0.05

        # Train Discriminator
        noise = torch.randn(batch_size, noise_dim).to("cuda")
        fake = generator(noise)
        disc_real = discriminator(real).view(-1)
        real_labels = torch.full_like(disc_real, label_smoothing_real)
        real_labels[flip_real] = label_smoothing_fake  # Flip some real labels
        loss_real = criterion(disc_real, real_labels)

        disc_fake = discriminator(fake.detach()).view(-1)
        fake_labels = torch.full_like(disc_fake, label_smoothing_fake)
        fake_labels[flip_fake] = label_smoothing_real  # Flip some fake labels
        loss_fake = criterion(disc_fake, fake_labels)

        gp = gradient_penalty(discriminator, real, fake)  # Apply gradient penalty
        loss_disc = (loss_real + loss_fake) / 2 + grad_penalty_lambda * gp

        optim_disc.zero_grad()
        loss_disc.backward()
        optim_disc.step()

        # Train Generator (twice as frequently)
        for _ in range(2):
            noise = torch.randn(batch_size, noise_dim).to("cuda")
            fake = generator(noise)
            disc_fake = discriminator(fake).view(-1)
            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))
            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()

    # Log losses
    losses_gen.append(loss_gen.item())
    losses_disc.append(loss_disc.item())

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Save generated images
    if (epoch + 1) % 10 == 0:
        save_image(fake.view(-1, 1, 52, 52), f"/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_fake_{epoch+1}.png")

# Save final model
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_discriminator.pth")

# Plot learning curves
plt.figure(figsize=(10, 5))
plt.plot(losses_gen, label='Generator Loss')
plt.plot(losses_disc, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Learning Curve')
plt.savefig("/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_learning_curve.png")
plt.close()


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch [1/50] Loss D: 0.5483, Loss G: 2.1871
Epoch [2/50] Loss D: 0.4881, Loss G: 1.9958
Epoch [3/50] Loss D: 0.6351, Loss G: 1.2008
Epoch [4/50] Loss D: 0.7093, Loss G: 1.0119
Epoch [5/50] Loss D: 1.3678, Loss G: 0.7927
Epoch [6/50] Loss D: 0.8886, Loss G: 1.2664
Epoch [7/50] Loss D: 1.5131, Loss G: 0.3397
Epoch [8/50] Loss D: 0.9907, Loss G: 0.5839
Epoch [9/50] Loss D: 0.6661, Loss G: 1.2543
Epoch [10/50] Loss D: 0.5594, Loss G: 1.6173
Epoch [11/50] Loss D: 0.4933, Loss G: 1.9285
Epoch [12/50] Loss D: 0.5614, Loss G: 1.3889
Epoch [13/50] Loss D: 0.5184, Loss G: 2.5397
Epoch [14/50] Loss D: 0.4051, Loss G: 2.0866
Epoch [15/50] Loss D: 0.4106, Loss G: 1.7784
Epoch [16/50] Loss D: 0.5219, Loss G: 2.2811
Epoch [17/50] Loss D: 0.4468, Loss G: 2.2474
Epoch [18/50] Loss D: 0.4207, Loss G: 1.8919
Epoch [19/50] Loss D: 0.4546, Loss G: 1.9933
Epoch [20/50] Loss D: 0.4580, Loss G: 1.8428
Epoch [21/50] Loss D: 0.4169, Loss G: 2.0045
Epoch [22/50] Loss D: 0.4231, Loss G: 1.9616
Epoch [23/50] Loss 