# Training GAN
The openAI foundation proposed an improved GAN and was able to apply it on the MNIST dataset. You can found the paper here: https://arxiv.org/abs/1606.03498. Someone else re-implemented the code in Chainer here: https://github.com/musyoku/improved-gan. However the code is quit hard to understand so i will first try to reproduce their results and understand what they did. The code is divided in general code that is used for all different GAN applications and models specific code. For example code that is used for the MNIST model in particular or generating anime faces.

In [1]:
# Some dependencies
import math
import json
import numpy as np
import chainer, os, collections, six, math, random, time, copy,sys
from chainer import cuda, Variable, optimizers, serializers, function, optimizer, initializers
from chainer.utils import type_check
from chainer import functions as F
from chainer import links as L
# add the imported repository to the path, so we can always just import
sys.path.append(os.path.join(os.path.split(os.getcwd())[0],'improved-gan'))

## Params
They formalize the params of the discrimator, generator and classifier in classes. The formalized classes are then used as input by the general GAN code to fit the different applications.

In [2]:
# Base class
# Found in params.py
class Params():
    def __init__(self, dict=None):
        if dict:
            self.from_dict(dict)

    def from_dict(self, dict):
        for attr, value in dict.iteritems():
            if hasattr(self, attr):
                setattr(self, attr, value)

    def to_dict(self):
        dict = {}
        for attr, value in self.__dict__.iteritems():
            if hasattr(value, "to_dict"):
                dict[attr] = value.to_dict()
            else:
                dict[attr] = value
        return dict

    def dump(self):
        for attr, value in self.__dict__.iteritems():
            print "	{}: {}".format(attr, value)

# General GAN code (found in gan.py) :
# These params can be defined for a Discriminator class
class DiscriminatorParams(Params):
    def __init__(self):
        self.ndim_input = 28 * 28
        self.ndim_output = 10
        self.weight_init_std = 1
        self.weight_initializer = "Normal"  # Normal, GlorotNormal or HeNormal
        self.nonlinearity = "elu"
        self.optimizer = "Adam"
        self.learning_rate = 0.001
        self.momentum = 0.5
        self.gradient_clipping = 10
        self.weight_decay = 0
        self.use_feature_matching = False
        self.use_minibatch_discrimination = False

# These params can be defined for a Generator class
class GeneratorParams(Params):
    def __init__(self):
        self.ndim_input = 10
        self.ndim_output = 28 * 28
        self.distribution_output = "universal"  # universal, sigmoid or tanh
        self.weight_init_std = 1
        self.weight_initializer = "Normal"  # Normal, GlorotNormal or HeNormal
        self.nonlinearity = "relu"
        self.optimizer = "Adam"
        self.learning_rate = 0.001
        self.momentum = 0.5
        self.gradient_clipping = 10
        self.weight_decay = 0

# These parameters can 
class ClassifierParams(Params):
    def __init__(self):
        self.ndim_input = 28 * 28
        self.ndim_output = 10
        self.weight_init_std = 1
        self.weight_initializer = "Normal"  # Normal, GlorotNormal or HeNormal
        self.nonlinearity = "elu"
        self.optimizer = "Adam"
        self.learning_rate = 0.001
        self.momentum = 0.5
        self.gradient_clipping = 10
        self.weight_decay = 0
        self.use_feature_matching = False
        self.use_minibatch_discrimination = False

