In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

class Generator(nn.Module):
    def __init__(self, noise_dim, num_classes):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, noise_dim)
        self.model = nn.Sequential(
            # Input: [batch_size, noise_dim + noise_dim, 1, 1]
            nn.ConvTranspose2d(noise_dim + noise_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # Output: [batch_size, 512, 4, 4]

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Output: [batch_size, 256, 8, 8]

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Output: [batch_size, 128, 16, 16]

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # Output: [batch_size, 64, 32, 32]

            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: [batch_size, 1, 64, 64]
        )

    def forward(self, noise, labels):
        label_input = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
        gen_input = torch.cat((noise, label_input), dim=1)
        return self.model(gen_input)




class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, 64 * 64)
        self.model = nn.Sequential(
            nn.Conv2d(2, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=8, stride=1, padding=0, bias=False)
        )

    def forward(self, img, labels):
        label_input = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
        disc_input = torch.cat((img, label_input), dim=1)
        return self.model(disc_input).view(-1)



# Hyperparameters
noise_dim = 100
batch_size = 2048
epochs = 50
lr_gen = 0.0002
lr_disc = 0.0002
label_smoothing_real = 0.9
label_smoothing_fake = 0.1
num_classes = 38

# Dataset Loader
class WaferMapDataset(Dataset):
    def __init__(self, file_path, transform=None):
        with np.load(file_path) as data:
            self.data = data['arr_0']
            self.onehot_labels = data['arr_1']
        self.labels = self._get_class_indices(self.onehot_labels)
        self.transform = transform

    def _get_class_indices(self, onehot_labels):
        class_indices = []
        for row in onehot_labels:
            active_indices = tuple(np.where(row == 1)[0])
            if active_indices not in class_indices:
                class_indices.append(active_indices)
        mapping = {cls: idx for idx, cls in enumerate(class_indices)}
        return np.array([mapping[tuple(np.where(row == 1)[0])] for row in onehot_labels])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx].astype(np.float32)
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64)),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1] range
])

# Load dataset
dataset = WaferMapDataset(
    file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz",
    transform=transform
)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def visualize_real_and_generated(dataset, generator, epoch, num_classes, noise_dim):
    fig, axs = plt.subplots(4, 19, figsize=(25, 15))  # 4 rows, 19 columns
    axs = axs.flatten()

    class_samples = {cls: None for cls in range(num_classes)}

    # Collect one sample per class for real data
    for img, label in dataset:
        if class_samples[label] is None:
            class_samples[label] = img.reshape(64, 64)

        if all(sample is not None for sample in class_samples.values()):
            break

    # Generate fake samples
    noise = torch.randn(num_classes, noise_dim, 1, 1).to("cuda")
    labels = torch.arange(num_classes).to("cuda")
    generated = generator(noise, labels).detach().cpu().numpy()

    # Plot real samples (top two rows)
    for cls in range(num_classes):
        axs[cls].imshow(class_samples[cls], cmap="Blues")
        axs[cls].set_title(f"Real Class {cls}", fontsize=8)
        axs[cls].axis("off")

    # Plot fake samples (bottom two rows)
    for cls in range(num_classes):
        axs[cls + 38].imshow(generated[cls].squeeze(), cmap="Blues")
        axs[cls + 38].set_title(f"Fake Class {cls}", fontsize=8)
        axs[cls + 38].axis("off")

    plt.tight_layout()
    plt.savefig(f"/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/epoch_{epoch}_visualization.png")
    plt.close()


# Initialize models
generator = Generator(noise_dim, num_classes).to("cuda")
discriminator = Discriminator(num_classes).to("cuda")
optim_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

# Track losses
losses_gen = []
losses_disc = []

# Training loop
for epoch in range(epochs):
    # Progress bar for batches
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch")
    for real, labels in pbar:
        # Move data to the device
        real = real.view(real.size(0), 1, 64, 64).to("cuda")
        labels = labels.to("cuda")

        # Add noise to the real images for robustness
        real += 0.05 * torch.randn_like(real)
        batch_size = real.size(0)

        # Generate fake images
        noise = torch.randn(batch_size, noise_dim, 1, 1).to("cuda")
        fake = generator(noise, labels)

        # Train Discriminator
        flip_real = torch.rand(batch_size, device=real.device) < 0.1
        real_labels = torch.full((batch_size,), label_smoothing_real, device=real.device)
        real_labels[flip_real] = label_smoothing_fake

        flip_fake = torch.rand(batch_size, device=real.device) < 0.1
        fake_labels = torch.full((batch_size,), label_smoothing_fake, device=real.device)
        fake_labels[flip_fake] = label_smoothing_real

        disc_real = discriminator(real, labels).view(-1)
        loss_real = criterion(disc_real, real_labels)

        disc_fake = discriminator(fake.detach(), labels).view(-1)
        loss_fake = criterion(disc_fake, fake_labels)

        loss_disc = (loss_real + loss_fake) / 2

        # Optimize discriminator
        optim_disc.zero_grad()
        loss_disc.backward()
        optim_disc.step()

        # Train Generator
        for _ in range(2):  # Train generator twice for every discriminator step
            noise = torch.randn(batch_size, noise_dim, 1, 1).to("cuda")
            fake = generator(noise, labels)
            disc_fake = discriminator(fake, labels).view(-1)

            loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))

            optim_gen.zero_grad()
            loss_gen.backward()
            optim_gen.step()

        # Update progress bar with losses
        pbar.set_postfix({"Loss D": loss_disc.item(), "Loss G": loss_gen.item()})

    # Track losses
    losses_gen.append(loss_gen.item())
    losses_disc.append(loss_disc.item())

    # Print progress
    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Visualize real and generated samples every 10 epochs
    if (epoch + 1) % 1 == 0:
        visualize_real_and_generated(dataset, generator, epoch + 1, num_classes, noise_dim)

