<a href="https://colab.research.google.com/github/tbass134/GAN-pytorch/blob/main/GAN_Pytorch_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter




<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
!pip install wandb
import wandb
wandb.init(project='cifar10-gan', entity='tbass134')
config = wandb.config



VBox(children=(Label(value=' 7.25MB of 7.25MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss_d,█▁▁▃▂▂▂▃▂▅▄▄▃▃▅▂▆▂▅▄▅▅▅▅▄▅▅▅▅▅▅▆▄▅▅▆▅▃▅▄
loss_g,▁▇▆▃█▅▄▄▄▅▄▇█▅▃▅▃▅▅▇▃▂▃▄▄▃▃▃▃▃▄▂▃▂▄▃▃▄▃▃

0,1
loss_d,0.27231
loss_g,2.83241


In [None]:
%xmode verbose
%pdb off

Exception reporting mode: Verbose
Automatic pdb calling has been turned OFF


In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.d = nn.Sequential(
        nn.Linear(img_dim, 1024 ),
        nn.LeakyReLU(0.1),

        nn.Linear(1024,512),
        nn.LeakyReLU(0.1),

        nn.Linear(512,256),
        nn.LeakyReLU(0.1),

        nn.Linear(256,128),
        nn.LeakyReLU(0.1),

        nn.Linear(128,1),

        nn.Sigmoid()
    )
  def forward(self, x):
    return self.d(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.g = nn.Sequential(
        nn.Linear(z_dim, 1024),
        nn.ReLU(0.1),
        nn.Linear(1024, 512),
        nn.ReLU(0.1),
        nn.Linear(512, 256),
        nn.ReLU(0.1),
        nn.Linear(256, 128),
        nn.ReLU(0.1),
        nn.Linear(128, img_dim), # 28x28x1 -> 784
        nn.Tanh()
    )

  def forward(self, x):
    return self.g(x)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def show_img(t, filename):
  npimg = t.cpu().numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
  plt.savefig(f'{filename}.png')

In [None]:
tfs = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
dset = datasets.CIFAR10(root='.', train=True,download=True, transform=tfs)

config.img_height = 32
config.img_width = 32
config.img_channel = 3

config.batch_size = 32

len(dset)


Files already downloaded and verified


50000

In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"

config.lr = 3e-4
config.z_dim = 100
config.num_epochs = 50
config.img_dim = config.img_height * config.img_width * config.img_channel

disc = Discriminator(config.img_dim).to(device)
gen = Generator(config.z_dim, config.img_dim).to(device)

fixed_noise = torch.randn((config.batch_size, config.z_dim)).to(device)

loader = DataLoader(dset, batch_size=config.batch_size, shuffle=True)
opt_discriminator = optim.Adam(disc.parameters(), lr=config.lr)
opt_generator = optim.Adam(gen.parameters(), lr=config.lr)
criterion = nn.BCELoss()

step = 0

for epoch in range(config.num_epochs):
  for batch_idx, (real_images, real_labels) in enumerate(loader):
    real_images = real_images.view(-1, config.img_channel * config.img_height * config.img_width).to(device) #flatten
    batch_size = real_images.shape[0]
    # train descriminator

    #generate noise
    noise = torch.randn(config.batch_size, config.z_dim).to(device)
    fake_images = gen(noise)

    d_real = disc(real_images).view(-1)
    loss_d_real = criterion(d_real, torch.ones_like(d_real))

    d_fake = disc(fake_images.detach()).view(-1) #save fake_images in order to be reused againb
    loss_d_fake = criterion(d_fake, torch.zeros_like(d_fake))

    loss_d = (loss_d_real + loss_d_fake) / 2
    disc.zero_grad()

    loss_d.backward()
    opt_discriminator.step()


    # train Generator
    #min log(1 - D(G(z))) <--> max log(D(G(z)))

    output = disc(fake_images).view(-1)
    loss_g = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_g.backward()
    opt_generator.step()

    if batch_idx == 0:
          wandb.log({"loss_d": loss_d, "loss_g":loss_g})
          print(
              f"Epoch [{epoch}/{config.num_epochs}] Batch {batch_idx}/{len(loader)} \
                    Loss D: {loss_d:.4f}, loss G: {loss_g:.4f}"
          )

          with torch.no_grad():
              fake = gen(fixed_noise).reshape(-1, config.img_channel, config.img_height, config.img_width)
              data = real_images.reshape(-1, config.img_channel, config.img_height, config.img_width)
              
              img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
              img_grid_real = torchvision.utils.make_grid(data, normalize=True) 

              wandb.log({"fake_images": wandb.Image(img_grid_fake)})
              wandb.log({"real_images": wandb.Image(img_grid_real)})


              step += 1
  

Epoch [0/50] Batch 0/1563                     Loss D: 0.6935, loss G: 0.7305
Epoch [1/50] Batch 0/1563                     Loss D: 0.0965, loss G: 5.8390
Epoch [2/50] Batch 0/1563                     Loss D: 0.0086, loss G: 4.5452
Epoch [3/50] Batch 0/1563                     Loss D: 0.0294, loss G: 7.1543
Epoch [4/50] Batch 0/1563                     Loss D: 0.0188, loss G: 6.0301
Epoch [5/50] Batch 0/1563                     Loss D: 0.0323, loss G: 4.1180
Epoch [6/50] Batch 0/1563                     Loss D: 0.0274, loss G: 5.7950
Epoch [7/50] Batch 0/1563                     Loss D: 0.1328, loss G: 4.2625
Epoch [8/50] Batch 0/1563                     Loss D: 0.0745, loss G: 4.3964
Epoch [9/50] Batch 0/1563                     Loss D: 0.1114, loss G: 4.3521
Epoch [10/50] Batch 0/1563                     Loss D: 0.0635, loss G: 4.2178
Epoch [11/50] Batch 0/1563                     Loss D: 0.0262, loss G: 5.0075
Epoch [12/50] Batch 0/1563                     Loss D: 0.0513, loss G: 4.9