In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

# ----------------
#   Generator
# ----------------
class Generator(nn.Module):
    def __init__(self, ngpu, nc=1, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

# ----------------
#   Discriminator
# ----------------
class Discriminator(nn.Module):
    def __init__(self, ngpu, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).view(-1)

# ----------------
#   Dataset Loader
# ----------------
class WaferMapDataset(Dataset):
    def __init__(self, file_path, transform=None):
        with np.load(file_path) as data:
            self.data = data['arr_0']
            self.onehot_labels = data['arr_1']
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx].astype(np.float32)
        if self.transform:
            img = self.transform(img)
        return img

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64)),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

dataset = WaferMapDataset(
    file_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/Wafer_Map_Datasets.npz",
    transform=transform
)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

# ----------------
#   Weight Initialization
# ----------------
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# ----------------
#   Initialize
# ----------------
generator = Generator(ngpu=1, nc=1, nz=100, ngf=64).to("cuda")
discriminator = Discriminator(ngpu=1, nc=1, ndf=64).to("cuda")

generator.apply(weights_init)
discriminator.apply(weights_init)

criterion = nn.MSELoss()
optim_gen = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# -----------------------------
#   Simple Visualization
# -----------------------------
def visualize_real_and_generated(dataset, generator, epoch, save_path="./"):
    rows, cols = 2, 8
    fig, axs = plt.subplots(rows, cols, figsize=(cols*2.0, rows*2.0))
    axs = axs.flatten()

    for i in range(cols):
        real_img = dataset[i]
        axs[i].imshow(real_img.squeeze(), cmap="gray")
        axs[i].set_title("Real")
        axs[i].axis("off")

    noise = torch.randn(cols, 100, 1, 1).to("cuda")
    with torch.no_grad():
        fake_imgs = generator(noise).cpu()
    for i in range(cols):
        axs[cols + i].imshow(fake_imgs[i].squeeze(), cmap="gray")
        axs[cols + i].set_title("Fake")
        axs[cols + i].axis("off")

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"epoch_{epoch}_visualization    .png"))
    plt.close()

# ----------------
#   Training Loop
# ----------------
losses_gen = []
losses_disc = []

