In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

import numpy as np
from matplotlib import pyplot as plt

In [None]:
train_sets = torchvision.datasets.FashionMNIST(root='./data/',
                                              train=True,
                                              download=True,
                                              transform = transforms.Compose(
                                              [
                                                  #transforms.Resize((64, 64)),
                                                  transforms.ToTensor()                                              
                                              ]))

In [None]:
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_sets,batch_size=batch_size,shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        """
        input (N, z_dim)
        output (N,28,28)
        """
        # transpose convolution
        self.transpose_conv_5 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=z_dim, out_channels = 1024, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            
            nn.ConvTranspose2d(in_channels = 1024, out_channels = 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),

            
            nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(in_channels = 128, out_channels = 1, kernel_size=3, stride=1, padding=0),
        )
        self.tanh = nn.Tanh()
    
        
    def forward(self,batch):
        y = self.transpose_conv_5(batch)
        return self.tanh(y)

In [None]:
class Discrinator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels= 64, kernel_size=5,stride=2,padding=0,bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=64, out_channels= 128, kernel_size=5,stride=2,padding=0,bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=128, out_channels= 256, kernel_size=4,stride=1,padding=1,bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            
            nn.Conv2d(in_channels=256, out_channels= 1, kernel_size=4,stride=2,padding=1,bias=False),
        )
        self.sigmod = nn.Sigmoid()
    
    def forward(self, batch):
        """
        input (batch,1,28,28)
        output (batch)
        """
        y = self.conv(batch)
        return self.sigmod(y)
        

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
z_dim = 100
netG = Generator(z_dim).to(device)
netD = Discrinator().to(device)

netG.apply(weights_init)
netD.apply(weights_init)

In [None]:
criterion = nn.BCELoss()

In [None]:
real_label = 1.0
fake_label = 0.0
lr = 0.001
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
num_epochs = 5
G_losses = []
D_losses = []
img_list = []
iters = 0
for epoch in range(num_epochs):
    for index, batch in enumerate(train_loader):
        data = batch[0].to(device)
        bs = data.size(0)
        
        z = torch.randn(bs,z_dim,1,1).to(device)
        f_imgs = netG(z)
        
        r_label = torch.ones((bs)).to(device)
        f_label = torch.zeros((bs)).to(device)
        
        r_logit = netD(data).view(-1)
        f_logit = netD(f_imgs).view(-1)
        
        # compute loss
        r_loss = criterion(r_logit, r_label)
        f_loss = criterion(f_logit, f_label)
        loss_D = (r_loss + f_loss) / 2

        # update model
        netD.zero_grad()
        loss_D.backward()
        optimizerD.step()

        """ train G """
        # leaf
        z = torch.randn(bs, z_dim,1,1).to(device)
        f_imgs = netG(z)

        # dis
        f_logit = netD(f_imgs).view(-1)
        
        # compute loss
        loss_G = criterion(f_logit, r_label)

        # update model
        netG.zero_grad()
        loss_G.backward()
        optimizerG.step()
        
        G_losses.append(loss_G.item())
        D_losses.append(loss_D.item())
        
        if index % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch, num_epochs, index, len(train_loader),
                     loss_D.item(), loss_G.item()))

In [None]:

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
real_batch[0].shape

In [None]:
real_batch = next(iter(train_loader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
x = torch.randn(128,100,1,1).to(device)


In [None]:
imgs = netG(x).squeeze()

In [None]:
x = torch.randn(128,100,1,1).to(device)
imgs = netG(x)
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(imgs.detach().cpu()[:128], padding=5, normalize=True).cpu(),(1,2,0)))