# Save the final models
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_discriminator.pth")

Epoch 1/50: 100%|██████████| 38/38 [00:17<00:00,  2.22batch/s, Loss D=0.534, Loss G=2.45]


Epoch [1/50] Loss D: 0.5343, Loss G: 2.4461


Epoch 2/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.547, Loss G=1.82]


Epoch [2/50] Loss D: 0.5465, Loss G: 1.8212


Epoch 3/50: 100%|██████████| 38/38 [00:14<00:00,  2.58batch/s, Loss D=0.495, Loss G=2.01]


Epoch [3/50] Loss D: 0.4955, Loss G: 2.0114


Epoch 4/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.514, Loss G=0.957]


Epoch [4/50] Loss D: 0.5142, Loss G: 0.9575


Epoch 5/50: 100%|██████████| 38/38 [00:14<00:00,  2.58batch/s, Loss D=0.516, Loss G=1.1]


Epoch [5/50] Loss D: 0.5164, Loss G: 1.0957


Epoch 6/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.505, Loss G=1.24]


Epoch [6/50] Loss D: 0.5054, Loss G: 1.2416


Epoch 7/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.521, Loss G=1.23]


Epoch [7/50] Loss D: 0.5213, Loss G: 1.2330


Epoch 8/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.477, Loss G=2.05]


Epoch [8/50] Loss D: 0.4773, Loss G: 2.0517


Epoch 9/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.484, Loss G=1.86]


Epoch [9/50] Loss D: 0.4844, Loss G: 1.8585


Epoch 10/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.471, Loss G=1.6]


Epoch [10/50] Loss D: 0.4708, Loss G: 1.5969


Epoch 11/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.517, Loss G=2.96]


Epoch [11/50] Loss D: 0.5172, Loss G: 2.9576


Epoch 12/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.487, Loss G=1.39]


Epoch [12/50] Loss D: 0.4875, Loss G: 1.3889


Epoch 13/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.507, Loss G=1.79]


Epoch [13/50] Loss D: 0.5074, Loss G: 1.7869


Epoch 14/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.492, Loss G=1.44]


Epoch [14/50] Loss D: 0.4922, Loss G: 1.4404


Epoch 15/50: 100%|██████████| 38/38 [00:14<00:00,  2.58batch/s, Loss D=0.51, Loss G=1.75]


Epoch [15/50] Loss D: 0.5099, Loss G: 1.7517


Epoch 16/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.469, Loss G=2.39]


Epoch [16/50] Loss D: 0.4687, Loss G: 2.3903


Epoch 17/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.497, Loss G=1.11]


Epoch [17/50] Loss D: 0.4968, Loss G: 1.1130


Epoch 18/50: 100%|██████████| 38/38 [00:14<00:00,  2.58batch/s, Loss D=0.484, Loss G=1.5]


Epoch [18/50] Loss D: 0.4842, Loss G: 1.4964


Epoch 19/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.444, Loss G=1.95]


Epoch [19/50] Loss D: 0.4442, Loss G: 1.9497


Epoch 20/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.535, Loss G=2.09]


Epoch [20/50] Loss D: 0.5349, Loss G: 2.0940


Epoch 21/50: 100%|██████████| 38/38 [00:14<00:00,  2.60batch/s, Loss D=0.443, Loss G=1.85]


Epoch [21/50] Loss D: 0.4429, Loss G: 1.8546


Epoch 22/50: 100%|██████████| 38/38 [00:14<00:00,  2.59batch/s, Loss D=0.519, Loss G=2.01]


Epoch [22/50] Loss D: 0.5191, Loss G: 2.0050


Epoch 23/50:  87%|████████▋ | 33/38 [00:13<00:01,  2.55batch/s, Loss D=0.476, Loss G=2.05]

In [None]:
# Plot learning curves
plt.figure(figsize=(10, 5))
plt.plot(losses_gen, label='Generator Loss')
plt.plot(losses_disc, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Learning Curve')
plt.show()
plt.savefig("/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/enhanced_learning_curve.png")
plt.close()