## Imports

In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Data loaders

### SVHN Dataset

In [2]:
class SvhnDataset(Dataset):
    def __init__(self, image_size, split):
        self.split = split
        self.use_gpu = True if torch.cuda.is_available() else False

        self.svhn_dataset = self._create_dataset(image_size, split)
        self.label_mask = self._create_label_mask()

    def _create_dataset(self, image_size, split):
        normalize = transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5])
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            normalize])
        return datasets.SVHN(root='./svhn', download=True, transform=transform, split=split)

    def _is_train_dataset(self):
        return True if self.split == 'train' else False

    def _create_label_mask(self):
        if self._is_train_dataset():
            label_mask = torch.zeros(len(self.svhn_dataset)).float()
            label_mask[0:1000] = 1
            return label_mask
        return None

    def __len__(self):
        return len(self.svhn_dataset)

    def __getitem__(self, idx):
        data, label = self.svhn_dataset.__getitem__(idx)
        if self._is_train_dataset():
            return data, label, self.label_mask[idx]
        return data, label

### Get data loaders

In [3]:
def get_loader(image_size, batch_size):
    num_workers = 1

    svhn_train = SvhnDataset(image_size=image_size, split='train')
    svhn_test = SvhnDataset(image_size=image_size, split='test')

    svhn_loader_train = DataLoader(
        dataset=svhn_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    svhn_loader_test = DataLoader(
        dataset=svhn_test,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    return svhn_loader_train, svhn_loader_test

In [None]:
svhn_loader_train, _ = get_loader(image_size=32, batch_size=36)
image_iter = iter(svhn_loader_train)
images, _, _ = image_iter.next()

In [None]:
def view_images(images):
    assert(len(images) >= 36)
    fig, axes = plt.subplots(6, 6, sharex=True, sharey=True, figsize=(5,5))
    for idx, ax in enumerate(axes.flatten()):
        img = images[idx].numpy()
        img = img.transpose(1, 2, 0)
        img = ((img - img.min())*255 / (img.max() - img.min())).astype(np.uint8)
        ax.imshow(img, aspect='equal')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
    plt.subplots_adjust(wspace=0, hspace=0)

In [None]:
view_images(images)

## Model

### Conv, deconv helpers

In [4]:
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)

### Generator

In [5]:
class _netG(nn.Module):
    '''
    GAN generator
    '''
    def __init__(self, num_noise_channels, size_mult, lrelu_alpha, num_output_channels):
        super(_netG, self).__init__()
        self.lrelu_alpha = lrelu_alpha

        # noise is going into a convolution
        self.deconv1 = deconv(
            c_in=num_noise_channels,
            c_out=size_mult * 4,
            k_size=4,
            stride=1,
            pad=0)
        # (size_mult * 4) x 4 x 4

        self.deconv2 = deconv(
            c_in=size_mult * 4,
            c_out=size_mult * 2,
            k_size=4)
        # (size_mult * 2) x 8 x 8

        self.deconv3 = deconv(
            c_in=size_mult * 2,
            c_out=size_mult * 1,
            k_size=4)
        # (size_mult) x 16 x 16

        self.deconv4 = deconv(
            c_in=size_mult,
            c_out=num_output_channels,
            k_size=4,
            bn=False)
        # (num_output_channels) x 16 x 16

    def forward(self, inputs):
        out = F.leaky_relu(self.deconv1(inputs), self.lrelu_alpha)
        out = F.leaky_relu(self.deconv2(out), self.lrelu_alpha)
        out = F.leaky_relu(self.deconv3(out), self.lrelu_alpha)
        out = F.tanh(self.deconv4(out))
        return out

### Discriminator

In [6]:
class _netD(nn.Module):
    '''
    GAN discruminator
    '''
    def __init__(self, size_mult, lrelu_alpha, number_channels, drop_rate, num_classes):
        super(_netD, self).__init__()
        self.drop_rate = drop_rate
        self.lrelu_alpha = lrelu_alpha
        self.size_mult = size_mult
        self.num_classes = num_classes

        # input is (number_channels) x 32 x 32
        self.conv1 = conv(
            c_in=number_channels,
            c_out=size_mult,
            k_size=3,
            bn=False
        )
        # (size_mult) x 16 x 16

        self.conv2 = conv(
            c_in=size_mult,
            c_out=size_mult,
            k_size=3,
        )
        # (size_mult) x 8 x 8

        self.conv3 = conv(
            c_in=size_mult,
            c_out=size_mult,
            k_size=3,
        )
        # (size_mult) x 4 x 4

        self.conv4 = conv(
            c_in=size_mult,
            c_out=size_mult * 2,
            k_size=3,
            stride=1
        )
        # (size_mult * 2) x 4 x 4

        self.conv5 = conv(
            c_in=size_mult * 2,
            c_out=size_mult * 2,
            k_size=3,
            stride=1
        )
        # (size_mult * 2) x 4 x 4

        self.conv6 = conv(
            c_in=size_mult * 2,
            c_out=size_mult * 2,
            k_size=3,
            stride=1,
            pad=0,
            bn=False
        )
        # (size_mult * 2) x 2 x 2

        self.features = nn.AvgPool2d(kernel_size=2)

        self.class_logits = nn.Linear(
            in_features=(size_mult * 2) * 1 * 1,
            out_features=num_classes)

    def forward(self, inputs):
        out = F.dropout2d(inputs, p=self.drop_rate/2.5)

        out = F.leaky_relu(self.conv1(out), self.lrelu_alpha)
        out = F.dropout2d(out, p=self.drop_rate)

        out = F.leaky_relu(self.conv2(out), self.lrelu_alpha)

        out = F.leaky_relu(self.conv3(out), self.lrelu_alpha)
        out = F.dropout2d(out, p=self.drop_rate)

        out = F.leaky_relu(self.conv4(out), self.lrelu_alpha)

        out = F.leaky_relu(self.conv5(out), self.lrelu_alpha)

        out = F.leaky_relu(self.conv6(out), self.lrelu_alpha)

        features = self.features(out)
        features = features.squeeze()

        class_logits = self.class_logits(features)

        # calculate gan logits
        max_val, _ = torch.max(class_logits, 1, keepdim=True)
        stable_class_logits = class_logits - max_val
        max_val = torch.squeeze(max_val)
        gan_logits = torch.log(torch.sum(torch.exp(stable_class_logits), 1)) + max_val

        out = F.softmax(class_logits, dim=0)

        return out, class_logits, gan_logits, features

In [19]:
import torch
import torch.nn as nn
from model import _netG, _netD
from torch import optim
from torch.autograd import Variable
import numpy as np

class Solver:
    def __init__(self, svhn_loader_train, svhn_loader_test, batch_size):
        self.nz = 100
        self.real_image_size = (3, 32, 32)
        self.lrelu_alpha = 1e-2
        self.drop_rate = .5
        self.g_size_mult = 32
        self.d_size_mult = 64
        self.num_classes = 10
        self.use_gpu = True if torch.cuda.is_available() else False
        self.learning_rate = 3e-3
        self.beta1 = .5
        self.svhn_loader_train = svhn_loader_train
        self.svhn_loader_test = svhn_loader_test
        self.epochs = 25
        self.batch_size = batch_size

        self.generator, self.discriminator = self._build_model()
        self.g_optimizer, self.d_optimizer = self._create_optimizers()

    def _build_model(self):
        generator = _netG(
            self.nz, self.g_size_mult, self.lrelu_alpha,
            self.real_image_size[0])
        generator.apply(self._weights_init)
        # TODO: load weights from file if it exists

        discriminator = _netD(
            self.d_size_mult, self.lrelu_alpha, self.real_image_size[0],
            self.drop_rate, self.num_classes)
        discriminator.apply(self._weights_init)
        # TODO: load weights from file if it exists

        if self.use_gpu:
            generator = generator.cuda()
            discriminator = discriminator.cuda()

        return generator, discriminator

    def _weights_init(self, module):
        '''
        Custom weights initialization called on generator and discriminator
        '''
        classname = module.__class__.__name__
        if classname.find('Conv') != -1:
            module.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            module.weight.data.normal_(1.0, 0.02)
            module.bias.data.fill_(0)

    def _create_optimizers(self):
        g_params = list(self.generator.parameters())
        d_params = list(self.discriminator.parameters())

        g_optimizer = optim.Adam(g_params, self.learning_rate)
        d_optimizer = optim.Adam(d_params, self.learning_rate)

        return g_optimizer, d_optimizer

    def _to_var(self, x):
        if self.use_gpu:
            x = x.cuda()
        return Variable(x)

    def _one_hot(self, x):
        ones = torch.sparse.torch.eye(self.num_classes)
        one_hot = ones.index_select(0, x.data.cpu())
        if self.use_gpu:
            one_hot = one_hot.cuda()
        return Variable(one_hot)

    def train(self):
        svhn_iter = iter(self.svhn_loader_train)
        iter_per_epoch = len(svhn_iter)
        print(iter_per_epoch)

        d_gan_criterion = nn.BCEWithLogitsLoss()
        
        noise = torch.FloatTensor(self.batch_size, self.nz, 1, 1)

        for epoch in range(1, self.epochs + 1):
            masked_correct = 0
            num_samples = 0
            total_count_samples = 0
            loop_count = 0

            for _, data in enumerate(self.svhn_loader_train):
                # load svhn dataset
                svhn_data, svhn_labels, label_mask = data
                svhn_data = self._to_var(svhn_data)
                svhn_labels = self._to_var(svhn_labels).long().squeeze()
                label_mask = self._to_var(label_mask).float().squeeze()

                # -------------- train discriminator --------------

                # train with real images
                self.d_optimizer.zero_grad()

                # d_out == softmax(d_class_logits)
                d_out, d_class_logits_on_data, d_gan_logits_real, d_sample_features = self.discriminator(svhn_data)
                d_gan_labels_real = self._to_var(torch.ones_like(d_gan_logits_real.data))
                d_gan_loss_real = d_gan_criterion(
                    d_gan_logits_real,
                    d_gan_labels_real)

                # train with fake images
                noise.resize_(self.batch_size, self.nz, 1, 1).normal_(0, 1)
                noise_var = self._to_var(noise)
                fake = self.generator(noise_var)

                # call detach() to avoid backprop for generator here
                _, _, d_gan_logits_fake, _ = self.discriminator(fake.detach())

                d_gan_labels_fake = self._to_var(torch.zeros_like(d_gan_logits_fake.data))
                d_gan_loss_fake = d_gan_criterion(
                    d_gan_logits_fake,
                    d_gan_labels_fake)

                d_gan_loss = d_gan_loss_real + d_gan_loss_fake

                # d_out == softmax(d_class_logits)
                # see https://stackoverflow.com/questions/34240703/whats-the-difference-between-softmax-and-softmax-cross-entropy-with-logits/39499486#39499486
                svhn_labels_one_hot = self._one_hot(svhn_labels)
                d_class_loss_entropy = -torch.sum(svhn_labels_one_hot * torch.log(d_out), dim=1)
                
                # d_class_loss_entropy = d_class_criterion(
                #     d_class_logits_on_data,
                #     self._one_hot(svhn_labels)
                # )

                d_class_loss_entropy = d_class_loss_entropy.squeeze()
                delim = torch.max(torch.Tensor([1.0, torch.sum(label_mask.data)]))
                d_class_loss = torch.sum(label_mask * d_class_loss_entropy) / delim
                
                d_loss = d_gan_loss + d_class_loss
                d_loss.backward()
                self.d_optimizer.step()

                # -------------- update generator --------------
                
                self.g_optimizer.zero_grad()

                # call discriminator again to do backprop for generator here
                _, _, _, d_data_features = self.discriminator(fake)
                
                # Here we set `g_loss` to the "feature matching" loss invented by Tim Salimans at OpenAI.
                # This loss consists of minimizing the absolute difference between the expected features
                # on the data and the expected features on the generated samples.
                # This loss works better for semi-supervised learning than the tradition GAN losses.
                data_features_mean = torch.mean(d_data_features, dim=0)
                sample_features_mean = torch.mean(d_sample_features.detach(), dim=0)
                
                g_loss = torch.mean(torch.abs(data_features_mean - sample_features_mean))

                _, pred_class = torch.max(d_class_logits_on_data, 1)
                eq = torch.eq(svhn_labels, pred_class)
                correct = torch.sum(eq.float())
                masked_correct += torch.sum(label_mask * eq.float())
                num_samples += torch.sum(label_mask)

                g_loss.backward()
                self.g_optimizer.step()

                total_count_samples += len(svhn_labels)
                loop_count += 1
                # if loop_count%10 == 0:
                #    print('Training:\tepoch {}/{}\tdiscr. gan loss {}\tdiscr. class loss {}\tgen loss {}\tsamples {}/{}'.
                #            format(epoch, self.epochs, d_gan_loss.data[0], d_class_loss.data[0], g_loss.data[0], 
                #                total_count_samples, len(self.svhn_loader_train)))
                    
            accuracy = masked_correct.data[0]/max(1.0, num_samples.data[0])
            print('Training:\tepoch {}/{}\taccuracy {}'.format(epoch, self.epochs, accuracy))

            total_count_samples = 0
            correct = 0
            num_samples = 0
            loop_count = 0
            for _, data in enumerate(self.svhn_loader_test):
                # load svhn dataset
                svhn_data, svhn_labels = data
                svhn_data = self._to_var(svhn_data)
                svhn_labels = self._to_var(svhn_labels).long().squeeze()

                # -------------- train discriminator --------------

                # train with real images
                d_out, _, _, _ = self.discriminator(svhn_data)
                _, pred_idx = torch.max(d_out.data, 1)
                eq = torch.eq(svhn_labels.data, pred_idx)
                correct += torch.sum(eq.float())
                num_samples += len(svhn_labels)
                
                total_count_samples += len(svhn_labels)
                loop_count += 1
                # if loop_count%10 == 0:
                #    print('Test:\tepoch {}/{}\tsamples {}/{}'.format(
                #        epoch, self.epochs, total_count_samples, len(self.svhn_loader_test)))
                
            accuracy = correct/max(1.0, 1.0 * num_samples)
            print('Test:\tepoch {}/{}\taccuracy {}'.format(epoch, self.epochs, accuracy))

            # TODO: save checkpoints and the best model weights

## Solver

## Main

In [20]:
image_size = 32
batch_size = 64

In [21]:
svhn_loader_train, svhn_loader_test = get_loader(image_size, batch_size)
solver = Solver(svhn_loader_train, svhn_loader_test, batch_size)

Using downloaded and verified file: ./svhn/train_32x32.mat
Using downloaded and verified file: ./svhn/test_32x32.mat


In [22]:
solver.train()

1145
Training:	epoch 1/25	accuracy 0.096
Test:	epoch 1/25	accuracy 0.08428088506453596
Training:	epoch 2/25	accuracy 0.064
Test:	epoch 2/25	accuracy 0.10517824216349109
Training:	epoch 3/25	accuracy 0.066
Test:	epoch 3/25	accuracy 0.10003073140749846
Training:	epoch 4/25	accuracy 0.064
Test:	epoch 4/25	accuracy 0.0806699446834665
Training:	epoch 5/25	accuracy 0.063
Test:	epoch 5/25	accuracy 0.11086355255070682
Training:	epoch 6/25	accuracy 0.064
Test:	epoch 6/25	accuracy 0.11140135218192994
Training:	epoch 7/25	accuracy 0.062
Test:	epoch 7/25	accuracy 0.10337277197295636
Training:	epoch 8/25	accuracy 0.065
Test:	epoch 8/25	accuracy 0.107559926244622
Training:	epoch 9/25	accuracy 0.066
Test:	epoch 9/25	accuracy 0.11328365089121081
Training:	epoch 10/25	accuracy 0.062
Test:	epoch 10/25	accuracy 0.14393822987092808
Training:	epoch 11/25	accuracy 0.066
Test:	epoch 11/25	accuracy 0.1840043023970498
Training:	epoch 12/25	accuracy 0.065
Test:	epoch 12/25	accuracy 0.18984326982175784
Training:

Process Process-9:
Process Process-47:
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 343, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 42, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", li

KeyboardInterrupt: 