## 4th Version of Metamaterials GAN
Beginning by following MNIST 'template', then adding complexity as problem dictates

In [81]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
from tensorboardX import SummaryWriter
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import time
import os
# useful v1 functions
import import_ipynb 
import importlib
import metamaterials_GAN_v1
importlib.reload(metamaterials_GAN_v1)

from metamaterials_GAN_v1 import plot_shape, load_item, quarter, dataset, dataloader

if __name__ == "__main__":
    print("Torch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "No GPU detected")

Torch version: 2.6.0
CUDA available: False
CUDA version: None
Number of GPUs: 0
GPU name: No GPU detected


In [107]:
class Discriminator(nn.Module):
    """
    Discriminator for conditional GAN
    Inputs:
        - Waveguide, size: (batch_size, 1, 32, 32)
        - Parameters, size (batchsize, 4)
        - Modes (condition), size (batchsize, 4)
    Outputs:
        - 0-1, if image is real or generated, size (batchsize)
    Questions:
        - Should I be using dropout in image_fc, or at all in my Discriminator??
        - Am I correct in using conv2d and splitting the problem into
          image convolution and parameter process and then combining?
        - 
    """

    def __init__(self):
        super().__init__()
        
        # Process for waveguide
        # Note on conv, output_size = 1 + [(input_size + 2*padding-kernel_size)/stride]
        self.image_conv = nn.Sequential(
            # Input is an image of shape (1,32,32), meaning greyscale and 32x32 pixels
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # (batchsize, 64, 16, 16) -> 65 channels, each of size 16 x 16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (batchsize, 128, 8, 8)
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (batchsize, 256, 4, 4)
            nn.Flatten() #  (batchsize, 256 x 4 x 4 = 4096) for linear output
        )

        # Process for parameters
        self.param_fc = nn.Sequential(
            # Need to take (batchsize, 4) and make (batchsize, 256) for concatenation,
            # add hidden layer so that we can infer information about parameters as well.
            nn.Linear(4,128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256)
        )
        # Process for modes
        self.cond_fc = nn.Sequential(
            # Rescales (batchsize, 4->256), maps to same feature space as image and params
            nn.Linear(8, 256),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Full combined model for all processes
        self.model = nn.Sequential(
            nn.Linear(4096 + 256 + 256, 512), # image + params + cond
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, params, cond):
        img_feat = self.image_conv(img)
        param_feat = self.param_fc(params)
        cond_feat = self.cond_fc(cond)

        x = torch.cat([img_feat, param_feat, cond_feat], dim=1)
        final = self.model(x)

        return final.squeeze() # returns (batchsize), where each number is 0 -> 1 based on how likely

In [108]:
class Generator(nn.Module):
    """
    Discriminator for conditional GAN
    Inputs:
        - Modes (condition), size (batchsize, 4)
    Outputs:
        - Waveguide, size (batchsize, 32, 32)
        - Params, size (batchsize, 4)
    Questions:
        - Should we still be using latent vector like in MNIST, as we want 
          consistent results i.e. for a set of modes, we want as close 
          to the same waveguide as possible each time? 
        - Should I be feeding my generated waveguide shape into my params
          process as well (and maybe in discrim too)? Also, does my params
          process need more layers?
    """

    def __init__(self):
        super().__init__()
        
        self.fc = nn.Sequential(
            # Need to transform cond vector into higher dimension
            # so that we can reshape it for deconv (batchsize, 8 -> 4096)
            nn.Linear(8, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True) 
        )
        # Output = (input_size-1)*stride-2*padding+kernel_size
        self.deconv = nn.Sequential(
            # We start with 256 4x4 pieces generated from our cond input
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 4x4 ->  8x8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 8x8 -> 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), # 16x16 -> 32x32, greyscale so only 1 output channel
            nn.Tanh()
        )

        # Takes in cond and outputs parameters
        
        self.param_proc = nn.Sequential(
            nn.Linear(4096, 128),
            nn.ReLU(True),
            nn.Linear(128, 4) # outputs (batchsize, 4)
        )
    
    def forward(self, cond):
        x = self.fc(cond) # (batchsize, 8 -> 4096)
        cond_feat = x.view(x.size(0), 256, 4, 4) #(batchsize, 4096) -> (batchsize, 256, 4, 4)
        image = self.deconv(cond_feat)
        params = self.param_proc(x)

        return image, params

In [109]:
def generator_train_step(batch, discriminator, generator, g_optimizer, g_criterion, d_criterion, device, adv_w=1, re_w=1):
    generator.train()
    g_optimizer.zero_grad()
    # will almost certainly have to change but same logic flow
    eigenmodes, weights, real_params, real_waveguides = [b.to(device) for b in batch]
    cond = torch.cat([eigenmodes, weights], dim=-1)

    fake_waveguides, fake_params = generator(cond)

    validity = discriminator(fake_waveguides, fake_params, cond)
    adv_loss = d_criterion(validity, Variable(torch.ones_like(validity))) # how it fairs against discriminator

    # These are how it fairs against real data, included because only one real result, unsure if to keep? 
    image_loss = g_criterion(fake_waveguides, real_waveguides)
    params_loss = g_criterion(fake_params, real_params)

    # can adjust weights to make it fully adversarial 
    g_loss = adv_loss * adv_w + (image_loss + params_loss) * re_w

    g_loss.backward()
    g_optimizer.step()
    return g_loss.item()

In [110]:
def discriminator_train_step(batch, discriminator, generator, d_optimizer, d_criterion, device):
    discriminator.train()
    d_optimizer.zero_grad()

    eigenmodes, weights, real_params, real_waveguides = [b.to(device) for b in batch]
    cond = torch.cat([eigenmodes, weights], dim=-1)

    real_validity = discriminator(real_waveguides, real_params, cond)
    real_loss = d_criterion(real_validity, Variable(torch.ones_like(real_validity)))

    fake_waveguides, fake_params = generator(cond)
    fake_validity = discriminator(fake_waveguides, fake_params, cond)
    fake_loss = d_criterion(fake_validity, Variable(torch.zeros_like(real_validity)))

    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    
    return d_loss.item()

Need to figure out differences between my dataset structure and MNIST dataset structure

Need to implement training loop, remember that output must be binarized before being fed to the discriminator!

For binarization, will that not significantly increase the loss of my model?

In [119]:
device = 'cuda' if torch.cuda.is_available() else 'cpu' # set up device

d = Discriminator().to(device)
g = Generator().to(device) 
d_optimizer = torch.optim.Adam(d.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(g.parameters(), lr=1e-4)
d_criterion = nn.BCELoss() # outputs [0,1]
g_criterion = nn.MSELoss() # outputs [-1,1]

writer = SummaryWriter()
num_epochs = 50
n_critic = 5
display_step = 50

for epoch in range(num_epochs):

    start = time.perf_counter()
    print('Starting epoch {}...'.format(epoch), end=' ')
    i = 0

    for batch in dataloader:
        step = epoch * len(dataloader) + i + 1
        i += 1

        for _ in range(n_critic):
            d_loss = discriminator_train_step(batch, d,g, d_optimizer, d_criterion, device)
        
        g_loss = generator_train_step(batch, d, g, g_optimizer, g_criterion, d_criterion, device)
        writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / n_critic)}, step)  

        if step % display_step == 0:
            g.eval()
            batch = [dataset[i] for i in range(10)]
            e_modes, weights, params, real_wguides = zip(*batch)

            e_modes = torch.stack(e_modes).to(device)          # (10, 4)
            weights = torch.stack(weights).to(device)          # (10, 4)
            cond = torch.cat([e_modes, weights], dim=1)        # (10, 8)

            real_wguides = torch.stack(real_wguides).to(device)  # (10, 1, 32, 32)

            with torch.no_grad():
                fake_wguides, _ = g(cond)                       # (10, 1, 32, 32)
            grid_fake = make_grid(fake_wguides, nrow=5, normalize=True)
            grid_real = make_grid(real_wguides, nrow=5, normalize=True)

            # Write to TensorBoard
            writer.add_image('Generated_Waveguides', grid_fake, step)
            writer.add_image('Real_Waveguides', grid_real, step)
    
    elapsed = time.perf_counter() - start
    print(f'Done! - {elapsed} s')


Starting epoch 0... 

KeyboardInterrupt: 