In [1]:
import torch
from torch.autograd import Variable

In [2]:
def to_variable(x):
  if torch.cuda.is_available():
    x = x.cuda()
  return Variable(x)

In [67]:
import torchvision.transforms as transforms
transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize(mean=0.5, std=0.5)
])

In [68]:
from torchvision.datasets import MNIST
train_dataset = MNIST(root = './', train=True, download=True, transform=transform)
test_dataset= MNIST(root = './', train=False, download=True, transform=transform)

In [69]:
import torch.nn as nn
import torch.utils.data as Data
data_loader = Data.DataLoader(dataset=train_dataset, 
                               batch_size=100, 
                               shuffle=True)

In [56]:
class DisCriminator(nn.Module):
  def __init__(self):
    super(DisCriminator , self).__init__()
    self.conv1 = nn.Conv2d(
        in_channels = 1,
        out_channels = 64,
        kernel_size = 5,
        stride = 2,
        padding = 2,
        bias = True)
    self.leaky_relu = nn.LeakyReLU()
    self.dropout_2d = nn.Dropout2d(0.3)
    self.conv2 = nn.Conv2d(
        in_channels = 64,
        out_channels = 128,
        kernel_size = 5,
        stride = 2,
        padding = 2,
        bias = True)
    self.linearl = nn.Linear(128*7*7 , 1 , bias = True)
    self.sigmoid = nn.Sigmoid()
    
  def forward(self , x):
    out = self.conv1(x)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = self.conv2(out)
    out = self.leaky_relu(out)
    out = self.dropout_2d(out)
    out = out.view(-1 , 128*7*7)
    out = self.linearl(out)
    out = self.sigmoid(out)
    return out

In [73]:
class Generator(nn.Module):
  def __init__(self , latent_dim = 100 , batchnorm = True):
    super(Generator , self).__init__()
    self.latent_dim = latent_dim
    self.batchnorm = batchnorm
    self.linearl = nn.Linear(latent_dim , 7*7*256 , bias=False)
    self.bn1d1 = nn.BatchNorm1d(256*7*7) if batchnorm else None
    self.leaky_relu = nn.LeakyReLU()
    self.conv1 = nn.Conv2d(
        in_channels = 256,
        out_channels = 128,
        kernel_size = 5,
        stride = 1,
        padding = 2,
        bias = False
    )
    self.bn2d1 = nn.BatchNorm2d(128) if batchnorm else None
    self.conv2 = nn.ConvTranspose2d(
        in_channels = 128,
        out_channels = 64,
        kernel_size = 4,
        stride = 2,
        padding = 1,
        bias = False
    )
    self.bn2d2 = nn.BatchNorm2d(64) if batchnorm else None
    self.conv3 = nn.ConvTranspose2d(
        in_channels = 64,
        out_channels = 1,
        kernel_size = 4,
        stride = 2,
        padding = 1,
        bias = False
    )
    self.tanh = nn.Tanh()
  def forward(self , x):
    out = self.linearl(x)
    if self.batchnorm:
      out = self.bn1d1(out)
    out = self.leaky_relu(out)
    out = out.view((-1 , 256 ,7 ,7))
    out = self.conv1(out)
    if self.batchnorm:
      out = self.bn2d1(out)
    out = self.leaky_relu(out)  
    out = self.conv2(out)
    if self.batchnorm:
      out = self.bn2d2(out)
    out = self.leaky_relu(out)
    out = self.conv3(out)
    out = self.tanh(out)
    return out

In [74]:
DCG = Generator()
DCD = DisCriminator()

In [75]:
if torch.cuda.is_available():
  DCD.cuda()
  DCG.cuda()

In [76]:
from torchvision.utils import save_image

In [77]:
def denorm(x):
  out  = (x+1) / 2
  return out.clamp(0,1) 

In [78]:
loss_func = nn.BCELoss()
dcd_opt = torch.optim.Adam(DCD.parameters(), lr=0.001, betas=(0.5, 0.999))
dcg_opt = torch.optim.Adam(DCG.parameters(), lr=0.001, betas=(0.5, 0.999))

In [81]:
for epoch in range(50):
  for i, (images, _) in enumerate(data_loader):
    batch_size = images.size(0)
    images = to_variable(images)

    real_labels = to_variable(torch.ones(batch_size,1))
    fake_labels = to_variable(torch.zeros(batch_size,1))

    outputs = DCD(images)
    d_loss_real = loss_func(outputs , real_labels)
    real_score = outputs

    z = to_variable(torch.randn(batch_size , 100))
    fake_images = DCG(z)
    outputs = DCD(fake_images)
    d_loss_fake = loss_func(outputs , fake_labels)
    fake_score = outputs

    d_loss = d_loss_real+d_loss_fake
    DCD.zero_grad()
    d_loss.backward()
    dcd_opt.step() 

    z = to_variable(torch.randn(batch_size , 100))
    fake_images = DCG(z)
    outputs = DCD(fake_images)

    g_loss = loss_func(outputs , real_labels)
    DCD.zero_grad()
    DCG.zero_grad()
    g_loss.backward()
    dcg_opt.step()

    if (i+1)%300 == 0:
      print("Epoch %d, batch %d, d_loss: %.4f , g_loss: %.4f,"
      " D(x): %.2f , D(G(z)): %.2f"
      %(epoch, i+1 , d_loss.data , g_loss.data , 
        real_score.data.mean() , fake_score.data.mean()))
  if epoch == 0:
    images = images.view(images.size(0) , 1 , 28 , 28)
    save_image(denorm(images) , "./data1/real_images.png")
  fake_images = fake_images.view(fake_images.size(0) , 1 , 28 , 28)
  save_image(denorm(fake_images), "./data1/fake_images-%d.png"%(epoch+1))


Epoch 0, batch 300, d_loss: 1.2461 , g_loss: 0.8178, D(x): 0.58 , D(G(z)): 0.47
Epoch 0, batch 600, d_loss: 1.2770 , g_loss: 0.9189, D(x): 0.56 , D(G(z)): 0.47
Epoch 1, batch 300, d_loss: 1.3925 , g_loss: 0.9700, D(x): 0.49 , D(G(z)): 0.46
Epoch 1, batch 600, d_loss: 1.3569 , g_loss: 0.9516, D(x): 0.55 , D(G(z)): 0.49
Epoch 2, batch 300, d_loss: 1.2084 , g_loss: 0.7795, D(x): 0.54 , D(G(z)): 0.41
Epoch 2, batch 600, d_loss: 1.2238 , g_loss: 0.9888, D(x): 0.58 , D(G(z)): 0.46
Epoch 3, batch 300, d_loss: 1.2902 , g_loss: 0.9517, D(x): 0.56 , D(G(z)): 0.47
Epoch 3, batch 600, d_loss: 1.3198 , g_loss: 0.9802, D(x): 0.51 , D(G(z)): 0.43
Epoch 4, batch 300, d_loss: 1.3147 , g_loss: 0.9847, D(x): 0.54 , D(G(z)): 0.44
Epoch 4, batch 600, d_loss: 1.2468 , g_loss: 0.8836, D(x): 0.57 , D(G(z)): 0.46
Epoch 5, batch 300, d_loss: 1.2195 , g_loss: 0.8719, D(x): 0.53 , D(G(z)): 0.39
Epoch 5, batch 600, d_loss: 1.1018 , g_loss: 1.0026, D(x): 0.60 , D(G(z)): 0.41
Epoch 6, batch 300, d_loss: 1.1668 , g_l