This is an example that synthesizes images based on the polynomial neural networks' paper.

The code below (short version of the original code) trains a polynomial generator in a GAN framework.

That is, the generator is a polynomial with Fully-connected layers and includes no activation functions.

Additional examples can be found in the official repo:

 https://github.com/grigorisg9gr/polynomial_nets

In [None]:
from os.path import join, isdir, isfile
import numpy as np
from functools import partial
# # chainer related imports.
try:
    import chainer
except ImportError as e:
    print('This notebook depends on chainer, please install it to proceed.')
    raise ImportError(e)
import chainer.functions as F
import chainer.links as L
from chainer import Variable
from chainer import training
from chainer.training import extensions

# Define the dataset class

In [None]:
def create_astroid(n_samples, alpha=1, **kwargs):
    """ Create an astroid (see wikipedia's lemma for parametric equations). """
    t = np.linspace(-alpha * 4, alpha * 4, num=n_samples)
    x = alpha / 4 * (3 * np.sin(t) - np.sin(3 * t))
    y = alpha / 4 * (3 * np.cos(t) + np.cos(3 * t))
    return np.vstack((x, y)).astype(np.float32).T


class SyntheticDataset(chainer.dataset.DatasetMixin):
    def __init__(self, n_samples, seed=0, **kwargs):
        sample_fn = partial(create_astroid, **kwargs)
        np.random.seed(seed)
        inp_feats = sample_fn(n_samples=n_samples)
        perm = np.random.permutation(range(inp_feats.shape[0]))
        # # permute and reshape with additional dimensions.
        self.base = np.reshape(inp_feats[perm], inp_feats.shape + (1, 1))

    def __len__(self):
        return len(self.base)

    def get_example(self, i):
        return self.base[i], 1

In [None]:
# # # # uncomment below to view the data distribution.
# synn1 = SyntheticDataset(300)
# pts = np.array([synn1.get_example(i)[0][:, 0, 0] for i in range(len(synn1))])
# import matplotlib.pyplot as plt
# %matplotlib inline
# plt.scatter(pts[:, 0], pts[:, 1])

# Define the generator and discriminator classes

In [None]:
class FCDiscriminator(chainer.Chain):
    def __init__(self, n_input=11, n_out=4, n_hidden=8, n_hidden2=5, sn=False,
                 layer_d=None, **kwargs):
        # # if layer_d is defined it overrides n_hidden[i] with i >= 3.
        # # Note: layer_d should include the value of n_hidden2.
        assert (layer_d is None) or n_hidden2 == 10 or layer_d[0] == n_hidden2
        assert layer_d is not None
        w = chainer.initializers.GlorotUniform()
        super(FCDiscriminator, self).__init__()
        Linear = SNLinear if sn else L.Linear
        self.n_l = n_l = len(layer_d) + 2

        with self.init_scope():
            self.l1 = Linear(n_input, n_hidden, initialW=w)
            self.l2 = Linear(n_hidden, n_hidden2, initialW=w)

            # # iterate over all layers (till the last) and save in self.
            for l in range(3, n_l):
                # # define the input and the output names.
                ni, no = layer_d[l - 3], layer_d[l - 3 + 1]
                setattr(self, 'l{}'.format(l), Linear(ni, no, initialW=w))

            # # save the last layer. 
            ni = layer_d[n_l - 3]
            setattr(self, 'l{}'.format(n_l), Linear(ni, n_out, initialW=w))
            # # add the binary classification layer in the end.
            self.lin = L.Linear(n_out, 1, initialW=w)
        self.activ = F.relu

    def __call__(self, x, y=None, **kwargs):
        h = x
        # # loop over the layers.
        for l in range(1, self.n_l):
            h = self.activ(getattr(self, 'l{}'.format(l))(h))
        # # last layer (no activation).
        repres = getattr(self, 'l{}'.format(self.n_l))(h)
        return self.lin(repres)


In [None]:
def sample_continuous(dim, batchsize, distribution='normal', xp=np):
    if distribution == 'normal':
        return xp.random.randn(batchsize, dim).astype(xp.float32)
    elif distribution == 'uniform':
        return xp.random.uniform(-1, 1, (batchsize, dim)).astype(xp.float32)
    else:
        raise NotImplementedError


