In [1]:
import torch
import torch.nn as nn
import torchvision
import os
import matplotlib.pyplot as plt
import numpy as np
from math import ceil
import time

In [2]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [3]:
dataroot = '../storage/data/dogscats/train/'
saveroot = '../storage/data/AS_storage/generatedImages/'
batch_size = 128
workers = 2
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 150
num_steps = 150
lr = 0.0002
beta1 = 0.5

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(image_size),
    torchvision.transforms.CenterCrop(image_size),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

In [4]:
ds = torchvision.datasets.ImageFolder(dataroot, transforms)
dl = torch.utils.data.DataLoader(ds, batch_size, shuffle=True, num_workers=workers)

FileNotFoundError: [Errno 2] No such file or directory: '../storage/data/dogscats/train/'

In [None]:
len(ds)

In [None]:
for img, labels in dl:
    print('Image:', img.shape)
    print('Labels:', labels.shape)
    
    grid = torchvision.utils.make_grid(img[1], normalize=True)
    plt.imshow(grid.permute(1,2,0))
    print(labels[1])
    break

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.main(input).view(-1,1)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

generator.apply(weights_init)
discriminator.apply(weights_init)

## LOSS FUNCTION

In [None]:
criterion = nn.BCELoss()

opt_d = torch.optim.Adam(discriminator.parameters(), lr, betas=(beta1, 0.999))
opt_g = torch.optim.Adam(generator.parameters(), lr, betas=(beta1, 0.999))

In [None]:
fixed_noise = torch.randn(36, nz, 1, 1).to(device)

real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
print(real_labels.dtype, fake_labels.shape)

## PRE-TRAIN TEST

In [None]:
sample_img = generator(torch.randn(batch_size,nz, 1, 1).to(device)).detach()
print(sample_img.shape)

sample_out = discriminator(sample_img).detach()
print(sample_out.shape)

loss = criterion(sample_out, real_labels)
print(loss)

In [None]:
start_time = time.time()

loss_d = []
loss_g = []

for epoch in range(num_epochs):
    loss_d_ = 0.0
    loss_g_ = 0.0
    for i, (images, _) in enumerate(dl):
        
        if i == num_steps:
            break
        
        #DISCRIMINATOR
        opt_d.zero_grad()
        
        real_images = images.to(device)
        fake_images = generator(torch.randn(batch_size, nz, 1, 1).to(device)).detach()
        
        real_outputs = discriminator(real_images)
        fake_outputs = discriminator(fake_images)
        
        real_loss = criterion(real_outputs, real_labels)
        fake_loss = criterion(fake_outputs, fake_labels)
        real_loss.backward()
        fake_loss.backward()
        
        opt_d.step()
        
        loss_d_ += real_loss + fake_loss
        
        #GENERATOR
        opt_g.zero_grad()
        
        outputs = discriminator(generator(torch.randn(batch_size, nz, 1,1).to(device)))
        
        loss = criterion(outputs, real_labels)
        loss.backward()
        
        opt_g.step()
        
        loss_g_ += loss
    
    epoch_loss_g = loss_g_/num_steps
    epoch_loss_d = loss_d_/num_steps
    loss_g.append(epoch_loss_g)
    loss_d.append(epoch_loss_d)
    print(f'{epoch}/{num_epochs} | Generator_Loss: {epoch_loss_g:.8f} | Discriminator_Loss: {epoch_loss_d:.8f} | Time Elapsed: {time.time() - start_time:.0f} seconds')
    
    if epoch%1==0:
        sample = generator(fixed_noise).detach()
        grid = torchvision.utils.make_grid(sample.view(-1, 3, 64, 64), nrow=6, pad_value=1, normalize=True)   
        torchvision.utils.save_image(grid.detach().cpu(), os.path.join(saveroot, 'DOGSCATS_DCGAN_{}.jpg'.format(str(epoch).zfill(3))))

print(f'\nTOTAL DURATION: {time.time() - start_time:.0f} seconds')
sample = generator(fixed_noise).detach()
grid = torchvision.utils.make_grid(sample.view(-1, 3, 64, 64), nrow=6, pad_value=1, normalize=True)   
torchvision.utils.save_image(grid.detach().cpu(), os.path.join(saveroot, 'DOGSCATS_DCGAN_050.jpg'))    

In [None]:
plt.figure()

plt.plot(loss_d, label='LOSS D')
plt.plot(loss_g, label='LOSS G')
plt.legend()
plt.show()

In [None]:
torch.save(generator.state_dict(), './catsdogs_dcgan_generator.pt')
torch.save(discriminator.state_dict(), './catsdogs_dcgan_discriminator.pt')

In [None]:
# # LOAD MODEL
# discriminator.load_state_dict(torch.load('./mnist_vanilla_gan_discriminator.pt'))
# generator.load_state_dict(torch.load('./mnist_vanilla_gan_generator.pt'))

In [None]:
sample_img = generator(torch.randn(batch_size,nz, 1, 1).to(device)).detach()

grid = torchvision.utils.make_grid(sample.view(-1, 3, 64, 64)[10], nrow=1, pad_value=1, normalize=True)   
plt.imshow(grid.cpu().permute(1,2,0))