In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

from utils import set_random_seed, load_data, CreateDataLoader

from Network_model import TransposeBN

import config

In [2]:
set_random_seed(config.RANDOM_SEED)

Setting seeds ...... 



In [3]:
PATH = f'DCGAN/{config.USED_DATA}'
PATH

'DCGAN/DOODLE'

In [4]:
if config.USED_DATA == "CIFAR10":
	transform = transforms.Compose([
							   transforms.Resize(64),
							   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
						   ])
	n_channels = 3
elif config.USED_DATA == "MNIST" or config.USED_DATA == "DOODLE":
	transform = transforms.Compose([
							   transforms.Resize(64),
							   transforms.Normalize([0.5], [0.5]),
						   ])
	n_channels = 1

In [5]:
dataset, *_ = load_data(transform)
dataset

<utils.CustomDataSet at 0x7f23d8279df0>

In [6]:
latent_size = 100

In [7]:
dataloader = CreateDataLoader(dataset, batch_size = config.GAN_BATCH_SIZE, device = config.DEVICE)

In [8]:
def weights_init(m):
	classname = m.__class__.__name__
	if classname.find('Conv') != -1:
		m.weight.data.normal_(0.0, 0.02)
	elif classname.find('BatchNorm') != -1:
		m.weight.data.normal_(1.0, 0.02)
		m.bias.data.fill_(0)


In [9]:
class Generator(nn.Module):
	def __init__(self, latent_size, n_channels, filters = [512, 256, 128, 64]):
		super().__init__()

		self.initial = TransposeBN(latent_size, filters[0], 4, 1, 0)


		layers = []

		for i in range(1, len(filters)):
			layers.append(TransposeBN(filters[i-1], filters[i]))
		
		self.Transposed = nn.Sequential(*layers)

		self.out = nn.ConvTranspose2d(filters[-1], n_channels, 4, 2, 1, bias = False)

		self.tanh = nn.Tanh()

	def forward(self, X: Tensor):
		out = self.initial(X)
		out = self.Transposed(out)
		out = self.out(out)
		out = self.tanh(out)

		return out

In [10]:
netG = Generator(latent_size, n_channels).to(config.DEVICE)
netG.apply(weights_init)
netG

