In [1]:
import torch
from torch import nn
from torchvision import transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [2]:
def process_img(x):
    x = tfs.ToTensor()(x)
    x = (x - 0.5) / 0.5
    return x

In [3]:
train_set = MNIST('./mnist', train = True, download = True, transform = process_img)

In [4]:
x = np.random.randint(1,60000,[1, 200])

In [5]:
train_labeled = []

In [6]:
for i in x[0]:
    img = train_set[i]
    img = list(img)
    train_labeled.append(img)

In [7]:
NOISE_DIM = 96
class generator(nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
                    nn.Linear(noise_dim, 1024),
                    nn.ReLU(True),
                    nn.Linear(1024, 1024),
                    nn.ReLU(True),
                    nn.Linear(1024, 7 * 7 * 128),
                    nn.ReLU(True))
        
        self.conv = nn.Sequential(
                    nn.ConvTranspose2d(128, 64, 4, 2, padding = 1),
                    nn.ReLU(True),
                    nn.BatchNorm2d(64),
                    nn.ConvTranspose2d(64, 1, 4, 2, padding = 1),
                    nn.Tanh())
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)
        x = self.conv(x)
        return x

In [8]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        
        self.fc = nn.Sequential(
                    nn.Linear(1024, 800),
                    nn.ReLU(True),
                    nn.Linear(800, 200),
                    nn.ReLU(True),
                    nn.Linear(200, 11))
        
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

In [9]:
bce_loss = nn.BCEWithLogitsLoss()

def generator_loss(logits_fake):
    size = logits_fake.shape[0]
    true_labels = torch.ones(size, 1).float().cuda()
    loss = bce_loss(logits_fake[:,10].view(20, 1), true_labels)
    return loss

In [10]:
def discriminator_loss(labeled, real_label, logits_real, logits_fake):
    size = logits_real.shape[0]
    true_labels = torch.ones(size, 1)
    true_labels = true_labels.float().cuda()
    false_labels = torch.zeros(size, 1)
    false_labels = false_labels.float().cuda()
    loss1 = nn.CrossEntropyLoss()
    loss11 = loss1(labeled[:,0:10], real_label)
    loss2 = bce_loss(logits_real[:,10].view(20, 1), true_labels) + bce_loss(logits_fake[:,10].view(20, 1), false_labels)
    return loss11 + loss2

In [11]:
def discriminator_loss1(logits_real, logits_fake, label):
    size = logits_real.shape[0]
    true_labels = torch.ones(size, 1)
    ture_labels = true_labels.float()
    false_labels = torch.zeros(size, 1)
    false_labels = false_labels.float()
    loss1 = nn.CrossEntropyLoss()
    loss11 = loss1(logits_real[:, 0:10], label)
    loss2 = bce_loss(logits_real[:,10].view(20, 1), true_labels) + bce_loss(logits_fake[:,10].view(20, 1), false_labels)
    return loss11 + loss2

In [12]:
train_labeled_data = DataLoader(train_labeled, batch_size = 20,shuffle = True)

In [13]:
train_unlabeled_data = DataLoader(train_set, batch_size = 20, shuffle = True)

In [14]:
test_set = MNIST('./mnist', train = False, transform = process_img)
test_data = DataLoader(test_set, batch_size = 64, shuffle = False)

In [15]:
G_net = generator().cuda()
D_net = discriminator().cuda()
generator_optim = torch.optim.Adam(G_net.parameters(), lr = 1e-4, betas = (0.5, 0.999))
discriminator_optim = torch.optim.Adam(D_net.parameters(), lr = 1e-4, betas = (0.5, 0.999))

In [16]:
def train_gan(discriminator, generator, discriminator_loss, generator_loss, discriminator_optim, generator_optim, 
              noise_size=96, num_epochs=10):
    iter_count = 0
    for epoch in range(num_epochs):
        for im, label in train_labeled_data:
            for x, _ in train_unlabeled_data:
                bs = x.shape[0]
                labeled = 0
                real_label = 0
            
                real_data = x.cuda()
                logits_real = discriminator(real_data)
            
                sample_noise = (torch.rand(bs, noise_size) - 0.5)/ 0.5
                g_fake_seed = sample_noise.cuda()
                fake_image = generator(g_fake_seed)
                logits_fake = discriminator(fake_image)
                        
                im = im.cuda()
                labeled = discriminator(im)
                d_error = discriminator_loss(labeled, label.cuda(), logits_real, logits_fake)
                discriminator_optim.zero_grad()
                d_error.backward(retain_graph = True)
                discriminator_optim.step()
            
                g_fake_seed = sample_noise.cuda()
                fake_image = generator(g_fake_seed)
                g_logits_fake = discriminator(fake_image)
                g_error = generator_loss(g_logits_fake)
                generator_optim.zero_grad()
                g_error.backward()
                generator_optim.step()
        iter_count += 1
            
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_error.data.item(), g_error.data.item()))
        

In [17]:
train_gan(D_net, G_net, discriminator_loss, generator_loss, discriminator_optim, generator_optim, num_epochs = 10)

Iter: 1, D: 0.06529, G:4.046
Iter: 2, D: 0.1085, G:2.572
Iter: 3, D: 0.4235, G:2.245
Iter: 4, D: 0.7058, G:2.393
Iter: 5, D: 0.2183, G:2.189
Iter: 6, D: 0.6098, G:2.532
Iter: 7, D: 0.2188, G:2.145
Iter: 8, D: 0.3296, G:2.394
Iter: 9, D: 0.4154, G:2.531
Iter: 10, D: 0.3921, G:2.732


In [18]:
for e in range(3):
    eval_acc = 0
    for im, label in test_data:
        im = im.cuda()
        label = label.cuda()
        out = D_net(im)
        _,pred = out[:,0:10].max(1)
        num_correct = (pred == label).sum().data.item()
        acc = num_correct / im.shape[0]
        eval_acc += acc
    print('acc:{:.6f}'.format(eval_acc/len(test_data)))

acc:0.817775
acc:0.817775
acc:0.817775


In [19]:
torch.save(discriminator, './discriminator.pth')

  "type " + obj.__name__ + ". It won't be checked "


In [20]:
torch.save(generator, './generator.pth')

  "type " + obj.__name__ + ". It won't be checked "


In [21]:
generator1 = torch.load('./generator.pth')

In [23]:
torch.save(G_net,'./generator.pth')

In [24]:
torch.save(D_net, './discriminator.pth')

In [25]:
torch.load('generator.pth')

generator(
  (fc): Sequential(
    (0): Linear(in_features=96, out_features=1024, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=1024, out_features=6272, bias=True)
    (5): ReLU(inplace)
  )
  (conv): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Tanh()
  )
)