In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

from gangen.cdcgan import Generator, Discriminator



In [2]:
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [3]:
data_dir = '/home/takayuki/Desktop/sem6/DL/mini_proj_TRAFFIC/data/'
train_dir = os.path.join(data_dir, 'preprocessed', 'Train')

data_tf = transforms.Compose([
    transforms.RandomAffine(degrees=10, translate=(0.08, 0.08), scale=(0.9, 1.1), shear=5),
    transforms.Resize((32, 32)),
    
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
   
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.3),

    transforms.ToTensor(),

    transforms.RandomErasing(p=0.2, scale=(0.01, 0.04), ratio=(0.5, 2.0), value=0), # value=0 for black box

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


full_dataset = datasets.ImageFolder(root = train_dir, transform=data_tf)

In [4]:
# Hyperparameters
batch_size = 64
num_epochs = 200
learning_rate = 0.0002
beta1 = 0.5
beta2 = 0.999
noise_dim = 100
n_classes = 43
embedding_dim = 64
lambda_gp = 10

In [5]:
data_loader = DataLoader(
    full_dataset, 
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True
)

print("Counting class distribution...")
class_counts = np.zeros(n_classes)
for _, label in full_dataset:
    class_counts[label] += 1
    
print("Class counts:", class_counts)

class_weights = 1.0 / class_counts
class_weights = class_weights / np.sum(class_weights) * n_classes  # Normalize
class_weights = torch.FloatTensor(class_weights).to(device)

Counting class distribution...
Class counts: [ 210. 2220. 2010. 1320. 2100. 2160.  780.  630.  420. 1110. 1200.  210.
 2250.  360.  330.  390.  510.  270. 1500.  600.  240.  540.  270. 1410.
  450.  780.  240.  689.  420. 1200.  390.  210. 2070.  300. 1980.  360.
  240.  240. 1860.  420. 1440. 1410. 1470.]


In [6]:
G = Generator(noise_dim=noise_dim, n_classes=n_classes, embedding_dim=embedding_dim).to(device)
D = Discriminator(n_classes=n_classes, embedding_dim=embedding_dim).to(device)

# Set up optimizers
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(beta1, beta2))

# Loss function (BCE with logits for stability)
criterion = nn.BCEWithLogitsLoss()

In [7]:
# Setup TensorBoard
log_dir = 'logs/gan_training_' + time.strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(log_dir)