for epoch in range(200):
    pbar = tqdm(loader, desc=f"Epoch {epoch + 1}/200", unit="batch")
    for real in pbar:
        real = real.view(real.size(0), 1, 64, 64).to("cuda")
        batch_size = real.size(0)

        # Train Discriminator
        noise = torch.randn(batch_size, 100, 1, 1).to("cuda")
        fake = generator(noise)

        real_labels = torch.ones(batch_size, device=real.device)
        fake_labels = torch.zeros(batch_size, device=real.device)

        disc_real = discriminator(real).view(-1)
        disc_fake = discriminator(fake.detach()).view(-1)

        loss_disc_real = criterion(disc_real, real_labels)
        loss_disc_fake = criterion(disc_fake, fake_labels)
        loss_disc = (loss_disc_real + loss_disc_fake) / 2

        optim_disc.zero_grad()
        loss_disc.backward()
        optim_disc.step()

        # Train Generator
        fake = generator(noise)
        disc_fake = discriminator(fake).view(-1)
        loss_gen = criterion(disc_fake, real_labels)

        optim_gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        pbar.set_postfix({"Loss D": loss_disc.item(), "Loss G": loss_gen.item()})

    losses_gen.append(loss_gen.item())
    losses_disc.append(loss_disc.item())

    print(f"Epoch [{epoch + 1}/200] Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

    # Visualization
    if (epoch + 1) % 10 == 0:
        visualize_real_and_generated(dataset, generator, epoch + 1, save_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output")

# Save final models
torch.save(generator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/optimized_generator.pth")
torch.save(discriminator.state_dict(), "/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/models/optimized_discriminator.pth")

Epoch 1/200: 100%|██████████| 297/297 [00:14<00:00, 21.09batch/s, Loss D=0.135, Loss G=6.31]


Epoch [1/200] Loss D: 0.1353, Loss G: 6.3112


Epoch 2/200: 100%|██████████| 297/297 [00:13<00:00, 21.46batch/s, Loss D=0.067, Loss G=0.779]


Epoch [2/200] Loss D: 0.0670, Loss G: 0.7789


Epoch 3/200: 100%|██████████| 297/297 [00:14<00:00, 21.21batch/s, Loss D=0.113, Loss G=0.862]


Epoch [3/200] Loss D: 0.1126, Loss G: 0.8616


Epoch 4/200: 100%|██████████| 297/297 [00:13<00:00, 21.31batch/s, Loss D=0.713, Loss G=3.83]


Epoch [4/200] Loss D: 0.7134, Loss G: 3.8333


Epoch 5/200: 100%|██████████| 297/297 [00:13<00:00, 21.50batch/s, Loss D=0.0909, Loss G=0.716]


Epoch [5/200] Loss D: 0.0909, Loss G: 0.7156


Epoch 6/200: 100%|██████████| 297/297 [00:13<00:00, 21.46batch/s, Loss D=0.104, Loss G=0.607]


Epoch [6/200] Loss D: 0.1039, Loss G: 0.6067


Epoch 7/200: 100%|██████████| 297/297 [00:13<00:00, 21.38batch/s, Loss D=0.0433, Loss G=0.671]


Epoch [7/200] Loss D: 0.0433, Loss G: 0.6708


Epoch 8/200: 100%|██████████| 297/297 [00:13<00:00, 21.47batch/s, Loss D=0.0776, Loss G=0.937]


Epoch [8/200] Loss D: 0.0776, Loss G: 0.9368


Epoch 9/200: 100%|██████████| 297/297 [00:14<00:00, 21.19batch/s, Loss D=0.0633, Loss G=0.792]


Epoch [9/200] Loss D: 0.0633, Loss G: 0.7921


Epoch 10/200: 100%|██████████| 297/297 [00:13<00:00, 21.26batch/s, Loss D=0.0801, Loss G=0.718]


Epoch [10/200] Loss D: 0.0801, Loss G: 0.7177


Epoch 11/200: 100%|██████████| 297/297 [00:13<00:00, 21.32batch/s, Loss D=0.0137, Loss G=0.985]


Epoch [11/200] Loss D: 0.0137, Loss G: 0.9850


Epoch 12/200: 100%|██████████| 297/297 [00:13<00:00, 21.24batch/s, Loss D=0.0216, Loss G=0.845]


Epoch [12/200] Loss D: 0.0216, Loss G: 0.8447


Epoch 13/200: 100%|██████████| 297/297 [00:13<00:00, 21.32batch/s, Loss D=0.00718, Loss G=0.894]


Epoch [13/200] Loss D: 0.0072, Loss G: 0.8941


Epoch 14/200: 100%|██████████| 297/297 [00:13<00:00, 21.54batch/s, Loss D=1.14, Loss G=0.0737]


Epoch [14/200] Loss D: 1.1403, Loss G: 0.0737


Epoch 15/200: 100%|██████████| 297/297 [00:13<00:00, 21.42batch/s, Loss D=0.0748, Loss G=0.487]


Epoch [15/200] Loss D: 0.0748, Loss G: 0.4869


Epoch 16/200: 100%|██████████| 297/297 [00:14<00:00, 21.20batch/s, Loss D=0.0132, Loss G=0.74]


Epoch [16/200] Loss D: 0.0132, Loss G: 0.7404


Epoch 17/200: 100%|██████████| 297/297 [00:13<00:00, 21.35batch/s, Loss D=0.00997, Loss G=0.971]


Epoch [17/200] Loss D: 0.0100, Loss G: 0.9708


Epoch 18/200: 100%|██████████| 297/297 [00:13<00:00, 21.26batch/s, Loss D=0.00657, Loss G=1.02]


Epoch [18/200] Loss D: 0.0066, Loss G: 1.0183


Epoch 19/200: 100%|██████████| 297/297 [00:13<00:00, 21.24batch/s, Loss D=0.0553, Loss G=0.647]


Epoch [19/200] Loss D: 0.0553, Loss G: 0.6471


Epoch 20/200: 100%|██████████| 297/297 [00:13<00:00, 21.26batch/s, Loss D=0.0212, Loss G=0.72]


Epoch [20/200] Loss D: 0.0212, Loss G: 0.7197


Epoch 21/200: 100%|██████████| 297/297 [00:13<00:00, 21.39batch/s, Loss D=0.00849, Loss G=0.917]


Epoch [21/200] Loss D: 0.0085, Loss G: 0.9174


Epoch 22/200: 100%|██████████| 297/297 [00:13<00:00, 21.34batch/s, Loss D=0.0233, Loss G=1.36]


Epoch [22/200] Loss D: 0.0233, Loss G: 1.3552


Epoch 23/200: 100%|██████████| 297/297 [00:13<00:00, 21.34batch/s, Loss D=0.0341, Loss G=0.635]


Epoch [23/200] Loss D: 0.0341, Loss G: 0.6350


Epoch 24/200: 100%|██████████| 297/297 [00:13<00:00, 21.48batch/s, Loss D=0.00515, Loss G=0.893]


Epoch [24/200] Loss D: 0.0051, Loss G: 0.8931


Epoch 25/200: 100%|██████████| 297/297 [00:13<00:00, 21.31batch/s, Loss D=0.00279, Loss G=0.98]


Epoch [25/200] Loss D: 0.0028, Loss G: 0.9798


Epoch 26/200:  99%|█████████▉| 294/297 [00:13<00:00, 21.66batch/s, Loss D=0.026, Loss G=1.36]

In [None]:
import matplotlib.pyplot as plt

def plot_loss_curves(losses_gen, losses_disc, save_path=None):
    """
    Plot the generator and discriminator loss curves.

    Args:
        losses_gen (list): Generator losses per epoch.
        losses_disc (list): Discriminator losses per epoch.
        save_path (str, optional): Path to save the plot. Defaults to None.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(losses_gen, label="Generator Loss", linewidth=2)
    plt.plot(losses_disc, label="Discriminator Loss", linewidth=2)
    plt.title("Loss Curves", fontsize=16)
    plt.xlabel("Epoch", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)

    # Save or show the plot
    if save_path:
        plt.savefig(save_path)
        print(f"Loss curves saved to {save_path}")
    else:
        plt.show()

# Call the function after training
plot_loss_curves(
    losses_gen=losses_gen,
    losses_disc=losses_disc,
    save_path="/content/drive/MyDrive/Artificial_Intelligence_Course_NTUT/hw3a/output/loss_curves.png"
)