class FCProductRecursiveGenerator(chainer.Chain):
    def __init__(self, dim_z=2, channels=[3, 10], power_poly=[6, 8], n_out=2, use_bias=True,
                 distribution='uniform', use_bn=False, derivfc=0):
        """
        Polynomial generator with fully-connected layers. Specifically, it implements the
        product of polynomials with both model 1 of the polygan paper.
        :param dim_z: int; the dimensions of the input noise (the prior distribution samples).
        :param channels: list; each element (int) includes the depth of the respective FC layer. We consider
            that each polynomial has a constant number of layers to avoid obfuscating the code.
        :param power_poly: list; each element in the list denotes the power (i.e. approximation N_i)
            of each polynomial.
        :param n_out: int; the dimensions of the output.
        :param use_bias: bool; whether to use bias in the FC layers.
        :param distribution: str; the prior distribution.
        :param use_bn: bool; whether to use batch normalization.
        """
        w = chainer.initializers.GlorotUniform()
        super(FCProductRecursiveGenerator, self).__init__()
        Linear = L.Linear
        # # Save several attributes from the provided arguments.
        assert isinstance(dim_z, int)
        self.dim_z = dim_z
        self.distribution = distribution
        self.use_bn = use_bn
        self.channels = channels
        # # whether to use bias in the FC layers.
        self.use_bias = use_bias
        # # the input size to the current polynomial; initialize on dimz.
        input_current_poly = dim_z
        self.power_poly = power_poly
        assert len(self.power_poly) == len(self.channels)
        self.derivfc = derivfc
        
        with self.init_scope():
            bn1 = partial(L.BatchNormalization, use_gamma=True, use_beta=False)
            # # iterate over all the polynomials (length of channels many).
            for id_poly in range(len(self.channels)):
                # # set the channels for this polynomial appropriately for each layer.
                channels_poly = channels[id_poly]
                assert isinstance(channels_poly, int)
                # # replace the int with a list (same channels for all). 
                channels_poly = [channels_poly] * (self.power_poly[id_poly] + 1)
                # ensure that the current input channels match the expected.
                setattr(self, 'has_rsz{}'.format(id_poly), input_current_poly != channels_poly[0])
                if input_current_poly != channels_poly[0]:
                    setattr(self, 'resize{}'.format(id_poly), Linear(input_current_poly,
                                                                     channels_poly[0], nobias=not use_bias))
                # # now build the current polynomial (id_poly).
                for l in range(1, self.power_poly[id_poly] + 1):
                    c1 = channels_poly[l]
                    cin = c1
                    setattr(self, 'l{}_{}'.format(id_poly, l), Linear(cin, c1, nobias=not use_bias))
                    if use_bn:
                        setattr(self, 'bn{}_{}'.format(id_poly, l), bn1(c1))
                # # update the channels for the next polynomial.
                input_current_poly = int(channels_poly[-1])
            # # save the last layer (only for the last). 
            self.last = Linear(input_current_poly, n_out)

    def __call__(self, batchsize, z=None, **kwargs):
        if z is None:
            z = sample_continuous(self.dim_z, batchsize, distribution=self.distribution, xp=self.xp)
        # # input_poly: the input variable to each polynomial; for the 
        # # first, simply z, i.e. the noise vector.
        input_poly = z + 0
        # # iterate over all the polynomials (length of channels many).
        for id_poly in range(len(self.channels)):
            # # ensure that the channels from previous polynomial are of the
            # # appropriate size.
            if getattr(self, 'has_rsz{}'.format(id_poly)):
                input_poly = getattr(self, 'resize{}'.format(id_poly))(input_poly)
            h = getattr(self, 'l{}_1'.format(id_poly))(input_poly)
            
            # # loop over the current polynomial layers and compute the 
            # # output (for this polynomial). 
            for layer in range(2, self.power_poly[id_poly] + 1):
                # # step 1: perform the hadamard product.
                z1 = getattr(self, 'l{}_{}'.format(id_poly, layer))(input_poly)
                h = z1 * h + h
                # # step 2: normalize representations.
                if self.use_bn:
                    h = getattr(self, 'bn{}_{}'.format(id_poly, layer))(h)
            # # update the input for the next polynomial.
            input_poly = h + 0
        # # last layer.
        h = self.last(h)
        if len(h.shape) == 2:
            h = F.reshape(h, (h.shape[0], h.shape[1], 1, 1))
        return h

# Define the updater class

In [None]:
# Hinge Loss
def loss_hinge_dis(dis_fake, dis_real):
    loss = F.mean(F.relu(1. - dis_real))
    loss += F.mean(F.relu(1. + dis_fake))
    return loss


def loss_hinge_gen(dis_fake):
    loss = -F.mean(dis_fake)
    return loss


