In [1]:
#!/usr/bin/env python
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

Using data [Data and variances]


In [3]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

  "Please ensure they have the same size.".format(target.size(), input.size()))


0: D: 0.2612496018409729/0.7550960779190063 G: 0.6319496035575867 (Real: [3.881776454001665, 1.3058554984265063], Fake: [0.34714002758264539, 0.013925168078822258]) 
200: D: -1.000088900582341e-12/0.1943722516298294 G: 1.7466143369674683 (Real: [3.919493629038334, 1.278808678596596], Fake: [-0.24533017694950104, 0.0012270249925156239]) 
400: D: -1.000088900582341e-12/0.10825973004102707 G: 2.294783115386963 (Real: [3.7743106368929147, 1.5686935209310764], Fake: [-0.39478518813848495, 0.017708721492327557]) 
600: D: -1.000088900582341e-12/0.05873845890164375 G: 2.8944711685180664 (Real: [3.9866126239299775, 1.2762855689696464], Fake: [-0.34019537091255186, 0.078327143105248118]) 
800: D: -1.000088900582341e-12/0.054303526878356934 G: 3.2106664180755615 (Real: [3.9703823983669282, 1.209226298433937], Fake: [-0.18055458210408687, 0.2166505761596022]) 
1000: D: 0.0009939244482666254/0.03329313173890114 G: 3.267967939376831 (Real: [4.1688681280612947, 1.1199689775010326], Fake: [0.019086868

10000: D: 0.9485693573951721/0.8106628656387329 G: 0.7467162609100342 (Real: [3.7736936104297638, 1.172259436712225], Fake: [4.1009891617298129, 1.0948647583559137]) 
10200: D: 0.5122779011726379/0.5562213063240051 G: 0.6588876843452454 (Real: [4.0858793282508854, 1.3087501835893263], Fake: [3.9643237501382829, 1.2205015387764857]) 
10400: D: 0.5323874950408936/0.3354381322860718 G: 1.3748964071273804 (Real: [3.9118615794181824, 1.2032777235712033], Fake: [4.3512736088037487, 1.2216805645132136]) 
10600: D: 0.5690218210220337/0.4919187128543854 G: 0.689799427986145 (Real: [3.9542882061004638, 1.1695150378502559], Fake: [3.9697003972530367, 1.116197163708049]) 
10800: D: 0.20576399564743042/0.5444735884666443 G: 0.6415402293205261 (Real: [4.0184838759899142, 1.2941799754148864], Fake: [4.0245869338512419, 1.2580973766562631]) 
11000: D: 0.31965500116348267/1.0085173845291138 G: 0.44921875 (Real: [4.1759560453891753, 1.2899029860735498], Fake: [4.1455021870136264, 1.2485395237450279]) 
1

19800: D: 0.1778448224067688/0.25703558325767517 G: 1.0865757465362549 (Real: [3.899229212999344, 1.2730725259531681], Fake: [3.7362002617120744, 1.3196759059842063]) 
20000: D: 0.04035596922039986/1.7935454845428467 G: 2.845966339111328 (Real: [3.8619731092453002, 1.1810351638044498], Fake: [4.1703265231847766, 1.1102070887110032]) 
20200: D: 0.07320480793714523/0.03559655696153641 G: 1.4348039627075195 (Real: [3.8897855523601175, 1.2898105227760008], Fake: [3.8613371509313583, 1.4116380289757369]) 
20400: D: 0.07747320830821991/0.2461923211812973 G: 0.20697583258152008 (Real: [4.0408905553817753, 1.2708659522625199], Fake: [4.3441728645563122, 1.145649639806926]) 
20600: D: 1.7431493997573853/0.2421710342168808 G: 2.1173717975616455 (Real: [4.2202457667887208, 1.2938268300700155], Fake: [4.6763345324993137, 1.0764294393848022]) 
20800: D: 0.24189963936805725/0.1331222951412201 G: 2.1456825733184814 (Real: [3.9748358857631683, 1.1635641792253912], Fake: [4.4862626981735225, 1.13199640

29600: D: 0.8818122148513794/0.8872649073600769 G: 0.7833473086357117 (Real: [3.8258704704046251, 1.1866414491810533], Fake: [4.5922258043289181, 1.200524216590622]) 
29800: D: 0.7342708706855774/0.8494189977645874 G: 0.5542294383049011 (Real: [4.068523369431496, 1.2442353027278581], Fake: [3.207109948396683, 1.1192780112382186]) 