Generator(
  (initial): TransposeBN(
    (deConv): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (Bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU(inplace=True)
  )
  (Transposed): Sequential(
    (0): TransposeBN(
      (deConv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (1): TransposeBN(
      (deConv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (2): TransposeBN(
      (deConv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (Bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [11]:
class Discriminator(nn.Module):
	def __init__(self, n_channels, filters = [64, 128, 256, 512]):
		super(Discriminator, self).__init__()
		self.main = nn.Sequential(
			# input is (nc) x 64 x 64
			nn.Conv2d(n_channels, filters[0], 4, 2, 1, bias=False),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf) x 32 x 32
			nn.Conv2d(filters[0], filters[1], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[1]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*2) x 16 x 16
			nn.Conv2d(filters[1], filters[2], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[2]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*4) x 8 x 8
			nn.Conv2d(filters[2], filters[3], 4, 2, 1, bias=False),
			nn.BatchNorm2d(filters[3]),
			nn.LeakyReLU(0.2, inplace=True),
			# state size. (ndf*8) x 4 x 4
			nn.Conv2d(filters[3], 1, 4, 1, 0, bias=False)
		)

	def forward(self, X: Tensor):
		out = self.main(X)

		return out.view(-1, 1).squeeze(1)

In [12]:
netD = Discriminator(n_channels).to(config.DEVICE)
netD.apply(weights_init)
netD

Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [13]:
criterion = nn.BCEWithLogitsLoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0001, betas = [0.5, 0.999])
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = [0.5, 0.999])

In [14]:
fixed_noise = torch.randn(64, latent_size, 1, 1, device=config.DEVICE)
real_label = 1
fake_label = 0

In [15]:
niter = 25
g_loss = []
d_loss = []

for epoch in range(niter):
	for batch in tqdm(dataloader, leave = True, position=0):
		images, _ = batch
		batch_size = images.size(0)
		############################
		# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
		###########################
		# train with real
		netD.zero_grad()
		label = torch.full((batch_size,), real_label, device=config.DEVICE, dtype=torch.float)

		output = netD(images)
		errD_real = criterion(output, label)
		errD_real.backward()
		D_x = output.mean().item()

		# train with fake
		noise = torch.randn(batch_size, latent_size, 1, 1, device=config.DEVICE)
		fake = netG(noise)
		label.fill_(fake_label)
		output = netD(fake.detach())
		errD_fake = criterion(output, label)
		errD_fake.backward()
		D_G_z1 = output.mean().item()
		errD = errD_real + errD_fake
		optimizerD.step()

		############################
		# (2) Update G network: maximize log(D(G(z)))
		###########################
		netG.zero_grad()
		label.fill_(real_label)  # fake labels are real for generator cost
		output = netD(fake)
		errG = criterion(output, label)
		errG.backward()
		D_G_z2 = output.mean().item()
		optimizerG.step()

		s = '[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, niter-1, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)
		
		tqdm.write(s, end = "\r")
		
	fake = netG(fixed_noise)
	vutils.save_image(fake.detach(),f'{PATH}/fake_samples_epoch_{epoch:03d}.png', normalize=True)
		
	
	
	# Check pointing for every epoch
	torch.save(netG.state_dict(), f'{PATH}/netG_epoch_{epoch:03d}.pt')
	# torch.save(netD.state_dict(), 'DCGAN/netD_epoch_%d.pth' % (epoch))

  0%|          | 0/782 [00:00<?, ?it/s]

[0/24] Loss_D: 0.2165 Loss_G: 3.7037 D(x): 2.6847 D(G(z)): -2.5340 / -3.6712994607

  0%|          | 0/782 [00:00<?, ?it/s]

[1/24] Loss_D: 0.0770 Loss_G: 5.1433 D(x): 4.8744 D(G(z)): -3.3410 / -5.13536561

  0%|          | 0/782 [00:00<?, ?it/s]

[2/24] Loss_D: 0.2019 Loss_G: 3.4754 D(x): 3.3674 D(G(z)): -2.5097 / -3.43102566

  0%|          | 0/782 [00:00<?, ?it/s]

[3/24] Loss_D: 0.0699 Loss_G: 3.5432 D(x): 3.6589 D(G(z)): -3.8852 / -3.50738836

  0%|          | 0/782 [00:00<?, ?it/s]

[4/24] Loss_D: 0.1304 Loss_G: 4.5626 D(x): 3.9153 D(G(z)): -2.9990 / -4.54742390

  0%|          | 0/782 [00:00<?, ?it/s]

[5/24] Loss_D: 0.0373 Loss_G: 4.7415 D(x): 4.1523 D(G(z)): -5.0187 / -4.72912680

  0%|          | 0/782 [00:00<?, ?it/s]

[6/24] Loss_D: 0.5018 Loss_G: 2.0569 D(x): 2.0322 D(G(z)): -1.1913 / -1.8946179290

  0%|          | 0/782 [00:00<?, ?it/s]

[7/24] Loss_D: 0.2902 Loss_G: 2.4330 D(x): 2.1933 D(G(z)): -2.1951 / -2.31372516

  0%|          | 0/782 [00:00<?, ?it/s]

[8/24] Loss_D: 0.1488 Loss_G: 3.6134 D(x): 3.2394 D(G(z)): -3.1334 / -3.57410498

  0%|          | 0/782 [00:00<?, ?it/s]

[9/24] Loss_D: 0.2597 Loss_G: 3.4613 D(x): 2.8552 D(G(z)): -2.1253 / -3.41962397

  0%|          | 0/782 [00:00<?, ?it/s]

[10/24] Loss_D: 0.0429 Loss_G: 3.9761 D(x): 4.5287 D(G(z)): -4.8745 / -3.94973978

  0%|          | 0/782 [00:00<?, ?it/s]

[11/24] Loss_D: 0.0110 Loss_G: 5.5335 D(x): 5.7814 D(G(z)): -5.9456 / -5.5269734002

  0%|          | 0/782 [00:00<?, ?it/s]

[12/24] Loss_D: 0.0089 Loss_G: 5.7194 D(x): 6.1382 D(G(z)): -5.5084 / -5.7151984017

  0%|          | 0/782 [00:00<?, ?it/s]

[13/24] Loss_D: 0.6158 Loss_G: 3.0813 D(x): 2.0926 D(G(z)): -0.8976 / -3.0299408450

  0%|          | 0/782 [00:00<?, ?it/s]

[14/24] Loss_D: 0.0145 Loss_G: 5.6234 D(x): 5.4420 D(G(z)): -6.0099 / -5.61662

  0%|          | 0/782 [00:00<?, ?it/s]

[15/24] Loss_D: 0.4063 Loss_G: 2.7791 D(x): 1.9782 D(G(z)): -2.2812 / -2.69024114

  0%|          | 0/782 [00:00<?, ?it/s]

[16/24] Loss_D: 0.0257 Loss_G: 4.9420 D(x): 4.8394 D(G(z)): -6.3930 / -4.92983461

  0%|          | 0/782 [00:00<?, ?it/s]

[17/24] Loss_D: 0.0242 Loss_G: 3.0500 D(x): 4.3082 D(G(z)): -6.6574 / -2.96468921

  0%|          | 0/782 [00:00<?, ?it/s]

[18/24] Loss_D: 0.0902 Loss_G: 4.1883 D(x): 4.7420 D(G(z)): -2.9327 / -4.16802965

  0%|          | 0/782 [00:00<?, ?it/s]

[19/24] Loss_D: 0.0105 Loss_G: 5.6877 D(x): 6.0607 D(G(z)): -5.3270 / -5.68324124

  0%|          | 0/782 [00:00<?, ?it/s]

[20/24] Loss_D: 0.0105 Loss_G: 5.5454 D(x): 5.6174 D(G(z)): -6.0616 / -5.53959151

  0%|          | 0/782 [00:00<?, ?it/s]

[21/24] Loss_D: 2.5538 Loss_G: 1.9455 D(x): -2.2304 D(G(z)): -6.4696 / -1.636719

  0%|          | 0/782 [00:00<?, ?it/s]

[22/24] Loss_D: 0.0168 Loss_G: 5.1667 D(x): 6.3359 D(G(z)): -5.8280 / -5.15741617

  0%|          | 0/782 [00:00<?, ?it/s]

[23/24] Loss_D: 0.0249 Loss_G: 2.9029 D(x): 5.4405 D(G(z)): -8.5798 / -2.7541

  0%|          | 0/782 [00:00<?, ?it/s]

[24/24] Loss_D: 0.5430 Loss_G: 2.4967 D(x): 1.2931 D(G(z)): -1.7888 / -2.38734274