In [1]:
!pip install wandb -qU

In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mreneelin2020[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

In [4]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)


In [5]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

In [6]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [7]:
# Launch 5 simulated experiments
total_runs = 2
for run in range(total_runs):
  # 🐝 1️⃣ Start a new run to track this script
  wandb.init(
      # Set the project where this run will be logged
      project="basic-GAN", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_{run}", 
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.0003,
      "architecture": "NN",
      "dataset": "MNIST",
      "epochs": 50,
      })
  
  # This simple block simulates a training loop logging metrics
  for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            wandb.log({"epoch":epoch,"batch_idx":batch_idx, "loss_D": lossD, "loss_G": lossG})
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                # img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                # img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                fake_img = wandb.Image(fake)

                wandb.log({"images_fake": fake_img})
                # writer_fake.add_image(
                #     "Mnist Fake Images", img_grid_fake, global_step=step
                # )
                # writer_real.add_image(
                #     "Mnist Real Images", img_grid_real, global_step=step
                # )
                # step += 1
      
  # Mark the run as finished
  wandb.finish()

Epoch [0/50] Batch 0/1875                       Loss D: 0.5930, loss G: 0.7047
Epoch [1/50] Batch 0/1875                       Loss D: 0.8890, loss G: 0.7075
Epoch [2/50] Batch 0/1875                       Loss D: 0.6323, loss G: 0.8811
Epoch [3/50] Batch 0/1875                       Loss D: 0.9070, loss G: 0.6693
Epoch [4/50] Batch 0/1875                       Loss D: 0.4699, loss G: 1.2590
Epoch [5/50] Batch 0/1875                       Loss D: 0.2585, loss G: 1.8989
Epoch [6/50] Batch 0/1875                       Loss D: 0.9006, loss G: 0.6634
Epoch [7/50] Batch 0/1875                       Loss D: 0.4925, loss G: 1.2457
Epoch [8/50] Batch 0/1875                       Loss D: 0.5818, loss G: 0.9829
Epoch [9/50] Batch 0/1875                       Loss D: 1.0409, loss G: 0.7607
Epoch [10/50] Batch 0/1875                       Loss D: 0.8918, loss G: 0.7114
Epoch [11/50] Batch 0/1875                       Loss D: 0.4137, loss G: 1.5105
Epoch [12/50] Batch 0/1875                       L

VBox(children=(Label(value='1.822 MB of 1.822 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
batch_idx,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_D,▅█▅█▁█▄▄█▃▃▂▄▃▃▅▆▅▃▄▃▅▄▄▄▂▃▅▅▄▅▅▅▅▇▅▅▅▆▄
loss_G,▁▁▂▁█▁▄▃▁▆▅▇▅▆▇▃▃▂▆▄▅▅▅▅▄█▆▅▄▅▃▆▄▄▂▃▃▄▃▄

0,1
batch_idx,0.0
epoch,49.0
loss_D,0.55968
loss_G,1.12033


Epoch [0/50] Batch 0/1875                       Loss D: 0.5575, loss G: 1.0861
Epoch [1/50] Batch 0/1875                       Loss D: 0.4913, loss G: 1.2609
Epoch [2/50] Batch 0/1875                       Loss D: 0.5731, loss G: 1.2958
Epoch [3/50] Batch 0/1875                       Loss D: 0.6185, loss G: 0.8986
Epoch [4/50] Batch 0/1875                       Loss D: 0.6025, loss G: 0.9673
Epoch [5/50] Batch 0/1875                       Loss D: 0.8016, loss G: 0.7294
Epoch [6/50] Batch 0/1875                       Loss D: 0.8113, loss G: 0.9677
Epoch [7/50] Batch 0/1875                       Loss D: 0.5424, loss G: 1.2366
Epoch [8/50] Batch 0/1875                       Loss D: 0.6061, loss G: 1.0007
Epoch [9/50] Batch 0/1875                       Loss D: 0.6339, loss G: 1.0671
Epoch [10/50] Batch 0/1875                       Loss D: 0.5924, loss G: 1.2110
Epoch [11/50] Batch 0/1875                       Loss D: 0.6130, loss G: 1.0099
Epoch [12/50] Batch 0/1875                       L

VBox(children=(Label(value='1.192 MB of 1.199 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.993963…

0,1
batch_idx,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_D,▂▁▃▃▇▇▂▃▃▃▃▃█▂▅▅▂▄▃▅▄▄▄▄▃▂▅▄▄▄▂▄▃▅▅▄▄▃▃▅
loss_G,▅██▃▁▄▇▄▇▄▅▄▂▅▃▃▇▄▅▃▄▂▅▃▄▇▆▄▃▄▄▄▄▃▂▃▂▄▃▃

0,1
batch_idx,0.0
epoch,49.0
loss_D,0.7326
loss_G,0.85272
