## Imports

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.utils as vutils

## Data loaders

### SVHN Dataset

In [None]:
class SvhnDataset(Dataset):
    def __init__(self, image_size, split):
        self.split = split
        self.use_gpu = True if torch.cuda.is_available() else False

        self.svhn_dataset = self._create_dataset(image_size, split)
        self.label_mask = self._create_label_mask()

    def _create_dataset(self, image_size, split):
        normalize = transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5])
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            normalize])
        return datasets.SVHN(root='./svhn', download=True, transform=transform, split=split)

    def _is_train_dataset(self):
        return True if self.split == 'train' else False

    def _create_label_mask(self):
        if self._is_train_dataset():
            label_mask = torch.zeros(len(self.svhn_dataset)).float()
            label_mask[0:1000] = 1
            return label_mask
        return None

    def __len__(self):
        return len(self.svhn_dataset)

    def __getitem__(self, idx):
        data, label = self.svhn_dataset.__getitem__(idx)
        if self._is_train_dataset():
            return data, label, self.label_mask[idx]
        return data, label

### Get data loaders

In [None]:
def get_loader(image_size, batch_size):
    num_workers = 1

    svhn_train = SvhnDataset(image_size=image_size, split='train')
    svhn_test = SvhnDataset(image_size=image_size, split='test')

    svhn_loader_train = DataLoader(
        dataset=svhn_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    svhn_loader_test = DataLoader(
        dataset=svhn_test,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    return svhn_loader_train, svhn_loader_test

In [None]:
svhn_loader_train, _ = get_loader(image_size=32, batch_size=36)
image_iter = iter(svhn_loader_train)
images, _, _ = image_iter.next()

In [None]:
def view_images(images):
    assert(len(images) >= 36)
    fig, axes = plt.subplots(6, 6, sharex=True, sharey=True, figsize=(5,5))
    for idx, ax in enumerate(axes.flatten()):
        img = images[idx].numpy()
        img = img.transpose(1, 2, 0)
        img = ((img - img.min())*255 / (img.max() - img.min())).astype(np.uint8)
        ax.imshow(img, aspect='equal')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    plt.subplots_adjust(wspace=0, hspace=0)

In [None]:
view_images(images)

## Model

### Conv, deconv helpers

In [None]:
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

### GAN logits

In [None]:
class _ganLogits(nn.Module):
    def __init__(self):
        super(_ganLogits, self).__init__()
    
    def forward(self, class_logits):
        max_val, _ = torch.max(class_logits, 1, keepdim=True)
        stable_class_logits = class_logits - max_val
        max_val = torch.squeeze(max_val)
        gan_logits = torch.log(torch.sum(torch.exp(stable_class_logits), 1)) + max_val
        
        return gan_logits

### Generator

In [None]:
class _netG(nn.Module):
    '''
    GAN generator
    '''
    def __init__(self, nz, ngf, alpha, nc, use_gpu):
        super(_netG, self).__init__()
        self.use_gpu = use_gpu
        
        self.main = nn.Sequential(
            # noise is going into a convolution
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.LeakyReLU(alpha),
            # (ngf * 4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.LeakyReLU(alpha),
            # (ngf * 2) x 8 x 8
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(alpha),
            # (ngf) x 16 x 16
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # (nc) x 32 x 32
        )

    def forward(self, inputs):
        if isinstance(inputs.data, torch.cuda.FloatTensor) and self.use_gpu:
            out = nn.parallel.data_parallel(self.main, inputs, range(1))
        else:
            out = self.main(inputs)
        return out

### Discriminator

In [None]:
class _netD(nn.Module):
    '''
    GAN discruminator
    '''
    def __init__(self, ndf, alpha, nc, drop_rate, num_classes, use_gpu):
        super(_netD, self).__init__()
        self.use_gpu = use_gpu
        
        self.main = nn.Sequential(
            nn.Dropout2d(drop_rate/2.5),
            
            # input is (number_channels) x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(alpha),
            nn.Dropout2d(drop_rate),
            # (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(alpha),
            # (ndf) x 8 x 8
            nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf),
            nn.LeakyReLU(alpha),
            nn.Dropout2d(drop_rate),
            # (ndf) x 4 x 4
            nn.Conv2d(ndf, ndf * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(alpha),
            # (ndf * 2) x 4 x 4
            nn.Conv2d(ndf * 2, ndf * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(alpha),
            # (ndf * 2) x 4 x 4
            nn.Conv2d(ndf * 2, ndf * 2, 3, 1, 0, bias=False),
            nn.LeakyReLU(alpha),
            # (ndf * 2) x 2 x 2
        )
        
        self.features = nn.AvgPool2d(kernel_size=2)

        self.class_logits = nn.Linear(
            in_features=(ndf * 2) * 1 * 1,
            out_features=num_classes)
        
        self.gan_logits = _ganLogits()
        
        self.softmax = nn.Softmax(dim=0)

    def forward(self, inputs):
        if isinstance(inputs.data, torch.cuda.FloatTensor) and self.use_gpu:
            out = nn.parallel.data_parallel(self.main, inputs, range(1))
        else:
            out = self.main(inputs)

        features = self.features(out)
        features = features.squeeze()

        class_logits = self.class_logits(features)

        gan_logits = self.gan_logits(class_logits)
        
        out = self.softmax(class_logits)

        return out, class_logits, gan_logits, features

## Solver

In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import numpy as np
import os

class Solver:
    def __init__(self, svhn_loader_train, svhn_loader_test, batch_size):
        self.nz = 100
        self.nc = 3
        self.alpha = 0.2
        self.drop_rate = .5
        self.ngf = 32
        self.ndf = 64
        self.num_classes = 10
        self.use_gpu = True if torch.cuda.is_available() else False
        self.learning_rate = 0.0002
        self.beta1 = .5
        self.svhn_loader_train = svhn_loader_train
        self.svhn_loader_test = svhn_loader_test
        self.epochs = 25
        self.batch_size = batch_size
        self.out_dir = './train_out'

        self.netG, self.netD = self._build_model()
        self.g_optimizer, self.d_optimizer = self._create_optimizers()

    def _build_model(self):
        netG = _netG(
            self.nz, self.ngf, self.alpha,
            self.nc, self.use_gpu)
        netG.apply(self._weights_init)
        print(netG)
        # TODO: load weights from file if it exists

        netD = _netD(
            self.ndf, self.alpha, self.nc,
            self.drop_rate, self.num_classes, self.use_gpu)
        netD.apply(self._weights_init)
        print(netD)
        # TODO: load weights from file if it exists

        if self.use_gpu:
            netG = netG.cuda()
            netD = netD.cuda()

        return netG, netD

    def _weights_init(self, module):
        '''
        Custom weights initialization called on netG and netD
        '''
        classname = module.__class__.__name__
        if classname.find('Conv') != -1:
            module.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            module.weight.data.normal_(1.0, 0.02)
            module.bias.data.fill_(0)

    def _create_optimizers(self):
        g_params = list(self.netG.parameters())
        d_params = list(self.netD.parameters())

        g_optimizer = optim.Adam(g_params, self.learning_rate, betas=(self.beta1, 0.999))
        d_optimizer = optim.Adam(d_params, self.learning_rate, betas=(self.beta1, 0.999))

        return g_optimizer, d_optimizer

    def _to_var(self, x):
        if self.use_gpu:
            x = x.cuda()
        return Variable(x)

    def _one_hot(self, x):
        label_numpy = x.data.cpu().numpy()
        label_onehot = np.zeros((label_numpy.shape[0], self.num_classes))
        label_onehot[np.arange(label_numpy.shape[0]), label_numpy] = 1
        label_onehot = self._to_var(torch.FloatTensor(label_onehot))
        return label_onehot
    
    def _reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self):
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)
        
        d_gan_criterion = nn.BCEWithLogitsLoss()
        d_gan_class_criterion = nn.BCEWithLogitsLoss()
        
        noise = torch.FloatTensor(self.batch_size, self.nz, 1, 1)
        
        fixed_noise = torch.FloatTensor(self.batch_size, self.nz, 1, 1).normal_(0, 1)
        fixed_noise = self._to_var(fixed_noise)
        
        d_gan_labels_real = torch.LongTensor(batch_size)
        d_gan_labels_fake = torch.LongTensor(batch_size)

        for epoch in range(1, self.epochs + 1):
            masked_correct = 0
            num_samples = 0
            
            for i, data in enumerate(self.svhn_loader_train):
                # load svhn dataset
                svhn_data, svhn_labels, label_mask = data
                svhn_data = self._to_var(svhn_data)
                svhn_labels = self._to_var(svhn_labels).long().squeeze()
                label_mask = self._to_var(label_mask).float().squeeze()

                # -------------- train netD --------------

                self._reset_grad()

                # train with real images
                # d_out == softmax(d_class_logits)
                d_out, d_class_logits_on_data, d_gan_logits_real, d_sample_features = self.netD(svhn_data)
                d_gan_labels_real.resize_as_(svhn_labels.data.cpu()).fill_(1)
                d_gan_labels_real_var = self._to_var(d_gan_labels_real).float()
                d_gan_loss_real = d_gan_criterion(
                    d_gan_logits_real,
                    d_gan_labels_real_var)
                
                # train with fake images
                noise.resize_(batch_size, self.nz, 1, 1).normal_(0, 1)
                noise_var = self._to_var(noise)
                fake = self.netG(noise_var)

                # call detach() to avoid backprop for netG here
                _, _, d_gan_logits_fake, _ = self.netD(fake.detach())
                d_gan_labels_fake.resize_(batch_size).fill_(0)
                d_gan_labels_fake_var = self._to_var(d_gan_labels_fake).float()
                d_gan_loss_fake = d_gan_criterion(
                    d_gan_logits_fake,
                    d_gan_labels_fake_var)

                d_gan_loss = d_gan_loss_real + d_gan_loss_fake

                # d_out == softmax(d_class_logits)
                # see https://stackoverflow.com/questions/34240703/whats-the-difference-between-softmax-and-softmax-cross-entropy-with-logits/39499486#39499486
                svhn_labels_one_hot = self._one_hot(svhn_labels)
                d_class_loss_entropy = -torch.sum(svhn_labels_one_hot * torch.log(d_out), dim=1)
                
                d_class_loss_entropy = d_class_loss_entropy.squeeze()
                # delim = torch.max(torch.Tensor([1.0, torch.sum(label_mask.data)]))
                # d_class_loss = torch.sum(label_mask * d_class_loss_entropy) / delim
                numpy_labels = svhn_labels.data.cpu().numpy()
                d_class_loss = torch.sum(d_class_loss_entropy) / numpy_labels.shape[0]
                
                d_loss = d_gan_loss + d_class_loss
                
                d_loss.backward()
                self.d_optimizer.step()

                # -------------- update netG --------------
                
                self._reset_grad()

                # call netD again to do backprop for netG here
                noise.resize_(batch_size, self.nz, 1, 1).normal_(0, 1)
                noise_var = self._to_var(noise)
                fake = self.netG(noise_var)

                _, _, _, d_data_features = self.netD(fake)
                
                # Here we set `g_loss` to the "feature matching" loss invented by Tim Salimans at OpenAI.
                # This loss consists of minimizing the absolute difference between the expected features
                # on the data and the expected features on the generated samples.
                # This loss works better for semi-supervised learning than the tradition GAN losses.
                data_features_mean = torch.mean(d_data_features, dim=0).squeeze()
                sample_features_mean = torch.mean(d_sample_features.detach(), dim=0).squeeze()
                
                g_loss = torch.mean(torch.abs(data_features_mean - sample_features_mean))

                g_loss.backward()
                self.g_optimizer.step()

                _, pred_class = torch.max(d_class_logits_on_data, 1)
                eq = torch.eq(svhn_labels, pred_class)
                correct = torch.sum(eq.float())
                # masked_correct += torch.sum(label_mask * eq.float())
                masked_correct += correct
                # num_samples += torch.sum(label_mask)
                num_samples += numpy_labels.shape[0]
                
                if i % 200 == 0:
                    print('Training:\tepoch {}/{}\tdiscr. gan loss {}\tdiscr. class loss {}\tgen loss {}\tsamples {}/{}'.
                        format(epoch, self.epochs, d_gan_loss.data[0], d_class_loss.data[0], g_loss.data[0], 
                               i + 1, len(self.svhn_loader_train)))
                    real_cpu, _, _ = data
                    vutils.save_image(real_cpu,
                            '%s/real_samples.png' % self.out_dir,
                            normalize=True)
                    fake = self.netG(fixed_noise)
                    vutils.save_image(fake.data,
                            '%s/fake_samples_epoch_%03d.png' % (self.out_dir, epoch),
                            normalize=True)
                    
            # accuracy = masked_correct.data[0]/max(1.0, num_samples.data[0])
            print('Training:\tepoch {}/{}\taccuracy {}/{}'.format(epoch, self.epochs, masked_correct, num_samples))

            correct = 0
            num_samples = 0
            for i, data in enumerate(self.svhn_loader_test):
                # load svhn dataset
                svhn_data, svhn_labels = data
                svhn_data = self._to_var(svhn_data)
                svhn_labels = self._to_var(svhn_labels).long().squeeze()

                # -------------- train netD --------------

                # train with real images
                d_out, d_class_logits, _, _ = self.netD(svhn_data)
                _, pred_idx = torch.max(d_class_logits.data, 1)
                eq = torch.eq(svhn_labels.data, pred_idx)
                correct += torch.sum(eq.float())
                num_samples += len(svhn_labels)
                
                if i % 50 == 0:
                    print('Test:\tepoch {}/{}\tsamples {}/{}'.format(
                        epoch, self.epochs, i + 1, len(self.svhn_loader_test)))
                
            # accuracy = correct/max(1.0, 1.0 * num_samples)
            print('Test:\tepoch {}/{}\taccuracy {}/{}'.format(epoch, self.epochs, correct, num_samples))

            # do checkpointing
            torch.save(self.netG.state_dict(), '%s/netG_epoch_%d.pth' % (self.out_dir, epoch))
            torch.save(self.netD.state_dict(), '%s/netD_epoch_%d.pth' % (self.out_dir, epoch))

## Main

In [None]:
image_size = 32
batch_size = 64

In [None]:
svhn_loader_train, svhn_loader_test = get_loader(image_size, batch_size)
solver = Solver(svhn_loader_train, svhn_loader_test, batch_size)

In [57]:
solver.train()

Training:	epoch 1/25	discr. gan loss 2.534649133682251	discr. class loss 4.167726516723633	gen loss 0.0419347807765007	samples 1/1145
Training:	epoch 1/25	discr. gan loss 1.090012550354004	discr. class loss 4.240676403045654	gen loss 0.3187779188156128	samples 201/1145
Training:	epoch 1/25	discr. gan loss 1.0721180438995361	discr. class loss 4.199171543121338	gen loss 0.3505549430847168	samples 401/1145
Training:	epoch 1/25	discr. gan loss 1.047226905822754	discr. class loss 3.856647491455078	gen loss 0.17716744542121887	samples 601/1145
Training:	epoch 1/25	discr. gan loss 0.9562832713127136	discr. class loss 3.523123264312744	gen loss 0.3336448669433594	samples 801/1145
Training:	epoch 1/25	discr. gan loss 0.8121185302734375	discr. class loss 3.1991424560546875	gen loss 0.17021872103214264	samples 1001/1145
Training:	epoch 1/25	accuracy Variable containing:
 24270
[torch.cuda.FloatTensor of size 1 (GPU 0)]
/73257
Test:	epoch 1/25	samples 1/407
Test:	epoch 1/25	samples 51/407
Test:	ep

Training:	epoch 7/25	discr. gan loss 0.2913511097431183	discr. class loss 3.0248942375183105	gen loss 0.3762364089488983	samples 1001/1145
Training:	epoch 7/25	accuracy Variable containing:
 42309
[torch.cuda.FloatTensor of size 1 (GPU 0)]
/73257
Test:	epoch 7/25	samples 1/407
Test:	epoch 7/25	samples 51/407
Test:	epoch 7/25	samples 101/407
Test:	epoch 7/25	samples 151/407
Test:	epoch 7/25	samples 201/407
Test:	epoch 7/25	samples 251/407
Test:	epoch 7/25	samples 301/407
Test:	epoch 7/25	samples 351/407
Test:	epoch 7/25	samples 401/407
Test:	epoch 7/25	accuracy 16266.0/26032
Training:	epoch 8/25	discr. gan loss 0.4859289526939392	discr. class loss 2.816208839416504	gen loss 0.31552788615226746	samples 1/1145
Training:	epoch 8/25	discr. gan loss 0.4358625113964081	discr. class loss 2.603207588195801	gen loss 0.49591678380966187	samples 201/1145
Training:	epoch 8/25	discr. gan loss 0.39033907651901245	discr. class loss 2.712826728820801	gen loss 0.22160223126411438	samples 401/1145
Traini

Process Process-39:
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 343, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt


KeyboardInterrupt: 