In [None]:
import os

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

In [16]:
data_dir = '/home/takayuki/Desktop/sem6/DL/mini_proj_TRAFFIC/data/'
train_dir = os.path.join(data_dir, 'preprocessed', 'Train')

data_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
    transforms.RandomRotation(15),
])

full_dataset = datasets.ImageFolder(root = train_dir, transform=data_tf)

# Let's start with training the GAN on class id = 0, which is the minority class here

In [None]:
class_id = 0
batchSize = 8
learning_rate = 0.001
num_epochs = 100
zdims = 100

label_offset = 0.0

In [28]:
index = full_dataset.class_to_idx[str(class_id)]
class_dataset = [sample for sample in full_dataset.samples if sample[1] == index]

class_loader = DataLoader(class_dataset, batch_size=batchSize, shuffle=True)

print("Class Datset size: ", len(class_loader.dataset))

Class Datset size:  210


In [29]:
from gangen.dcgan import Generator, Discriminator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

G = Generator(z_dims=zdims).to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)

### GAN Training...
1. create label vectors of 0s & 1s 
    - these can also be made 0.1 and 0.9 if the D is becoming too confident too quickly
    - one might consider gradually pushing them from 0.1 to 0 and 0.9 to 1 as training progresses and the Generator gets better ??
2. Train D with real images and take a step with d_optimizer
3. Generate a batch of fake images using G, from a batch of vectors - z sampled from standard normal distr
4. Train D on this batch of generated images and those fake_labels from 1

In [None]:
g_loss_list = []
d_loss_list = []
d_accuracy_list = []



# 1. create label vectors of 0s & 1s 
real_labels = torch.ones(batchSize, 1, device=device, dtype=torch.float32) - label_offset
fake_labels = torch.zeros(batchSize, 1, device=device, dtype=torch.float32) + label_offset

for ep in range(num_epochs):
    rdloss = 0
    rdcorr = 0
    rgloss = 0
    
    for real_images, _ in class_loader:
        real_images = real_images.to(device)
        
        # 2. Train D with this batch of real images and take a step with d_optimizer
        d_out = D(real_images)
        d_loss = criterion(d_out, real_labels)
        
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        rdloss += d_loss.item()
        rdcorr += (d_out > 0.5).sum().item()        
        # 3. Generate a batch of fake images using G, from a batch of vectors - z sampled from standard normal distr
        z = torch.randn(batchSize, zdims, 1, 1, device=device, dtype=torch.float32)
        fake_images = G(z).detach()
        
        # 4. Train D on this batch of generated images  and those fake_labels 
        d_out_fakes = D(fake_images)
        d_loss = criterion(d_out_fakes, fake_labels)
        
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        rdloss += d_loss.item()
        rdcorr += (d_out < 0.5).sum().item()        
        
        # 5. Train Generator to have those fake images get matched with the real labels
        fake_images = G(z)
        d_out_fakes = D(fake_images)
        g_loss = criterion(d_out_fakes, real_labels)
        
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        rgloss += g_loss.item()
        
    csize = max(len(class_loader), 1)
    d_acc = rdcorr / (2 * csize * batchSize)
    rdloss /= 2*csize
    rgloss /= csize
    
    d_loss_list.append(rdloss)
    g_loss_list.append(rgloss)
    d_accuracy_list.append(d_acc)
    
    print(f"Epoch {ep+1}/{num_epochs} | D Loss: {rdloss:.4f}, G Loss: {rgloss:.4f}, D Accuracy: {d_acc:.4f}")