In [1]:
import torch
from torch import nn
from torchvision import transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [2]:
def process_img(x):
    x = tfs.ToTensor()(x)
    x = (x - 0.5) / 0.5
    return x

In [3]:
train_set = MNIST('./mnist', train = True, download = True, transform = process_img)

In [4]:
x = np.random.randint(1, 10000, [1,20])

In [5]:
train_labeled = []

In [6]:
for i in x[0]:
    img = train_set[i]
    img = list(img)
    train_labeled.append(img)

In [7]:
NOISE_DIM = 96
class generator(nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
                    nn.Linear(noise_dim, 1024),
                    nn.ReLU(True),
                    nn.Linear(1024, 1024),
                    nn.ReLU(True),
                    nn.Linear(1024, 7 * 7 * 128),
                    nn.ReLU(True))
        
        self.conv = nn.Sequential(
                    nn.ConvTranspose2d(128, 64, 4, 2, padding = 1),
                    nn.ReLU(True),
                    nn.BatchNorm2d(64),
                    nn.ConvTranspose2d(64, 1, 4, 2, padding = 1),
                    nn.Tanh())
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

In [8]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc = nn.Sequential(
                    nn.Linear(1024, 800),
                    nn.ReLU(True),
                    nn.Linear(800, 200),
                    nn.ReLU(True),
                    nn.Linear(200, 11))
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

In [9]:
bce_loss = nn.BCEWithLogitsLoss()

def generator_loss(logits_fake):
    size = logits_fake.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    loss = bce_loss(logits_fake[:,10].view(20, 1), true_labels)
    return loss

In [10]:
def discriminator_loss(labeled, real_label, logits_real, logits_fake):
    size = logits_real.shape[0]
    true_labels = torch.ones(size, 1)
    true_labels = true_labels.float().cuda()
    false_labels = torch.zeros(size, 1)
    false_labels = false_labels.float().cuda()
    loss1 = nn.CrossEntropyLoss()
    loss11 = loss1(labeled[:,0:10], real_label)
    loss2 = bce_loss(logits_real[:,10].view(20, 1), true_labels) + bce_loss(logits_fake[:,10].view(20, 1), false_labels)
    return loss11 + loss2

In [11]:
train_labeled_data = DataLoader(train_labeled, batch_size = 20)

In [12]:
for img, label in train_labeled_data:
    print(img.shape)
    print(label.shape)

torch.Size([20, 1, 28, 28])
torch.Size([20])


In [13]:
train_unlabeled_data = DataLoader(train_set, batch_size = 20, shuffle = True)

In [14]:
test_set = MNIST('./mnist', train = False, transform = process_img)
test_data = DataLoader(test_set, batch_size = 64, shuffle = False)

In [15]:
G_net = generator().cuda()
D_net = discriminator().cuda()
generator_optim = torch.optim.Adam(G_net.parameters(), lr = 1e-4, betas = (0.5, 0.999))
discriminator_optim = torch.optim.Adam(D_net.parameters(), lr = 1e-4, betas = (0.5, 0.999))

In [16]:
def train_gan(discriminator, generator, discriminator_loss, generator_loss, discriminator_optim, generator_optim, 
              noise_size=96, num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for x,_ in train_unlabeled_data:
            bs = x.shape[0]
            labeled = 0
            real_label = 0
            
            real_data = x.cuda()
            logits_real = discriminator(real_data)
            
            sample_noise = (torch.rand(bs, noise_size) - 0.5)/ 0.5
            g_fake_seed = sample_noise.cuda()
            fake_image = generator(g_fake_seed)
            logits_fake = discriminator(fake_image)
            
            for img, label in train_labeled_data:
                img = img.cuda()
                label = label.cuda()
                labeled = discriminator(img)
                real_label = label
            d_error = discriminator_loss(labeled, real_label, logits_real, logits_fake)
            discriminator_optim.zero_grad()
            d_error.backward(retain_graph = True)
            discriminator_optim.step()
            
            g_fake_seed = sample_noise.cuda()
            fake_image = generator(g_fake_seed)
            g_logits_fake = discriminator(fake_image)
            g_error = generator_loss(g_logits_fake)
            generator_optim.zero_grad()
            g_error.backward()
            generator_optim.step()
            
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_error.data.item(), g_error.data.item()))
    

In [17]:
train_gan(D_net, G_net, discriminator_loss, generator_loss, discriminator_optim, generator_optim, num_epochs = 10)

Iter: 0, D: 0.1126, G:3.866
Iter: 0, D: 0.04591, G:4.026
Iter: 0, D: 0.2255, G:4.943
Iter: 0, D: 0.2923, G:3.113
Iter: 0, D: 0.001428, G:17.75
Iter: 0, D: 0.3236, G:4.409
Iter: 0, D: 0.03789, G:9.205
Iter: 0, D: 0.008513, G:4.872
Iter: 0, D: 0.00649, G:5.669
Iter: 0, D: 0.01966, G:5.766


In [18]:
for e in range(3):
    eval_acc = 0
    for im, label in test_data:
        im = im.cuda()
        label = label.cuda()
        out = D_net(im)
        _,pred = out[:,0:10].max(1)
        num_correct = (pred == label).sum().data.item()
        acc = num_correct / im.shape[0]
        eval_acc += acc
    print('acc:{:.6f}'.format(eval_acc/len(test_data)))

acc:0.520999
acc:0.520999
acc:0.520999
