In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

from utils import set_random_seed, load_data, CreateDataLoader

from Generator import Generator

import config

In [2]:
set_random_seed(config.RANDOM_SEED)

Setting seeds ...... 



In [3]:
PATH = f'DCGAN/{config.USED_DATA}'
PATH

'DCGAN/CIFAR10'

In [4]:
if config.USED_DATA == "CIFAR10":
	transform = transforms.Compose([
							   transforms.Resize(64),
							   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
						   ])
	n_channels = 3
elif config.USED_DATA == "MNIST" or config.USED_DATA == "DOODLE":
	transform = transforms.Compose([
							   transforms.Resize(64),
							   transforms.Normalize([0.5], [0.5]),
						   ])
	n_channels = 1

In [5]:
dataset, *_ = load_data(transform)
dataset

<utils.CustomDataSet at 0x7f4e5c03db80>

In [6]:
latent_size = 100

In [7]:
dataloader = CreateDataLoader(dataset, batch_size = config.GAN_BATCH_SIZE, device = config.DEVICE)

In [8]:
def weights_init(m):
	classname = m.__class__.__name__
	if classname.find('Conv') != -1:
		m.weight.data.normal_(0.0, 0.02)
	elif classname.find('BatchNorm') != -1:
		m.weight.data.normal_(1.0, 0.02)
		m.bias.data.fill_(0)


In [9]:
netG = Generator(latent_size, n_channels).to(config.DEVICE)
netG.apply(weights_init)
netG

Generator(
  (initial): TransposeBN(
    (deConv): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (Bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
  )
  (Transposed): Sequential(
    (0): TransposeBN(
      (deConv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (1): TransposeBN(
      (deConv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (2): TransposeBN(
      (deConv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [10]:
class Discriminator(nn.Module):
	def __init__(self, n_channels, filters = [64, 128, 256, 512]):
		super(Discriminator, self).__init__()
		self.main = nn.Sequential(
			# input is (nc) x 64 x 64
			nn.Conv2d(n_channels, filters[0], 4, 2, 1, bias=False),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf) x 32 x 32
			nn.Conv2d(filters[0], filters[1], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[1]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*2) x 16 x 16
			nn.Conv2d(filters[1], filters[2], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[2]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*4) x 8 x 8
			nn.Conv2d(filters[2], filters[3], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[3]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*8) x 4 x 4
			nn.Conv2d(filters[3], 1, 4, 1, 0, bias=False)
		)

	def forward(self, X: Tensor):
		out = self.main(X)

		return out.view(-1, 1).squeeze(1)

In [11]:
netD = Discriminator(n_channels).to(config.DEVICE)
netD.apply(weights_init)
netD

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [12]:
criterion = nn.BCEWithLogitsLoss()
# optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = [0.5, 0.999])
# optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = [0.5, 0.999])

optimizerD = optim.RMSprop(netD.parameters(), lr = 0.0001)
optimizerG = optim.RMSprop(netG.parameters(), lr = 0.0002)

In [13]:
fixed_noise = torch.randn(64, latent_size, 1, 1, device=config.DEVICE)
real_label = 1
fake_label = 0

In [14]:
niter = 10
g_loss = []
d_loss = []

for epoch in range(10, 10 + niter):
	for batch in tqdm(dataloader, leave = True, position=0):
		images, _ = batch
		batch_size = images.size(0)
		############################
		# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
		###########################
		# train with real
		netD.zero_grad()
		label = torch.full((batch_size,), real_label, device=config.DEVICE, dtype=torch.float)

		output = netD(images)
		errD_real = criterion(output, label)
		errD_real.backward()
		D_x = output.mean().item()

		# train with fake
		noise = torch.randn(batch_size, latent_size, 1, 1, device=config.DEVICE)
		fake = netG(noise)
		label.fill_(fake_label)
		output = netD(fake.detach())
		errD_fake = criterion(output, label)
		errD_fake.backward()
		D_G_z1 = output.mean().item()
		errD = errD_real + errD_fake
		optimizerD.step()

		############################
		# (2) Update G network: maximize log(D(G(z)))
		###########################
		netG.zero_grad()
		label.fill_(real_label)  # fake labels are real for generator cost
		output = netD(fake)
		errG = criterion(output, label)
		errG.backward()
		D_G_z2 = output.mean().item()
		optimizerG.step()

		s = '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter-1, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)
		
		tqdm.write(s, end = "\r")
		
	fake = netG(fixed_noise)
	vutils.save_image(fake.detach(),f'{PATH}/fake_samples_epoch_{epoch:03d}.png', normalize=True)
		
	
	
	# Check pointing for every epoch
	torch.save(netG.state_dict(), f'{PATH}/netG_epoch_{epoch:03d}.pt')

  0%|          | 0/391 [00:00<?, ?it/s]

[10/9] Loss_D: 1.2477 Loss_G: 0.7284 D(x): -0.3896 D(G(z)): -1.1933 / -0.015266

  0%|          | 0/391 [00:00<?, ?it/s]

[11/9] Loss_D: 1.3384 Loss_G: 1.3587 D(x): 0.5868 D(G(z)): 0.2050 / -0.967676

  0%|          | 0/391 [00:00<?, ?it/s]

[12/9] Loss_D: 1.4899 Loss_G: 2.0324 D(x): 1.1621 D(G(z)): 0.7984 / -1.856453

  0%|          | 0/391 [00:00<?, ?it/s]

[13/9] Loss_D: 1.3859 Loss_G: 1.7408 D(x): 0.9505 D(G(z)): 0.5503 / -1.518559

  0%|          | 0/391 [00:00<?, ?it/s]

[14/9] Loss_D: 1.1812 Loss_G: 0.8847 D(x): -0.1502 D(G(z)): -0.8481 / -0.3072

  0%|          | 0/391 [00:00<?, ?it/s]

[15/9] Loss_D: 1.0488 Loss_G: 0.8106 D(x): -0.0452 D(G(z)): -1.0969 / -0.1733

KeyboardInterrupt: 