In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [None]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.12.6-py2.py3-none-any.whl (1.7 MB)
[?25l[K     |▏                               | 10 kB 25.0 MB/s eta 0:00:01[K     |▍                               | 20 kB 31.3 MB/s eta 0:00:01[K     |▋                               | 30 kB 25.3 MB/s eta 0:00:01[K     |▉                               | 40 kB 20.7 MB/s eta 0:00:01[K     |█                               | 51 kB 15.7 MB/s eta 0:00:01[K     |█▏                              | 61 kB 11.1 MB/s eta 0:00:01[K     |█▍                              | 71 kB 12.2 MB/s eta 0:00:01[K     |█▋                              | 81 kB 13.3 MB/s eta 0:00:01[K     |█▊                              | 92 kB 12.3 MB/s eta 0:00:01[K     |██                              | 102 kB 13.2 MB/s eta 0:00:01[K     |██▏                             | 112 kB 13.2 MB/s eta 0:00:01[K     |██▍                             | 122 kB 13.2 MB/s eta 0:00:01[K     |██▌                             | 133 kB 13.2 MB/s eta 

In [None]:
import wandb
wandb.init(project='cifar10-gan', entity='tbass134')
config = wandb.config

VBox(children=(Label(value=' 7.52MB of 7.52MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
loss_d,▄▆▁▂▄▄▄▄▂█▁▆▄▄▂
loss_g,▁▅▃▇▃▃▃▄▃▇▄█▁▅▂

0,1
loss_d,0.64843
loss_g,0.67332


In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super().__init__()
    self.d = nn.Sequential(
        
      nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
      nn.LeakyReLU(0.2),
      
      nn.Conv2d(features_d, features_d*2, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(features_d*2),
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*2, features_d*4, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(features_d*4),
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*4, features_d*8, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(features_d*8),
      nn.LeakyReLU(0.2),

      nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
      nn.LeakyReLU(0.2),

      nn.Sigmoid()
    )

  def forward(self, x):
    return self.d(x)

class Generator(nn.Module):
  def __init__(self, channels_noise, channels_img, features_g):
    super().__init__()
    self.g = nn.Sequential(
        
        nn.ConvTranspose2d(channels_noise, features_g * 16, 4, 1, 0, bias=False),
        nn.BatchNorm2d(features_g * 16),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 16, features_g * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 8),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 8, features_g * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 4),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 4, features_g * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features_g * 2),
        nn.ReLU(),

        nn.ConvTranspose2d( features_g * 2, channels_img, 4, 2, 1),
        nn.Tanh()
    )
       
  def forward(self, x):
    return self.g(x)

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
config.img_size = 64
config.img_channels = 3

tfs = transforms.Compose([
                          transforms.Resize(config.img_size),
                          transforms.ToTensor(),
                          transforms.Normalize(
                              [0.5 for _ in range(config.img_channels)],
                              [0.5 for _ in range(config.img_channels)]
                          )
     ])
dset = datasets.CIFAR10(root='.', train=True,download=True, transform=tfs)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to .


In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"

config.lr = 2e-4
config.betas = (0.5, 0.999)
config.batch_size= 128


config.noise_dim = 100
config.num_epochs = 15
config.features_d = 64
config.features_g = 64

gen = Generator(config.noise_dim, config.img_channels, config.features_g).to(device)
disc = Discriminator(config.img_channels, config.features_d).to(device)
initialize_weights(gen)
initialize_weights(disc)

wandb.watch(gen, log="all")
wandb.watch(disc, log="all")
print("Generator", gen)
print("Discriminator", disc)

fixed_noise = torch.randn(32, config.noise_dim, 1, 1).to(device)

loader = DataLoader(dset, batch_size=config.batch_size, shuffle=True)

opt_discriminator = optim.Adam(disc.parameters(), lr=config.lr, betas=config.betas)
opt_generator = optim.Adam(gen.parameters(), lr=config.lr, betas=config.betas)

criterion = nn.BCELoss()
gen.train()
disc.train()
step = 0

for epoch in range(config.num_epochs):
  for batch_idx, (real_images, real_labels) in enumerate(loader):
    real_images = real_images.to(device)
    # train descriminator

    #generate noise
    noise = torch.randn(config.batch_size, config.noise_dim, 1, 1).to(device)
    fake_images = gen(noise)

    d_real = disc(real_images).reshape(-1)
    loss_d_real = criterion(d_real, torch.ones_like(d_real))

    d_fake = disc(fake_images.detach()).reshape(-1) #save fake_images in order to be reused againb
    loss_d_fake = criterion(d_fake, torch.zeros_like(d_fake))

    loss_d = (loss_d_real + loss_d_fake) / 2
    disc.zero_grad()

    loss_d.backward()
    opt_discriminator.step()


    # train Generator
    #min log(1 - D(G(z))) <--> max log(D(G(z)))

    output = disc(fake_images).reshape(-1)
    loss_g = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_g.backward()
    opt_generator.step()

    if batch_idx == 0:
          losses = {"loss_d": loss_d, "loss_g":loss_g}
          print(losses)
          wandb.log(losses)
          print(
              f"Epoch [{epoch}/{config.num_epochs}] Batch {batch_idx}/{len(loader)} \
                    Loss D: {loss_d:.4f}, loss G: {loss_g:.4f}"
          )

          with torch.no_grad():
              fake = gen(fixed_noise)

              
              img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
              img_grid_real = torchvision.utils.make_grid(real_images[:32], normalize=True) 
              # show_img(img_grid_fake, f'{epoch}-fake')
              # show_img(img_grid_real, f'{epoch}-real')

              wandb.log({"fake_images": wandb.Image(img_grid_fake)})
              wandb.log({"real_images": wandb.Image(img_grid_real)})


              step += 1
  

Generator Generator(
  (g): Sequential(
    (0): ConvTranspose2d(100, 1024, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): ReLU()
    (2): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): ReLU()
    (4): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): ReLU()
    (8): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): Tanh()
  )
)
Discriminator Discriminator(
  (d): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.2)
    (4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256,

KeyboardInterrupt: ignored