In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

# ----------------
#   Generator
# ----------------
class Generator(nn.Module):
    def __init__(self, ngpu, nc=1, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # noise Z -> 4x4
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            # 4x4 -> 8x8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # 8x8 -> 16x16
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # 16x16 -> 32x32
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # 32x32 -> 64x64
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )


    def forward(self, input):
        return self.main(input)

# ----------------
#   Discriminator
# ----------------
class Discriminator(nn.Module):
    def __init__(self, ngpu, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),   # -> (ndf) x 32 x 32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  # -> (ndf*2) x 16 x 16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),  # -> (ndf*4) x 8 x 8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),  # -> (ndf*8) x 4 x 4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # Final layer: kernel_size=4, stride=1, padding=0 => output shape (1 x 1)
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        for i, layer in enumerate(self.main):
            x = layer(x)
        return x.view(-1)



# ----------------
#  Hyperparameters
# ----------------
noise_dim = 100
batch_size = 2048
epochs = 200
lr_gen = 0.0002
lr_disc = 0.0002
label_smoothing_real = 0.9
label_smoothing_fake = 0.1

# ----------------
#   Dataset Loader
# ----------------
class WaferMapDataset(Dataset):
    def __init__(self, file_path, transform=None):
        with np.load(file_path) as data:
            self.data = data['arr_0']
            self.onehot_labels = data['arr_1']  # Not used if unconditional
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx].astype(np.float32)
        if self.transform:
            img = self.transform(img)
        # Return only img if unconditional
        return img

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64)),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

# Load dataset
dataset = WaferMapDataset(
    file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz",
    transform=transform
)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# -----------------------------
#   Simple Visualization
# -----------------------------
def visualize_real_and_generated(dataset, generator, epoch, noise_dim, save_path="./"):
    # Take a few real images
    rows = 2
    cols = 8
    fig, axs = plt.subplots(rows, cols, figsize=(cols*2.0, rows*2.0))
    axs = axs.flatten()

    # Show real images in the first row
    for i in range(cols):
        real_img = dataset[i]  # get i-th item
        axs[i].imshow(real_img.squeeze(), cmap="gray")
        axs[i].set_title("Real")
        axs[i].axis("off")

    # Generate fake images in the second row
    noise = torch.randn(cols, noise_dim, 1, 1).to("cuda")
    with torch.no_grad():
        fake_imgs = generator(noise).cpu()
    for i in range(cols):
        axs[cols + i].imshow(fake_imgs[i].squeeze(), cmap="gray")
        axs[cols + i].set_title("Fake")
        axs[cols + i].axis("off")

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"epoch_{epoch}_visualization.png"))
    plt.close()


# ----------------
#   Initialize
# ----------------
generator = Generator(ngpu=1, nc=1, nz=noise_dim, ngf=64).to("cuda")
discriminator = Discriminator(ngpu=1, nc=1, ndf=64).to("cuda")

optim_gen = optim.Adam(generator.parameters(), lr=lr_gen, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=lr_disc, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

losses_gen = []
losses_disc = []

# ----------------
#   Training Loop
# ----------------
for epoch in range(epochs):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch")
    for real in pbar:  # real is now just images
        real = real.view(real.size(0), 1, 64, 64).to("cuda")
        batch_size = real.size(0)

        # Generate fake images
        noise = torch.randn(batch_size, noise_dim, 1, 1).to("cuda")
        #fake = generator(noise)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        # Real label smoothing (sometimes flipping real to fake)
        flip_real = torch.rand(batch_size, device=real.device) < 0.1
        real_labels = torch.full((batch_size,), label_smoothing_real, device=real.device)
        real_labels[flip_real] = label_smoothing_fake

        # Fake label smoothing (sometimes flipping fake to real)
        flip_fake = torch.rand(batch_size, device=real.device) < 0.1
        fake_labels = torch.full((batch_size,), label_smoothing_fake, device=real.device)
        fake_labels[flip_fake] = label_smoothing_real

        disc_real = discriminator(real).view(-1)
        loss_real = criterion(disc_real, real_labels)

        disc_fake = discriminator(fake.detach()).view(-1)
        loss_fake = criterion(disc_fake, fake_labels)

        loss_disc = (loss_real + loss_fake) / 2

        optim_disc.zero_grad()
        loss_disc.backward()
        optim_disc.step()

        # -----------------
        #  Train Generator
        # -----------------
        # Re-generate fake because we detached
        fake = generator(noise)
        disc_fake = discriminator(fake).view(-1)
        loss_gen = criterion(disc_fake, torch.ones_like(disc_fake))

        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        pbar.set_postfix({"Loss D": loss_disc.item(), "Loss G": loss_gen.item()})

    losses_gen.append(loss_gen.item())
    losses_disc.append(loss_disc.item())

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Visualization
    if (epoch + 1) % 5 == 0:
        visualize_real_and_generated(dataset, generator, epoch + 1, noise_dim, save_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output")

# Save final models
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/enhanced_discriminator.pth")

Epoch 1/200: 100%|██████████| 19/19 [00:10<00:00,  1.75batch/s, Loss D=0.595, Loss G=0.693]


Epoch [1/200] Loss D: 0.5948, Loss G: 0.6926


Epoch 2/200: 100%|██████████| 19/19 [00:10<00:00,  1.77batch/s, Loss D=0.597, Loss G=0.693]


Epoch [2/200] Loss D: 0.5973, Loss G: 0.6928


Epoch 3/200:  84%|████████▍ | 16/19 [00:09<00:01,  1.66batch/s, Loss D=0.594, Loss G=0.693]


KeyboardInterrupt: 