<a href="https://colab.research.google.com/github/tbass134/GAN-pytorch/blob/main/GAN_Pytorch_WGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [None]:
!pip install wandb -q

import wandb
wandb.init(project='cifar10-gan', entity='tbass134')
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mtbass134[0m (use `wandb login --relogin` to force relogin)


In [None]:
#used when not using wandb
# class Config():
#   def __init__(self):
#     pass
# config = Config()

In [None]:
class Critic(nn.Module): #aka discriminator
  def __init__(self, channels_img, features_d):
    super().__init__()
    self.critic = nn.Sequential(
        
      nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
      nn.LeakyReLU(0.2),
      
      nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1, bias=False),
      nn.InstanceNorm2d(features_d*2, affine=True),
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1, bias=False),
      nn.InstanceNorm2d(features_d*4, affine=True,
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1, bias=False),
      nn.InstanceNorm2d(features_d*8, affine=True),
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
      nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.critic(x)

class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super().__init__()
    self.gen = nn.Sequential(
        
        nn.ConvTranspose2d(channels_noise, features_g * 16, 4, 1, 0, bias=False),
        nn.BatchNorm2d(features_g * 16),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 16, features_g * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 8),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 8, features_g * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 4),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 4, features_g * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 2),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 2, channels_img, 4, 2, 1),
        nn.Tanh()
    )
       
  def forward(self, x):
    return self.gen(x)

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
config.img_size = 64
config.img_channels = 3

tfs = transforms.Compose([
                          transforms.Resize(config.img_size),
                          transforms.ToTensor(),
                          transforms.Normalize(
                              [0.5 for _ in range(config.img_channels)],
                              [0.5 for _ in range(config.img_channels)]
                          )
     ])
dset = datasets.CIFAR10(root='.', train=True,download=True, transform=tfs)


Files already downloaded and verified


In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"

config.lr = 5e-5
config.batch_size= 64
config.noise_dim = 128
config.num_epochs = 5
config.features_critic = 64
config.features_g = 64
config.critic_iters = 5
config.weight_clip = 0.01

gen = Generator(config.noise_dim, config.img_channels, config.features_g).to(device)
critic = Critic(config.img_channels, config.features_critic).to(device)
initialize_weights(critic)
initialize_weights(gen)

# wandb.watch(gen, log="all")
# wandb.watch(critic, log="all")
#print("Generator", gen)
#print("Critic", critic)

fixed_noise = torch.randn(32, config.noise_dim, 1, 1).to(device)

loader = DataLoader(dset, batch_size=config.batch_size, shuffle=True)

opt_critic = optim.RMSprop(critic.parameters(), lr=config.lr)
opt_gen = optim.RMSprop(gen.parameters(), lr=config.lr)

gen.train()
critic.train()
step = 0

for epoch in range(config.num_epochs):
  for batch_idx, (real_images, real_labels) in enumerate(loader):
    real_images = real_images.to(device)
    cur_batch_size = real_images.shape[0]
    # train descriminator

    for _ in range(config.critic_iters):
      #generate noise
      noise = torch.randn(cur_batch_size, config.noise_dim, 1, 1).to(device)
      fake_images = gen(noise)
      critic_real = critic(real_images).reshape(-1)
      critic_fake = critic(fake_images).reshape(-1)
      loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
      critic.zero_grad()
      loss_critic.backward(retain_graph=True)
      opt_critic.step()

      for p in critic.parameters():
        p.data.clamp_(-config.weight_clip, config.weight_clip)

    gen_fake = critic(fake_images).reshape(-1)
    loss_gen = -torch.mean(gen_fake)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()


    if batch_idx % 100 == 0 and batch_idx > 0:
        gen.eval()
        critic.eval()
        
        losses = {"loss_critic": loss_critic, "loss_g":loss_gen}
        print(losses)
        wandb.log(losses)
        print(
            f"Epoch [{epoch}/{config.num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss Critic: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
            fake = gen(noise)

            
            img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
            img_grid_real = torchvision.utils.make_grid(real_images[:32], normalize=True) 
            # show_img(img_grid_fake, f'{epoch}-fake')
            # show_img(img_grid_real, f'{epoch}-real')

            wandb.log({"fake_images": wandb.Image(img_grid_fake)})
            wandb.log({"real_images": wandb.Image(img_grid_real)})

        step += 1
        gen.train()
        critic.train()
  

{'loss_critic': tensor(-56.7876, device='cuda:0', grad_fn=<NegBackward>), 'loss_g': tensor(9.1556, device='cuda:0', grad_fn=<NegBackward>)}
Epoch [0/5] Batch 100/782                   Loss Critic: -56.7876, loss G: 9.1556
{'loss_critic': tensor(-57.1206, device='cuda:0', grad_fn=<NegBackward>), 'loss_g': tensor(9.3518, device='cuda:0', grad_fn=<NegBackward>)}
Epoch [0/5] Batch 200/782                   Loss Critic: -57.1206, loss G: 9.3518
{'loss_critic': tensor(-57.1512, device='cuda:0', grad_fn=<NegBackward>), 'loss_g': tensor(9.3524, device='cuda:0', grad_fn=<NegBackward>)}
Epoch [0/5] Batch 300/782                   Loss Critic: -57.1512, loss G: 9.3524
{'loss_critic': tensor(-57.1500, device='cuda:0', grad_fn=<NegBackward>), 'loss_g': tensor(9.3565, device='cuda:0', grad_fn=<NegBackward>)}
Epoch [0/5] Batch 400/782                   Loss Critic: -57.1500, loss G: 9.3565
{'loss_critic': tensor(-57.1710, device='cuda:0', grad_fn=<NegBackward>), 'loss_g': tensor(9.3618, device='cuda: