# DCGAN for Image Generation

**Author:** RSK World  
**Website:** https://rskworld.in  
**Email:** help@rskworld.in  
**Phone:** +91 93305 39277

This notebook implements DCGAN (Deep Convolutional Generative Adversarial Network) for generating realistic images using adversarial training with convolutional layers.

## Features
- DCGAN architecture with convolutional generator and discriminator
- Adversarial training with stable techniques
- Batch normalization and LeakyReLU activations
- Proper weight initialization
- Realistic image generation


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import os

# Import our custom modules
from dcgan_model import Generator, Discriminator, weights_init
from data_loader import get_dataloader, denormalize
from trainer import DCGANTrainer
from utils import save_image_grid, plot_training_losses, generate_and_save_images
import config

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## Configuration

Set up hyperparameters and training configuration.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model parameters
nz = 100  # Size of input noise vector
ngf = 64  # Number of generator filters
ndf = 64  # Number of discriminator filters
nc = 3    # Number of channels (3 for RGB, 1 for grayscale)
image_size = 64  # Image size (64x64 or 128x128)

# Training parameters
batch_size = 128
num_epochs = 50
lr = 0.0002
beta1 = 0.5

# Dataset
dataset_name = 'custom'  # 'custom', 'celeba', 'cifar10', 'mnist'
data_root = './data'
custom_data_dir = './data/custom'  # Required if dataset_name='custom'

# Output directories
output_dir = './outputs'
checkpoint_dir = './checkpoints'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

print("Configuration set!")


## Load Dataset

Load and visualize the training dataset.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Load dataset
print(f"Loading dataset: {dataset_name}")
dataloader = get_dataloader(
    dataset_name=dataset_name,
    root=data_root,
    image_size=image_size,
    batch_size=batch_size,
    num_workers=2,
    custom_dir=custom_data_dir if dataset_name == 'custom' else None
)

print(f"Dataset loaded successfully!")
print(f"Number of batches: {len(dataloader)}")
print(f"Batch size: {batch_size}")
print(f"Total images: {len(dataloader) * batch_size}")

# Visualize some real images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(make_grid(
    real_batch[0][:64], 
    padding=2, 
    normalize=True,
    value_range=(-1, 1)
).cpu(), (1, 2, 0)))
plt.show()


## Initialize Models

Create and initialize the Generator and Discriminator networks.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Create Generator
netG = Generator(nz=nz, ngf=ngf, nc=nc, image_size=image_size).to(device)
netG.apply(weights_init)
print("Generator created:")
print(netG)

# Create Discriminator
netD = Discriminator(nc=nc, ndf=ndf, image_size=image_size).to(device)
netD.apply(weights_init)
print("\nDiscriminator created:")
print(netD)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nGenerator parameters: {count_parameters(netG):,}")
print(f"Discriminator parameters: {count_parameters(netD):,}")

# Create fixed noise for visualization
fixed_noise = torch.randn(64, nz, 1, 1, device=device)


## Setup Training

Initialize loss function, optimizers, and trainer.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Create trainer
trainer = DCGANTrainer(
    generator=netG,
    discriminator=netD,
    device=device,
    lr=lr,
    beta1=beta1,
    nz=nz
)

print("Trainer initialized successfully!")


## Training Loop

Train the DCGAN model. This may take a while depending on your hardware.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Training loop
print("Starting training...")
print("=" * 50)

for epoch in range(1, num_epochs + 1):
    trainer.train_epoch(dataloader, epoch, num_epochs)
    
    # Generate and visualize samples every epoch
    if epoch % 1 == 0:
        with torch.no_grad():
            fake = trainer.netG(fixed_noise).detach().cpu()
        
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title(f"Generated Images - Epoch {epoch}")
        plt.imshow(np.transpose(make_grid(
            fake, 
            padding=2, 
            normalize=True,
            value_range=(-1, 1)
        ), (1, 2, 0)))
        plt.show()
        
        # Save images
        sample_path = os.path.join(output_dir, f'epoch_{epoch:03d}_samples.png')
        save_image_grid(fake, sample_path)
    
    # Save checkpoint every 5 epochs
    if epoch % 5 == 0:
        trainer.save_checkpoint(epoch, checkpoint_dir)
        print(f"Checkpoint saved at epoch {epoch}")

print("\nTraining completed!")


## Visualize Training Progress

Plot the training losses and view generated samples.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Plot training losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(trainer.G_losses, label="Generator")
plt.plot(trainer.D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

# Save loss plot
loss_plot_path = os.path.join(output_dir, 'training_losses.png')
plot_training_losses(trainer.G_losses, trainer.D_losses, loss_plot_path)
print(f"Loss plot saved to {loss_plot_path}")


## Generate New Images

Generate new images using the trained generator.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Generate new images
num_samples = 64
noise = torch.randn(num_samples, nz, 1, 1, device=device)

trainer.netG.eval()
with torch.no_grad():
    fake_images = trainer.netG(noise).detach().cpu()

# Visualize generated images
plt.figure(figsize=(12, 12))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(make_grid(
    fake_images, 
    padding=2, 
    normalize=True,
    value_range=(-1, 1),
    nrow=8
), (1, 2, 0)))
plt.show()

# Save generated images
output_path = os.path.join(output_dir, 'final_generated_samples.png')
save_image_grid(fake_images, output_path)
print(f"Generated images saved to {output_path}")


## Save Final Model

Save the trained models for future use.


In [None]:
# Author: RSK World
# Website: https://rskworld.in
# Email: help@rskworld.in
# Phone: +91 93305 39277

# Save final models
final_generator_path = os.path.join(checkpoint_dir, 'final_generator.pth')
final_discriminator_path = os.path.join(checkpoint_dir, 'final_discriminator.pth')

torch.save({
    'model_state_dict': trainer.netG.state_dict(),
    'optimizer_state_dict': trainer.optimizerG.state_dict(),
    'epoch': num_epochs,
    'losses': trainer.G_losses,
    'nz': nz,
    'ngf': ngf,
    'nc': nc,
    'image_size': image_size
}, final_generator_path)

torch.save({
    'model_state_dict': trainer.netD.state_dict(),
    'optimizer_state_dict': trainer.optimizerD.state_dict(),
    'epoch': num_epochs,
    'losses': trainer.D_losses,
    'nc': nc,
    'ndf': ndf,
    'image_size': image_size
}, final_discriminator_path)

print(f"Final models saved:")
print(f"  Generator: {final_generator_path}")
print(f"  Discriminator: {final_discriminator_path}")