In [8]:
def generate_samples(num_samples=16, fixed_noise=None, fixed_labels=None):
    G.eval()
    with torch.no_grad():
        if fixed_noise is None:
            fixed_noise = torch.randn(num_samples, noise_dim, device=device)
        if fixed_labels is None:
            # Generate samples across different classes
            fixed_labels = torch.arange(0, min(n_classes, num_samples), device=device)
            # Repeat labels if num_samples > n_classes
            fixed_labels = fixed_labels.repeat(num_samples // min(n_classes, num_samples) + 1)[:num_samples]
        
        fake_samples = G(fixed_noise, fixed_labels)
        
    # Convert to displayable format
    fake_samples = fake_samples.detach().cpu().numpy()
    # Move channel dimension to the end for plotting
    fake_samples = np.transpose(fake_samples, (0, 2, 3, 1))
    # Denormalize
    fake_samples = fake_samples * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    fake_samples = np.clip(fake_samples, 0, 1)
    
    return fake_samples, fixed_labels.cpu().numpy()

# Create a grid of images
def create_image_grid(images, labels, nrow=4):
    ncol = images.shape[0] // nrow
    fig, axes = plt.subplots(nrow, ncol, figsize=(12, 12))
    
    for i, ax in enumerate(axes.flat):
        if i < images.shape[0]:
            ax.imshow(images[i])
            ax.set_title(f"Class: {labels[i]}")
            ax.axis('off')
    
    plt.tight_layout()
    return fig

# Function to save checkpoint
def save_checkpoint(G, D, optimizer_G, optimizer_D, epoch, filename='checkpoint.pth'):
    torch.save({
        'epoch': epoch,
        'generator_state_dict': G.state_dict(),
        'discriminator_state_dict': D.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
    }, filename)

# Create fixed noise for visualization
fixed_noise = torch.randn(16, noise_dim, device=device)
fixed_labels = torch.arange(0, min(n_classes, 16), device=device).repeat(16 // min(n_classes, 16) + 1)[:16]


# Training loop
print("Starting training...")

for epoch in range(num_epochs):
    G.train()
    D.train()
    
    # Metrics for this epoch
    epoch_d_loss = 0.0
    epoch_g_loss = 0.0
    epoch_d_real_accuracy = 0.0
    epoch_d_fake_accuracy = 0.0
    
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for i, (real_images, real_labels) in progress_bar:
        # Move data to device
        real_images = real_images.to(device)
        real_labels = real_labels.to(device)
        batch_size = real_images.size(0)
        
        # Create labels for real and fake images
        real_target = torch.ones(batch_size, 1, device=device)
        fake_target = torch.zeros(batch_size, 1, device=device)
        
        # ---------------------
        # Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Forward pass real batch through D
        d_real_output = D(real_images, real_labels)
        d_real_loss = criterion(d_real_output, real_target)
        
        # Generate batch of fake images
        z = torch.randn(batch_size, noise_dim, device=device)
        
        # Use weighted sampling for fake labels to address class imbalance
        fake_labels_dist = torch.multinomial(class_weights, batch_size, replacement=True)
        
        fake_images = G(z, fake_labels_dist)
        
        # Classify fake batch with D
        d_fake_output = D(fake_images.detach(), fake_labels_dist)
        d_fake_loss = criterion(d_fake_output, fake_target)
        
        # Add losses
        d_loss = d_real_loss + d_fake_loss
        
        # Calculate metrics
        d_real_accuracy = ((d_real_output > 0).float().mean()).item()
        d_fake_accuracy = ((d_fake_output < 0).float().mean()).item()
        
        # Backpropagation
        d_loss.backward()
        optimizer_D.step()
        
        # ---------------------
        # Train Generator
        # ---------------------
        optimizer_G.zero_grad()
        
        # Generate new batch of fake images
        # We reuse z and fake_labels_dist from above
        fake_images = G(z, fake_labels_dist)
        
        # Try to fool the discriminator
        g_output = D(fake_images, fake_labels_dist)
        g_loss = criterion(g_output, real_target)  # We want generator to produce images discriminator thinks are real
        
        # Backpropagation
        g_loss.backward()
        optimizer_G.step()
        
        # Update metrics
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()
        epoch_d_real_accuracy += d_real_accuracy
        epoch_d_fake_accuracy += d_fake_accuracy
        
        # Update progress bar
        progress_bar.set_postfix({
            'D Loss': f"{d_loss.item():.4f}",
            'G Loss': f"{g_loss.item():.4f}",
            'D Real Acc': f"{d_real_accuracy:.3f}",
            'D Fake Acc': f"{d_fake_accuracy:.3f}"
        })
    
    # End of epoch
    avg_d_loss = epoch_d_loss / len(data_loader)
    avg_g_loss = epoch_g_loss / len(data_loader)
    avg_d_real_acc = epoch_d_real_accuracy / len(data_loader)
    avg_d_fake_acc = epoch_d_fake_accuracy / len(data_loader)
    
    # Log to TensorBoard
    writer.add_scalar('Loss/Discriminator', avg_d_loss, epoch)
    writer.add_scalar('Loss/Generator', avg_g_loss, epoch)
    writer.add_scalar('Accuracy/Discriminator_Real', avg_d_real_acc, epoch)
    writer.add_scalar('Accuracy/Discriminator_Fake', avg_d_fake_acc, epoch)
    
    print(f"\nEpoch [{epoch+1}/{num_epochs}] - "
          f"D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}, "
          f"D Real Acc: {avg_d_real_acc:.3f}, D Fake Acc: {avg_d_fake_acc:.3f}")
    
    # Generate and log sample images
    if (epoch + 1) % 5 == 0 or epoch == 0:
        fake_samples, sample_labels = generate_samples(num_samples=16, fixed_noise=fixed_noise, fixed_labels=fixed_labels)
        fig = create_image_grid(fake_samples, sample_labels)
        
        # Save figure to TensorBoard
        writer.add_figure(f'Generated Traffic Signs/Epoch {epoch+1}', fig, epoch)
        
        # Save figure to disk
        os.makedirs('generated_samples', exist_ok=True)
        fig.savefig(f'generated_samples/epoch_{epoch+1}.png')
        plt.close(fig)
    
    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        os.makedirs('checkpoints', exist_ok=True)
        save_checkpoint(G, D, optimizer_G, optimizer_D, epoch, f'checkpoints/gan_checkpoint_epoch_{epoch+1}.pth')

# Save final model
save_checkpoint(G, D, optimizer_G, optimizer_D, num_epochs, 'checkpoints/gan_final_model.pth')

print("Training complete!")
writer.close()

Starting training...


Epoch 1/200:   0%|          | 0/612 [00:00<?, ?it/s]


TypeError: Sequential.forward() takes 2 positional arguments but 3 were given

In [None]:

# Generate final samples
fake_samples, sample_labels = generate_samples(num_samples=36)
fig = create_image_grid(fake_samples, sample_labels, nrow=6)
fig.savefig('final_generated_samples.png')
plt.close(fig)