## Sequentials
The sequentials folder implements a lot of general neural network functionality to support the GAN model. For example a deconvolutional layer and weight normalization(https://arxiv.org/abs/1602.07868). I will not discuss all the code in detail since it's quit 

One important class is the Sequential class, which implements a sequence of neural network layer. It is loaded into a chain before optimizing.

In [3]:
import sequential

## General GAN
The code below shows the general code that implements a GAN given the params defined above and a model for the discriminator and generator.

In [4]:
class Sequential(sequential.Sequential):
    """
    Sequential formalizes a sequence of neural network layers
    """
    def __call__(self, x, test=False):
        activations = []
        for i, link in enumerate(self.links):
            if isinstance(link, sequential.functions.dropout):
                x = link(x, train=not test)
            elif isinstance(link, chainer.links.BatchNormalization):
                x = link(x, test=test)
            else:
                x = link(x)
                if isinstance(link, sequential.functions.ActivationFunction):
                    activations.append(x)
        return x, activations

# Following two help saving objects
class Object(object):
    pass


def to_object(dict):
    obj = Object()
    for key, value in dict.iteritems():
        setattr(obj, key, value)
    return obj

class GAN():
    def __init__(self, params_discriminator, params_generator):
        """
        As an input a GAN gets two arguments: a dictionary for the discriminator and a dictionary for the generator
        Both have two items with the key config and model. 
        The config key contains a param object implementing one of the param classes above
        The model key contains a neural network, converted to a dictioniary via the Sequential implementation
        
        """
        self.params_discriminator = copy.deepcopy(params_discriminator)
        self.config_discriminator = to_object(params_discriminator["config"])

        self.params_generator = copy.deepcopy(params_generator)
        self.config_generator = to_object(params_generator["config"])

        self.build_discriminator()
        self.build_generator()
        self._gpu = False

    def build_discriminator(self):
        # discriminator model is extracted and loaded into a chain next we can build a optimizer
        self.discriminator = sequential.chain.Chain()
        self.discriminator.add_sequence(sequential.from_dict(self.params_discriminator["model"]))
        config = self.config_discriminator
        self.discriminator.setup_optimizers(config.optimizer, config.learning_rate, config.momentum)

    def build_generator(self):
        #generator model is extracted and loaded into a chain next we can build a optimizer
        self.generator = sequential.chain.Chain()
        self.generator.add_sequence(sequential.from_dict(self.params_generator["model"]))
        config = self.config_discriminator
        self.generator.setup_optimizers(config.optimizer, config.learning_rate, config.momentum)

    def update_learning_rate(self, lr):
        #Change learning rate of both discriminator and generator seperately
        self.discriminator.update_learning_rate(lr)
        self.generator.update_learning_rate(lr)

    def to_gpu(self):
        #Make sure both networks are trained on GPU
        self.discriminator.to_gpu()
        self.generator.to_gpu()
        self._gpu = True

    @property
    def gpu_enabled(self):
        # If gpu is set to true and cuda is available
        if cuda.available is False:
            return False
        return self._gpu

    @property
    def xp(self):
        # Get's cupy if gpu is enabled otherwise numpy
        if self.gpu_enabled:
            return cuda.cupy
        return np

    def to_variable(self, x):
        # Helper function converts variable deals with gpu
        if isinstance(x, Variable) == False:
            x = Variable(x)
            if self.gpu_enabled:
                x.to_gpu()
        return x

    def to_numpy(self, x):
        # helper functions converts to numpy deals with gpu
        if isinstance(x, Variable) == True:
            x = x.data
        if isinstance(x, cuda.ndarray) == True:
            x = cuda.to_cpu(x)
        return x

    def get_batchsize(self, x):
        # Gets batch size
        return x.shape[0]

    def zero_grads(self):
        # Reset all grads
        self.optimizer_discriminator.zero_grads()
        self.optimizer_generative_model.zero_grads()

    def sample_z(self, batchsize=1):
        """ Generates a random z sample from an uniform distribution
        the gerenator will generate an image based on that input will use a complete batch
        """
        
        config = self.config_generator
        ndim_z = config.ndim_input
        # uniform
        z_batch = np.random.uniform(-1, 1, (batchsize, ndim_z)).astype(np.float32)
        # gaussian
        # z_batch = np.random.normal(0, 1, (batchsize, ndim_z)).astype(np.float32)
        return z_batch

    def generate_x(self, batchsize=1, test=False, as_numpy=False):
        """
        This function lets the generator generate a variable
        It combines the input from sample_z with generate_x_from_z to generator z
        """
        return self.generate_x_from_z(self.sample_z(batchsize), test=test, as_numpy=as_numpy)

    def generate_x_from_z(self, z_batch, test=False, as_numpy=False):
        """
        This functions generates x a sample by the generator given a random input z. 
        it will automatically get the batch size from z_batch
        """
        z_batch = self.to_variable(z_batch)
        x_batch, _ = self.generator(z_batch, test=test, return_activations=True)
        if as_numpy:
            return self.to_numpy(x_batch)
        return x_batch

    def discriminate(self, x_batch, test=False, apply_softmax=True):
        """
        Given an example produced by the generator (likely by using generate_x), will call 
        the discriminator and return an probability to see if this is a fake our another class
        """
        x_batch = self.to_variable(x_batch)
        prob, activations = self.discriminator(x_batch, test=test, return_activations=True)
        if apply_softmax:
            prob = F.softmax(prob)
        return prob, activations

    def backprop_discriminator(self, loss):
        # Backpropagate and learn through discriminator
        self.discriminator.backprop(loss)

    def backprop_generator(self, loss):
        # Backpropagate and learn generator
        self.generator.backprop(loss)

    def compute_kld(self, p, q):
        # Helper function that calculates 
        return F.reshape(F.sum(p * (F.log(p + 1e-16) - F.log(q + 1e-16)), axis=1), (-1, 1))

    def get_unit_vector(self, v):
        v /= (np.sqrt(np.sum(v ** 2, axis=1)).reshape((-1, 1)) + 1e-16)
        return v

    def compute_lds(self, x, xi=10, eps=1, Ip=1):
        x = self.to_variable(x)
        y1, _ = self.discriminate(x, apply_softmax=True)
        y1.unchain_backward()
        d = self.to_variable(self.get_unit_vector(np.random.normal(size=x.shape).astype(np.float32)))

        for i in xrange(Ip):
            y2, _ = self.discriminate(x + xi * d, apply_softmax=True)
            kld = F.sum(self.compute_kld(y1, y2))
            kld.backward()
            d = self.to_variable(self.get_unit_vector(self.to_numpy(d.grad)))

        y2, _ = self.discriminate(x + eps * d, apply_softmax=True)
        return -self.compute_kld(y1, y2)

    def load(self, dir=None):
        if dir is None:
            raise Exception()
        self.generator.load(dir + "/generator.hdf5")
        self.discriminator.load(dir + "/discriminator.hdf5")

    def save(self, dir=None):
        if dir is None:
            raise Exception()
        try:
            os.mkdir(dir)
        except:
            pass
        self.generator.save(dir + "/generator.hdf5")
        self.discriminator.save(dir + "/discriminator.hdf5")


## MNIST
Next we will look how they use these general classes to implement MNIST training
### Args
The provided code assumes that the terminal is used and uses args.py to parse the arguments. Since i work in a notebook i solve it by manually crafting an args object

In [6]:
sys.path.append(os.path.join(os.path.split(os.getcwd())[0],'improved-gan/train_mnist'))
args = Object()
args.model_dir = 'mnist'
args.gpu_device = 0
args.seed = None
args.plot_dir = 'mnist-plot'
args.num_labeled = 100


### Model
This defines the generator and descriminator file. It saves the params and model to json and afterwards loads from this again. This makes it easy to change the model to your needs.

The discriminator gets a 28x28 image input flattened to a 756x1 vector and outputs one vector of 10x1: a value for each of the classes. It consists of three fully connected hidden layers 756x1000x500x250x10. 

The generator gets a 50x1 vector latent dimensional input(random variable) and produces a 756 output(the 28x28 image). It also has two hidden layers: 50x500x500x756

In [7]:
from sequential import Sequential
from sequential.layers import Linear, BatchNormalization, MinibatchDiscrimination
from sequential.functions import Activation, dropout, gaussian_noise, softmax

# load params.json
try:
    os.mkdir(args.model_dir)
except:
    pass

# data
image_width = 28
image_height = image_width
ndim_latent_code = 50 # 50 latent dimensional input

# specify discriminator
discriminator_sequence_filename = args.model_dir + "/discriminator.json"

if os.path.isfile(discriminator_sequence_filename):
    print "loading", discriminator_sequence_filename
    with open(discriminator_sequence_filename, "r") as f:
        try:
            params = json.load(f)
        except Exception as e:
            raise Exception("could not load {}".format(discriminator_sequence_filename))
else:
    config = ClassifierParams()
    config.ndim_input = image_width * image_height
    config.ndim_output = 10
    config.weight_init_std = 1
    config.weight_initializer = "GlorotNormal"
    config.use_weightnorm = False
    config.nonlinearity = "softplus"
    config.optimizer = "Adam"
    config.learning_rate = 0.001
    config.momentum = 0.5
    config.gradient_clipping = 10
    config.weight_decay = 0
    config.use_feature_matching = True
    config.use_minibatch_discrimination = False

    discriminator = Sequential(weight_initializer=config.weight_initializer, weight_init_std=config.weight_init_std)
    discriminator.add(gaussian_noise(std=0.3))
    discriminator.add(Linear(config.ndim_input, 1000, use_weightnorm=config.use_weightnorm))
    discriminator.add(gaussian_noise(std=0.5))
    discriminator.add(Activation(config.nonlinearity))
    # discriminator.add(BatchNormalization(1000))
    discriminator.add(Linear(None, 500, use_weightnorm=config.use_weightnorm))
    discriminator.add(gaussian_noise(std=0.5))
    discriminator.add(Activation(config.nonlinearity))
    # discriminator.add(BatchNormalization(500))
    discriminator.add(Linear(None, 250, use_weightnorm=config.use_weightnorm))
    discriminator.add(gaussian_noise(std=0.5))
    discriminator.add(Activation(config.nonlinearity))
    # discriminator.add(BatchNormalization(250))
    if config.use_minibatch_discrimination:
        discriminator.add(MinibatchDiscrimination(None, num_kernels=50, ndim_kernel=5))
    discriminator.add(Linear(None, config.ndim_output, use_weightnorm=config.use_weightnorm))
    # no need to add softmax() here
    discriminator.build()

    params = {
        "config": config.to_dict(),
        "model": discriminator.to_dict(),
    }

    with open(discriminator_sequence_filename, "w") as f:
        json.dump(params, f, indent=4, sort_keys=True, separators=(',', ': '))

discriminator_params = params

# specify generator
generator_sequence_filename = args.model_dir + "/generator.json"

if os.path.isfile(generator_sequence_filename):
    print "loading", generator_sequence_filename
    with open(generator_sequence_filename, "r") as f:
        try:
            params = json.load(f)
        except:
            raise Exception("could not load {}".format(generator_sequence_filename))
else:
    config = GeneratorParams()
    config.ndim_input = ndim_latent_code
    config.ndim_output = image_width * image_height
    config.distribution_output = "tanh"
    config.use_weightnorm = False
    config.weight_init_std = 1
    config.weight_initializer = "GlorotNormal"
    config.nonlinearity = "relu"
    config.optimizer = "Adam"
    config.learning_rate = 0.001
    config.momentum = 0.5
    config.gradient_clipping = 10
    config.weight_decay = 0

    # generator
    generator = Sequential(weight_initializer=config.weight_initializer, weight_init_std=config.weight_init_std)
    generator.add(Linear(config.ndim_input, 500, use_weightnorm=config.use_weightnorm))
    generator.add(BatchNormalization(500))
    generator.add(Activation(config.nonlinearity))
    generator.add(Linear(None, 500, use_weightnorm=config.use_weightnorm))
    generator.add(BatchNormalization(500))
    generator.add(Activation(config.nonlinearity))
    generator.add(Linear(None, config.ndim_output, use_weightnorm=config.use_weightnorm))
    if config.distribution_output == "sigmoid":
        generator.add(Activation("sigmoid"))
    if config.distribution_output == "tanh":
        generator.add(Activation("tanh"))
    generator.build()

    params = {
        "config": config.to_dict(),
        "model": generator.to_dict(),
    }

    with open(generator_sequence_filename, "w") as f:
        json.dump(params, f, indent=4, sort_keys=True, separators=(',', ': '))

generator_params = params

gan = GAN(discriminator_params, generator_params)
gan.load(args.model_dir)

if args.gpu_device != -1:
    cuda.get_device(args.gpu_device).use()
    gan.to_gpu()


loading mnist/discriminator.json
loading mnist/generator.json
loading mnist/generator.hdf5 ...
loading mnist/discriminator.hdf5 ...


## Dataset
Creates a semi-supervised dataset

In [8]:
import mnist_tools
def load_train_images():
    return mnist_tools.load_train_images()


def load_test_images():
    return mnist_tools.load_test_images()

def binarize_data(x):
    threshold = np.random.uniform(size=x.shape)
    return np.where(threshold < x, 1.0, 0.0).astype(np.float32)


def create_semisupervised(images, labels, num_validation_data=10000, num_labeled_data=100, num_types_of_label=10,
                          seed=0):
    if len(images) < num_validation_data + num_labeled_data:
        raise Exception("len(images) < num_validation_data + num_labeled_data")
    training_labeled_x = []
    training_unlabeled_x = []
    validation_x = []
    validation_labels = []
    training_labels = []
    indices_for_label = {}
    num_data_per_label = int(num_labeled_data / num_types_of_label)
    num_unlabeled_data = len(images) - num_validation_data - num_labeled_data

    np.random.seed(seed)
    indices = np.arange(len(images))
    np.random.shuffle(indices)

    def check(index):
        label = labels[index]
        if label not in indices_for_label:
            indices_for_label[label] = []
            return True
        if len(indices_for_label[label]) < num_data_per_label:
            for i in indices_for_label[label]:
                if i == index:
                    return False
            return True
        return False

    for n in xrange(len(images)):
        index = indices[n]
        if check(index):
            indices_for_label[labels[index]].append(index)
            training_labeled_x.append(images[index])
            training_labels.append(labels[index])
        else:
            if len(training_unlabeled_x) < num_unlabeled_data:
                training_unlabeled_x.append(images[index])
            else:
                validation_x.append(images[index])
                validation_labels.append(labels[index])

    # reset seed
    np.random.seed()

    return training_labeled_x, training_labels, training_unlabeled_x, validation_x, validation_labels


def sample_labeled_data(images, labels, batchsize, ndim_x, ndim_y, binarize=True):
    image_batch = np.zeros((batchsize, ndim_x), dtype=np.float32)
    label_onehot_batch = np.zeros((batchsize, ndim_y), dtype=np.float32)
    label_id_batch = np.zeros((batchsize,), dtype=np.int32)
    indices = np.random.choice(np.arange(len(images), dtype=np.int32), size=batchsize, replace=False)
    for j in range(batchsize):
        data_index = indices[j]
        img = images[data_index].astype(np.float32) / 255.0
        image_batch[j] = img.reshape((ndim_x,))
        label_onehot_batch[j, labels[data_index]] = 1
        label_id_batch[j] = labels[data_index]
    if binarize:
        image_batch = binarize_data(image_batch)
    # [0, 1] -> [-1, 1]
    image_batch = image_batch * 2.0 - 1.0
    return image_batch, label_onehot_batch, label_id_batch


def sample_unlabeled_data(images, batchsize, ndim_x, binarize=True):
    image_batch = np.zeros((batchsize, ndim_x), dtype=np.float32)
    indices = np.random.choice(np.arange(len(images), dtype=np.int32), size=batchsize, replace=False)
    for j in range(batchsize):
        data_index = indices[j]
        img = images[data_index].astype(np.float32) / 255.0
        image_batch[j] = img.reshape((ndim_x,))
    if binarize:
        image_batch = binarize_data(image_batch)
    # [0, 1] -> [-1, 1]
    image_batch = image_batch * 2.0 - 1.0
    return image_batch



### training

In [9]:
import sys, os
import numpy as np
import visualizer
from progress import Progress
import pandas as pd
sys.path.append(os.path.split(os.getcwd())[0])


def plot(filename="gen"):
    try:
        os.mkdir(args.plot_dir)
    except:
        pass

    x_fake = gan.generate_x(100, test=True, as_numpy=True)
    x_fake = (x_fake + 1.0) / 2.0
    visualizer.tile_binary_images(x_fake.reshape((-1, 28, 28)), dir=args.plot_dir, filename=filename)


In [10]:
def get_learning_rate_for_epoch(epoch):
    if epoch < 10:
        return 0.001
    if epoch < 50:
        return 0.0003
    return 0.0001


def main():
    # load MNIST images
    images, labels = load_train_images()

    # config
    discriminator_config = gan.config_discriminator
    generator_config = gan.config_generator

    # settings
    # _l -> labeled
    # _u -> unlabeled
    # _g -> generated
    max_epoch = 1000
    num_trains_per_epoch = 500
    plot_interval = 5
    batchsize_l = 100
    batchsize_u = 100
    batchsize_g = batchsize_u

    # seed
    np.random.seed(args.seed)
    if args.gpu_device != -1:
        cuda.cupy.random.seed(args.seed)

    # save validation accuracy per epoch
    csv_results = []

    # create semi-supervised split
    num_validation_data = 10000
    num_labeled_data = args.num_labeled
    if batchsize_l > num_labeled_data:
        batchsize_l = num_labeled_data

    training_images_l, training_labels_l, training_images_u, validation_images, validation_labels = create_semisupervised(
        images, labels, num_validation_data, num_labeled_data, discriminator_config.ndim_output, seed=args.seed)
    print training_labels_l

    # training
    progress = Progress()
    for epoch in xrange(1, max_epoch):
        progress.start_epoch(epoch, max_epoch)
        sum_loss_supervised = 0
        sum_loss_unsupervised = 0
        sum_loss_adversarial = 0
        sum_dx_labeled = 0
        sum_dx_unlabeled = 0
        sum_dx_generated = 0

        gan.update_learning_rate(get_learning_rate_for_epoch(epoch))

        for t in xrange(num_trains_per_epoch):
            # sample from data distribution
            images_l, label_onehot_l, label_ids_l = sample_labeled_data(training_images_l, training_labels_l,
                                                                                batchsize_l,
                                                                                discriminator_config.ndim_input,
                                                                                discriminator_config.ndim_output,
                                                                                binarize=False)
            images_u = sample_unlabeled_data(training_images_u, batchsize_u, discriminator_config.ndim_input,
                                                     binarize=False)
            images_g = gan.generate_x(batchsize_g)
            images_g.unchain_backward()

            # supervised loss
            py_x_l, activations_l = gan.discriminate(images_l, apply_softmax=False)
            loss_supervised = F.softmax_cross_entropy(py_x_l, gan.to_variable(label_ids_l))

            log_zx_l = F.logsumexp(py_x_l, axis=1)
            log_dx_l = log_zx_l - F.softplus(log_zx_l)
            dx_l = F.sum(F.exp(log_dx_l)) / batchsize_l

            # unsupervised loss
            # D(x) = Z(x) / {Z(x) + 1}, where Z(x) = \sum_{k=1}^K exp(l_k(x))
            # softplus(x) := log(1 + exp(x))
            # logD(x) = logZ(x) - log(Z(x) + 1)
            # 		  = logZ(x) - log(exp(log(Z(x))) + 1)
            # 		  = logZ(x) - softplus(logZ(x))
            # 1 - D(x) = 1 / {Z(x) + 1}
            # log{1 - D(x)} = log1 - log(Z(x) + 1)
            # 				= -log(exp(log(Z(x))) + 1)
            # 				= -softplus(logZ(x))
            py_x_u, _ = gan.discriminate(images_u, apply_softmax=False)
            log_zx_u = F.logsumexp(py_x_u, axis=1)
            log_dx_u = log_zx_u - F.softplus(log_zx_u)
            dx_u = F.sum(F.exp(log_dx_u)) / batchsize_u
            loss_unsupervised = -F.sum(log_dx_u) / batchsize_u  # minimize negative logD(x)
            py_x_g, _ = gan.discriminate(images_g, apply_softmax=False)
            log_zx_g = F.logsumexp(py_x_g, axis=1)
            loss_unsupervised += F.sum(F.softplus(log_zx_g)) / batchsize_u  # minimize negative log{1 - D(x)}

            # update discriminator
            gan.backprop_discriminator(loss_supervised + loss_unsupervised)

            # adversarial loss
            images_g = gan.generate_x(batchsize_g)
            py_x_g, activations_g = gan.discriminate(images_g, apply_softmax=False)
            log_zx_g = F.logsumexp(py_x_g, axis=1)
            log_dx_g = log_zx_g - F.softplus(log_zx_g)
            dx_g = F.sum(F.exp(log_dx_g)) / batchsize_g
            loss_adversarial = -F.sum(log_dx_g) / batchsize_u  # minimize negative logD(x)

            # feature matching
            if discriminator_config.use_feature_matching:
                features_true = activations_l[-1]
                features_true.unchain_backward()
                if batchsize_l != batchsize_g:
                    images_g = gan.generate_x(batchsize_l)
                    _, activations_g = gan.discriminate(images_g, apply_softmax=False)
                features_fake = activations_g[-1]
                loss_adversarial += F.mean_squared_error(features_true, features_fake)

            # update generator
            gan.backprop_generator(loss_adversarial)

            sum_loss_supervised += float(loss_supervised.data)
            sum_loss_unsupervised += float(loss_unsupervised.data)
            sum_loss_adversarial += float(loss_adversarial.data)
            sum_dx_labeled += float(dx_l.data)
            sum_dx_unlabeled += float(dx_u.data)
            sum_dx_generated += float(dx_g.data)
            if t % 10 == 0:
                progress.show(t, num_trains_per_epoch, {})

        gan.save(args.model_dir)

        # validation
        images_l, _, label_ids_l = sample_labeled_data(validation_images, validation_labels,
                                                               num_validation_data, discriminator_config.ndim_input,
                                                               discriminator_config.ndim_output, binarize=False)
        images_l_segments = np.split(images_l, num_validation_data // 500)
        label_ids_l_segments = np.split(label_ids_l, num_validation_data // 500)
        sum_accuracy = 0
        for images_l, label_ids_l in zip(images_l_segments, label_ids_l_segments):
            y_distribution, _ = gan.discriminate(images_l, apply_softmax=True, test=True)
            accuracy = F.accuracy(y_distribution, gan.to_variable(label_ids_l))
            sum_accuracy += float(accuracy.data)
        validation_accuracy = sum_accuracy / len(images_l_segments)

        progress.show(num_trains_per_epoch, num_trains_per_epoch, {
            "loss_l": sum_loss_supervised / num_trains_per_epoch,
            "loss_u": sum_loss_unsupervised / num_trains_per_epoch,
            "loss_g": sum_loss_adversarial / num_trains_per_epoch,
            "dx_l": sum_dx_labeled / num_trains_per_epoch,
            "dx_u": sum_dx_unlabeled / num_trains_per_epoch,
            "dx_g": sum_dx_generated / num_trains_per_epoch,
            "accuracy": validation_accuracy,
        })

        # write accuracy to csv
        csv_results.append([epoch, validation_accuracy, progress.get_total_time()])
        data = pd.DataFrame(csv_results)
        data.columns = ["epoch", "accuracy", "min"]
        data.to_csv("{}/result.csv".format(args.model_dir))

        if epoch % plot_interval == 0 or epoch == 1:
            plot(filename="epoch_{}_time_{}min".format(epoch, progress.get_total_time()))




In [11]:
main()

loading images ... (60000 / 60000)
[9, 4, 3, 9, 7, 4, 4, 9, 8, 9, 2, 7, 7, 1, 6, 1, 4, 0, 7, 1, 9, 0, 3, 3, 4, 7, 3, 7, 8, 3, 7, 6, 0, 6, 1, 9, 4, 5, 2, 5, 6, 2, 0, 6, 7, 2, 6, 7, 1, 6, 5, 3, 0, 1, 8, 6, 4, 3, 6, 7, 5, 3, 6, 3, 0, 1, 3, 8, 1, 1, 1, 4, 8, 9, 4, 5, 8, 0, 8, 9, 0, 2, 5, 8, 8, 5, 5, 8, 9, 9, 5, 0, 4, 5, 0, 2, 2, 2, 2, 2]
Epoch 1/1000

KeyboardInterrupt: 