In [1]:
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


In [2]:
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, 1024, 4, stride=1)
        self.bn1 = nn.BatchNorm2d(1024)
        self.deconv2 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(512)
        self.deconv3 = nn.ConvTranspose2d(512, 256, 4, stride=2)
        self.bn3 = nn.BatchNorm2d(256)
        self.deconv4 = nn.ConvTranspose2d(256, 128, 4, stride=2)
        self.bn4 = nn.BatchNorm2d(128)
        self.deconv5 = nn.ConvTranspose2d(128, 1, 4, stride=2)


    def forward(self, x):
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))
        x = F.relu(self.bn4(self.deconv4(x)))
        x = F.tanh(self.deconv5(x))
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.ConvTranspose2d(1, 128, 4, stride=2)
        self.conv2 = nn.ConvTranspose2d(128, 256, 4, stride=2)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.ConvTranspose2d(256, 512, 4, stride=2)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.ConvTranspose2d(512, 1024, 4, stride=2)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.ConvTranspose2d(1024, 1, 4, stride=1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)),0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)),0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)),0.2)
        x = F.sigmoid(self.conv5(x))
        return x



In [3]:
# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 20

# data_loader
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

In [4]:
G = Generator()
D = Discriminator()

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

zz = Variable(torch.rand((100,100, 1, 1)))


In [5]:
G_losses = []
D_losses = []
for epoch in range(100):
    for i, (x,_) in enumerate(train_loader):
        mini_batch = x.size()[0]
        y1 = torch.ones(mini_batch)
        y0 = torch.zeros(mini_batch)
        z1 = torch.randn((mini_batch, 100, 1, 1))
        z2 = torch.randn((mini_batch, 100, 1, 1))
        x,y0,y1  = Variable(x), Variable(y0), Variable(y1)
        z1, z2 = Variable(z1), Variable(z2)

        #training Generator
        G_optimizer.zero_grad()
        G_result = G(z1)
        D_prob = D(G_result)
        G_loss = BCE_loss(D_prob, y1)
        G_losses.append(G_loss.cpu().data.numpy()[0])
        G_loss.backward()
        G_optimizer.step()
        
        #training Discriminator
        D_optimizer.zero_grad()
        D_result_real = D(x)
        D_result_fake = D(G(z2))
        D_loss_real = BCE_loss(D_result_real, y1)
        D_loss_fake = BCE_loss(D_result_fake, y0)
        D_loss = D_loss_real + D_loss_fake
        D_losses.append(D_loss.cpu().data.numpy()[0])
        D_loss.backward()
        D_optimizer.step()        
        
#     visible = len(train_loader) // 3
#     if i%visible == visible - 1:
        print("epoch: {}, batch: {}".format(epoch, i))    
    
    print(sum(G_losses) / len(G_losses),sum(D_losses) / len(D_losses))
    if epoch % 10 == 0:
        noize_res = G(zz).cpu().data.view(100,28,28).numpy()/2 + 0.5
        plt.figure(figsize=(10,10))
        for j in range(10):
            for k in range(10):
                plt.subplot(10,10,j*10+k+1)
                plt.imshow(noize_res[j*10+k],cmap = 'gray')
        plt.show()    

            
        


RuntimeError: $ Torch: not enough memory: you tried to allocate 12GB. Buy new RAM! at /pytorch/torch/lib/TH/THGeneral.c:253