class Updater(chainer.training.StandardUpdater):
    def __init__(self, gen, dis, opt_gen, opt_dis, iterator, *args, n_dis=1, 
                 n_gen_samples=512, **kwargs):
        self.loss_dis = loss_hinge_dis
        self.loss_gen = loss_hinge_gen
        self.n_gen_samples = n_gen_samples
        self.gen, self.dis = gen, dis
        self.opt_gen, self.opt_dis = opt_gen, opt_dis
        self.iterator = iterator
        self.n_dis = n_dis
        kwargs1 = {'iterator': iterator, 'optimizer': {'opt_gen': opt_gen, 'opt_dis': opt_dis}}
        super(Updater, self).__init__(*args, **kwargs1)

    def _generate_samples(self, n_gen_samples=None):
        if n_gen_samples is None:
            n_gen_samples = self.n_gen_samples
        x_fake = self.gen(n_gen_samples)
        return x_fake

    def get_batch(self, xp):
        batch = self.iterator.next()
        x = []
        for j in range(len(batch)):
            x.append(np.asarray(batch[j][0]).astype('f'))
        x_real = Variable(xp.asarray(x))
        return x_real

    def update_core(self):
        gen = self.gen
        dis = self.dis
        xp = gen.xp
        for i in range(self.n_dis):
            x_real = self.get_batch(xp)
            batchsize = len(x_real)
            dis_real = dis(x_real)
            x_fake = self._generate_samples(n_gen_samples=batchsize)
            dis_fake = dis(x_fake)
            x_fake.unchain_backward()

            fake_arr, real_arr = dis_fake.array, dis_real.array
            chainer.reporter.report({'dis_fake': fake_arr.mean()})
            chainer.reporter.report({'dis_real': real_arr.mean()})

            loss_dis = self.loss_dis(dis_fake=dis_fake, dis_real=dis_real)
            dis.cleargrads()
            loss_dis.backward()
            self.opt_dis.update()
            loss_dis.unchain_backward()
            chainer.reporter.report({'loss_dis': loss_dis.array})
            del loss_dis

            if i == 0:
                x_fake = self._generate_samples()
                dis_fake = dis(x_fake)
                loss_gen = self.loss_gen(dis_fake=dis_fake)
                assert not xp.isnan(loss_gen.data)
                gen.cleargrads()
                loss_gen.backward()
                self.opt_gen.update()
                loss_gen.unchain_backward()
                chainer.reporter.report({'loss_gen': loss_gen.array})
                del loss_gen


# Prepare the models for training

In [None]:
batch = 128
n_iters = 30000
cuda = chainer.cuda.available

dataset = SyntheticDataset(n_samples=200000)
iterator = chainer.iterators.SerialIterator(dataset, batch)

# # define the instances of the generator and the discriminator.
dis = FCDiscriminator(n_input=2, n_out=2, n_hidden=20, n_hidden2=20, layer_d=[20, 20, 20, 20, 20])
gen = FCProductRecursiveGenerator(channels=[3, 10], power_poly=[6, 8])

if cuda:
    gen.to_gpu()
    dis.to_gpu()

# # define the optimizers for the generator and the discriminator.
opt_gen = chainer.optimizers.Adam(alpha=0.00015, beta1=0., beta2=0.9)
opt_gen.setup(gen)
opt_dis = chainer.optimizers.Adam(alpha=0.0001, beta1=0., beta2=0.9)
opt_dis.setup(dis)

# # initialize the trainer and the updater.
updater = Updater(gen, dis, opt_gen, opt_dis, iterator, n_dis=3)
trainer = training.Trainer(updater, (n_iters, 'iteration'),) # out=out
# # add extensions to the training (auxiliary for the user).
trainer.extend(extensions.LogReport(trigger=(200, 'iteration')))
report_keys = ['loss_dis', 'loss_gen', 'synth_qual']
trainer.extend(extensions.PrintReport(report_keys), trigger=(200, 'iteration'))
trainer.extend(extensions.ProgressBar(update_interval=200))

In [None]:
m1 = 'Generator params: {}. Discriminator params: {}.'
print(m1.format(gen.count_params(), dis.count_params()))
# Run the training
print('start training')
trainer.run()
# # export the last model.
extensions.snapshot_object(gen, '{}_best.npz'.format(gen.__class__.__name__))

# Synthesize points from the trained model

In [None]:
aa = gen(10000).array
plt.scatter(aa[:, 0, 0, 0], aa[:, 1, 0, 0])