In [1]:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets

import torch
import torch.autograd as agd
import torch.nn as tchnn
import torch.nn.functional as F
import torch.optim as optim
import random

In [2]:
#some params
DIM = 512
FIXED_GEN = False
LAMBDA = .1
DISCRI_ITR = 5
BATCHSZ = 256
batchSz = BATCHSZ
TOT_GEN_ITR = 100000

In [3]:
#gaussian mixture dataset creation
class GMMsampler:
    
    def __init__( self, n_samples, n_components=1, weights=[1], mu=[np.array([0,0])], sig=[np.array([[1,0],[0,1]])]):
        self.n_components = n_components
        self.n_samples = n_samples
        self.weights = weights
        self.mu = mu
        self.sig = sig
        self.dim = mu[0].size
        self.data = np.empty([n_samples, self.dim])
        self.datacid = dict()
        
    def check_musig(self):
        shmu = self.mu[0].shape
        shsig = self.sig[0].shape
        if(not(all(m.shape==shmu for m in mu))):
            print('all mean vectors must be of same dimension')
        if(not(all(s.shape==shmu for s in sig))):
            print('all covariance matrix must be of same dimension')
    
    def gen_sample(self):
        for i in range(self.n_components):
            self.datacid[i] = []
        for i in range(self.n_samples):
            idx = np.random.choice(np.arange(0,self.n_components), p=(self.weights)/np.sum(self.weights))
            mu_,sig_ = self.mu[idx], self.sig[idx]
            self.data[i,:] = np.random.multivariate_normal(mu_, sig_)
            self.datacid[idx].append(self.data[i,:])
        for idx in range(self.n_components):
            self.datacid[idx] = np.array(self.datacid[idx]).reshape(-1,2)
            
    def plot_centers(self):
        plt.figure()
        for c in self.mu:
            plt.scatter(c[0], c[1])
        plt.show()
        
    def plot_data(self):
        plt.figure()
        for i in range(self.n_components):
            plt.scatter(self.datacid[i][:,0], self.datacid[i][:,1])
        plt.show()

scale = 2
m1 = np.array([-1, 0])*scale
m2 = np.array([1, 0])*scale
m3 = np.array([0, 1])*scale
m4 = np.array([0, -1])*scale
m5 = np.array([1/np.sqrt(2), 1/np.sqrt(2)])*scale
m6 = np.array([1/np.sqrt(2), -1/np.sqrt(2)])*scale
m7 = np.array([-1/np.sqrt(2), 1/np.sqrt(2)])*scale
m8 = np.array([-1/np.sqrt(2), -1/np.sqrt(2)])*scale

sig1 = np.eye(2)/1.414
sig2 = np.eye(2)/1.414
sig3 = np.eye(2)/1.414
sig4 = np.eye(2)/1.414
sig5 = np.eye(2)/1.414
sig6 = np.eye(2)/1.414
sig7 = np.eye(2)/1.414
sig8 = np.eye(2)/1.414

