[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/tiruota/WGAN-PyTorch/blob/master/train_gaussian_mixture/wgan_gaussian_mixture.ipynb)

In [None]:
%matplotlib inline

# WGAN Gaussian Mixture

GANの問題点の一つだった"mode collapse"をWGANでは回避できることを確認するnotebook

## どうやって確認するのか

平均をずらした正規分布の混合分布を真のデータに用いる．

入力ノイズzをGeneratorに入力し，出力の分布が真の分布を捉えることができるか，というもの．

## 分布の生成・プロット

関数"gaussian_mixture_circle"は単円状，"gaussian_mixture_double_circle"は二重円状の混合分布を生成する．

また以下のように，生成した分布を"plot_scatter"でプロットし，"plot_kde"で分布のカーネル密度推定を描画する．

- gaussian_mixture_circle

|scatter|KDE|
|:---:|:---:|
|![scattar_true](https://imgur.com/6NbbeMi.png)|![kde_true](https://imgur.com/b6AypjF.png)|

- gaussian_mixture_double_circle

|scatter|KDE|
|:---:|:---:|
|![scattar_true](https://imgur.com/jKl2QjV.png)|![kde_true](https://imgur.com/HX4oSGM.png)|

# 実装

In [None]:
import random
import torch
import numpy as np
import math
import pylab
import seaborn as sns

import torch.nn as nn
import torch.optim as optim
import torch.utils.data

from torch.autograd import Variable

## パラメータ設定

- **num_mixture** - 正規分布の個数

- **scale** - 混合正規分布が成す円の大きさ

- **std** - 正規分布の標準偏差

In [None]:
num_mixture = 8
scale = 2.0
std = 0.2
batchSize = 100
nc = 2
nz = 256
nepochs = 200
clamp_lower = -0.01
clamp_upper = 0.01
n_critic = 5
iterate = 1000
lrD = 0.00005
lrG = 0.00005
ngpu = 1

In [None]:
def gaussian_mixture_circle(batchsize, num_cluster=8, scale=1, std=1):
	rand_indices = np.random.randint(0, num_cluster, size=batchsize)
	base_angle = math.pi * 2 / num_cluster
	angle = rand_indices * base_angle - math.pi / 2
	mean = np.zeros((batchsize, 2), dtype=np.float32)
	mean[:, 0] = np.cos(angle) * scale
	mean[:, 1] = np.sin(angle) * scale
	return np.random.normal(mean, std**2, (batchsize, 2)).astype(np.float32)

def gaussian_mixture_double_circle(batchsize, num_cluster=8, scale=1, std=1):
	rand_indices = np.random.randint(0, num_cluster, size=batchsize)
	base_angle = math.pi * 2 / num_cluster
	angle = rand_indices * base_angle - math.pi / 2
	mean = np.zeros((batchsize, 2), dtype=np.float32)
	mean[:, 0] = np.cos(angle) * scale
	mean[:, 1] = np.sin(angle) * scale
	# Doubles the scale in case of even number
	even_indices = np.argwhere(rand_indices % 2 == 0)
	mean[even_indices] /= 2
	return np.random.normal(mean, std**2, (batchsize, 2)).astype(np.float32)

In [None]:
def plot_kde(data, dir=None, filename="kde", color="Greens"):
    fig = pylab.gcf()
    fig.set_size_inches(4.0, 4.0)
    pylab.clf()
    bg_color  = sns.color_palette(color, n_colors=256)[0]
    ax = sns.kdeplot(data[:, 0], data[:,1], shade=True, cmap=color, n_levels=30, clip=[[-4, 4]]*2)
    ax.set_facecolor(bg_color)
    kde = ax.get_figure()
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    # kde.savefig("{}/{}.png".format(dir, filename))
    pylab.show()

def plot_scatter(data, dir=None, filename="scatter", color="blue"):
    fig = pylab.gcf()
    fig.set_size_inches(16.0, 16.0)
    pylab.clf()
    pylab.scatter(data[:, 0], data[:, 1], s=20, marker="o", edgecolors="none", color=color)
    pylab.xlim(-4, 4)
    pylab.ylim(-4, 4)
    # pylab.savefig("{}/{}.png".format(dir, filename))
    pylab.show()

In [None]:
# =======
# device
# =======
cuda = True if torch.cuda.is_available() else False
if(cuda):
    device = torch.device("cuda")
else:
    device = torch.device( "cpu" )

## Generator

中間層のTanhの有無で結果が大きく異なる．

多くは試せなかったが，Tanhなしでは分布を捉えられなかった．

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, nc):
        super(Generator, self).__init__()

        net = nn.Sequential(
            nn.Linear(nz, 128),
            nn.Tanh(),

            nn.Linear(128, 128),
            nn.Tanh(),

            nn.Linear(128, nc),
        )
        self.net = net
        self.nc = nc
        self.nz = nz
    
    def forward(self, input):
        output = self.net(input)
        return output

class Discriminator(nn.Module):
    def __init__(self, nc):
        super(Discriminator, self).__init__()
        
        net = nn.Sequential(
            nn.Linear(nc, 128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(128, 128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(128, 1),
        )
        self.net = net
        self.nc = nc

    def forward(self, input):
        output = self.net(input)
        return output

In [None]:
# ==========
# models
# ==========
generator = Generator(nz=nz, nc=nc).to(device)
discriminator = Discriminator(nc=nc).to(device)

# ==========
# optimizer
# ==========
optimizer_D = optim.RMSprop(discriminator.parameters(), lr = lrD)
optimizer_G = optim.RMSprop(generator.parameters(), lr = lrG)

# Handle multi-gpu if desired
if(cuda == True) and (ngpu > 1):
    generator = nn.DataParallel(generator, list(range(ngpu)))
    discriminator = nn.DataParallel(discriminator, list(range(ngpu)))

In [None]:
# ======
# train
# ======
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
iterations = 0
for epoch in range(nepochs):
    for i in range(iterate):
        generator.train()
        discriminator.train()

        iterations += batchSize
        
        for param in discriminator.parameters():
            param.requires_grad = True

        for n in range(n_critic):
            # ====================
            # Train the discriminator
            # ====================
            optimizer_D.zero_grad()

            # sample from data distribution
            samples_ture = Variable(Tensor(gaussian_mixture_double_circle(batchsize=batchSize, num_cluster=num_mixture, scale=scale, std=std))).to(device)

            # sample from generator
            z = Variable(Tensor(np.random.normal(0, 1, (batchSize, nz)).astype(np.float32))).to(device)

            # Generate a batch of images
            with torch.no_grad():
                samples_fake = generator(z)

            # Adversarial loss
            real_validity = discriminator(samples_ture / scale)
            fake_validity = discriminator(samples_fake.detach() / scale)

            lossD = -torch.sum(real_validity - fake_validity) / batchSize
            lossD.backward()
            optimizer_D.step()
            # Clip weights of discriminator
            for p in discriminator.parameters():
                p.data.clamp_(clamp_lower, clamp_upper)

        # ====================
        # Train the generator 
        # ====================
        for param in discriminator.parameters():
            param.requires_grad = False

        z = Variable(Tensor(np.random.normal(0, 1, (batchSize, nz)).astype(np.float32))).to(device)

        optimizer_G.zero_grad()

        # Generate a batch of images
        samples_fake = generator(z)
        # Adversarial loss
        lossG = - torch.sum(discriminator(samples_fake / scale) / batchSize)
        lossG.backward()
        optimizer_G.step()
        
    # ============
    # Save images
    # ============
    generator.eval()
    z_fixed = Variable(Tensor(np.random.normal(0, 1, (10000, nz)).astype(np.float32))).to(device)
    with torch.no_grad():
        samples_fake = generator(z_fixed)

    # =====
    # Plot
    # =====
    plot_scatter(samples_fake.cpu().numpy())
    plot_kde(samples_fake.cpu().numpy())