In [14]:
#models of generator and discriminator
class Generator(tchnn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        self.L1 = tchnn.Linear(2, DIM)
        self.L2 = tchnn.Linear(DIM, DIM)
        self.L3 = tchnn.Linear(DIM, DIM)
        self.Ou = tchnn.Linear(DIM,2)
    def forward(self, noise, real_data):
        if FIXED_GEN:
            #print('here0')
            return noise + real_data
        else:
            x = F.relu(self.L1(noise))
            x = F.relu(self.L2(x))
            x = F.relu(self.L3(x))
            x = self.Ou(x)
            #print('here1')
            return x.view(-1,2)
    def name(self):
        return 'GENERATOR'

In [5]:
class Discriminator(tchnn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.L1 = tchnn.Linear(2, DIM)
        self.L2 = tchnn.Linear(DIM, DIM)
        self.L3 = tchnn.Linear(DIM, DIM)
        self.Ou = tchnn.Linear(DIM,1)
    def forward(self, x):
            x = F.relu(self.L1(x))
            x = F.relu(self.L2(x))
            x = F.relu(self.L3(x))
            x = self.Ou(x)
            return x.view(-1)
    def name(self):
        return 'DISCRIMINATOR'

In [6]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [7]:
#training data generator
def traindata_gen():
    g = GMMsampler(batchSz, n_components=8, weights=[1, 1, 1, 1, 1, 1, 1, 1],
                  mu=[m1,m2,m3,m4,m5,m6,m7,m8], sig=[sig1,sig2,sig3,sig4,sig5,sig6,sig7,sig8])
    while(True):
        g.gen_sample()
        g.plot_centers()
        g.plot_data()
        yield g.data
d = traindata_gen()

In [8]:
#gradient penalty term in objective
def calc_gp(D, real_data, fake_data):
    α = torch.rand(BATCHSZ, 1)
    α = α.expand(real_data.size())
    α = α.cuda()
    
    interpolated = α*real_data + (1-α)*fake_data
    interpolated = interpolated.cuda()
    interpolated = agd.Variable(interpolated, requires_grad=True)
    
    D_interp = D(interpolated)
    
    gradients = agd.grad(outputs=D_interp, inputs=interpolated, grad_outputs=torch.ones(D_interp.size()).cuda()
                        ,create_graph=True, retain_graph=True, only_inputs=True)[0]
    gp = ((gradients.norm(2, dim=1) - 1)**2).mean()*LAMBDA
    return gp

In [22]:
#notaion similar to paper
G = Generator().cuda()
D = Discriminator().cuda()
#G.apply(weights_init)
#D.apply(weights_init)
#G = G.cuda()
#D = D.cuda()
print(G)
print(D)

Generator (
  (L1): Linear (2 -> 512)
  (L2): Linear (512 -> 512)
  (L3): Linear (512 -> 512)
  (Ou): Linear (512 -> 2)
)
Discriminator (
  (L1): Linear (2 -> 512)
  (L2): Linear (512 -> 512)
  (L3): Linear (512 -> 512)
  (Ou): Linear (512 -> 1)
)


In [23]:
v1 = agd.Variable(torch.randn(BATCHSZ,2).cuda())
v2 = agd.Variable(torch.randn(BATCHSZ,2).cuda())
G(v1,v2)

Variable containing:
-0.0077 -0.0064
-0.0013 -0.0212
-0.0336  0.0104
 0.0049  0.0097
 0.0071  0.0295
 0.0081  0.0316
 0.0077  0.0013
-0.0031  0.0363
-0.0190  0.0236
-0.0320 -0.0051
-0.0181  0.0250
-0.0587 -0.0158
-0.0066  0.0308
-0.0256  0.0168
-0.0358 -0.0027
-0.0269  0.0167
-0.0357  0.0018
-0.0182  0.0275
-0.0737 -0.0229
-0.0331  0.0162
-0.0446 -0.0098
-0.0399 -0.0025
 0.0016  0.0278
-0.0073  0.0290
-0.0124  0.0323
-0.0139  0.0312
-0.0559 -0.0137
-0.0309  0.0044
-0.0181  0.0214
 0.0001  0.0275
-0.0368 -0.0023
 0.0106  0.0265
-0.0635 -0.0254
 0.0112  0.0281
-0.0998 -0.0418
-0.0204  0.0091
-0.0082  0.0341
-0.0314  0.0182
-0.0010  0.0380
-0.0081  0.0216
-0.0397 -0.0026
-0.0112  0.0225
-0.0288  0.0079
-0.0501 -0.0039
-0.0086  0.0351
-0.0316  0.0095
-0.0729 -0.0208
-0.0007  0.0212
 0.0135 -0.0090
-0.0483  0.0079
-0.0253  0.0153
-0.0169  0.0347
-0.0188  0.0256
 0.0134 -0.0072
-0.0098  0.0327
-0.0375 -0.0008
-0.0116  0.0340
-0.0358  0.0019
 0.0181  0.0055
-0.0448  0.0034
-0.0479 -0.0209
-0.

In [24]:
optD = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.9))
optG = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.9))

In [25]:
one = torch.FloatTensor([1])
onebar = one * -1
one = one.cuda()
onebar = onebar.cuda()

In [12]:
#to plot results
def generate_image(true_dist):
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))

    points_v = agd.Variable(torch.Tensor(points), volatile=True).cuda()
    disc_map = D(points_v).cpu().data.numpy()

    noise = torch.randn(BATCHSZ, 2).cuda()
    
    noisev = agd.Variable(noise, volatile=True)
    
    true_dist_v = agd.Variable(torch.Tensor(true_dist).cuda())
    
    samples = G(noisev, true_dist_v).cpu().data.numpy()

    plt.clf()

    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())

    plt.scatter(true_dist[:, 0], true_dist[:, 1], c='orange', marker='+')
    if not FIXED_GEN:
        plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')

In [26]:
for itr in range(TOT_GEN_ITR):
    ############################
    # (1) Update D network
    ###########################
    for p in D.parameters():  # reset requires_grad
        p.requires_grad = True  # they are set to False below in netG update

    for iter_d in range(DISCRI_ITR):
        #print(iter_d)
        _data = next(d)
        real_data = torch.Tensor(_data)
        real_data = real_data.cuda()
        real_data_v = agd.Variable(real_data)
        #print(real_data_v)

        D.zero_grad()

        # train with real
        D_real = D(real_data_v)
        D_real = D_real.mean()
        D_real.backward(onebar)

        # train with fake
        noise = torch.randn(BATCHSZ, 2)
        noise = noise.cuda()
        noisev = agd.Variable(noise, volatile=True)  # totally freeze netG
        
        #noisev = agd.Variable(noise)
        gop = G(noisev, real_data_v)
        print('here3')
        fake = agd.Variable(gop.data)
        inputv = fake
        D_fake = D(inputv)
        D_fake = D_fake.mean()
        D_fake.backward(one)

        # train with gradient penalty
        gradient_penalty = calc_gp(D, real_data_v.data, fake.data)
        gradient_penalty.backward()

        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake
        optD.step()
        print('discri iter done ', iter_d)

    if not FIXED_GEN:
        ############################
        # (2) Update G network
        ###########################
        for p in D.parameters():
            p.requires_grad = False  # to avoid computation
        G.zero_grad()

        _data = next(d)
        real_data = torch.Tensor(_data)
        real_data = real_data.cuda()
        real_data_v = agd.Variable(real_data)

        noise = torch.randn(BATCHSZ, 2)
        noise = noise.cuda()
        noisev = agd.Variable(noise)
        fake = G(noisev, real_data_v)
        G = D(fake)
        G = G.mean()
        G.backward(onebar)
        G_cost = -G
        optG.step()
        print('gen iter done', itr)



here3
discri iter done  0
here3
discri iter done  1
here3
discri iter done  2
here3
discri iter done  3
here3
discri iter done  4
gen iter done 0


TypeError: 'Variable' object is not callable