In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

In [3]:
print("Using data [%s]" % (name))

Using data [Data and variances]


In [4]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

In [5]:
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

In [6]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))


In [8]:
def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [9]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

0: D: 0.8460698127746582/0.5943595767021179 G: 0.813252866268158 (Real: [3.9521866011619569, 1.1685733866570043], Fake: [-0.48025260657072066, 0.0070735080882127618]) 
200: D: 0.010584743693470955/0.23989149928092957 G: 1.5469396114349365 (Real: [3.8880642449855802, 1.3031401983026045], Fake: [0.48656365692615511, 0.029472201957884633]) 
400: D: 0.01170569472014904/0.1212318018078804 G: 2.165919780731201 (Real: [4.0753459572792057, 1.1031088556396209], Fake: [0.48296793580055236, 0.065159438157828584]) 
600: D: 0.04766295477747917/0.05856797844171524 G: 2.85894775390625 (Real: [4.052794470787048, 1.1097557495287973], Fake: [0.56469883948564525, 0.19945158076355821]) 
800: D: 0.005996947176754475/0.07891087979078293 G: 2.874392509460449 (Real: [3.8947348964214323, 1.1025138207848162], Fake: [0.79311690978705884, 0.4237904713837935]) 
1000: D: 0.0009975639404729009/0.16946099698543549 G: 3.50296688079834 (Real: [3.8404748439788818, 1.3049847662925498], Fake: [1.3054529586806893, 0.950777

10000: D: 0.977918803691864/0.6114557981491089 G: 0.7739893794059753 (Real: [3.9476002049446106, 1.2074999923630061], Fake: [4.0048085027933125, 1.338086053606151]) 
10200: D: 0.8323512077331543/0.5435066223144531 G: 1.0546272993087769 (Real: [4.0780276829004292, 1.2168515767545938], Fake: [3.9539078521728515, 1.0627023436737408]) 
10400: D: 0.5068973898887634/0.3478733003139496 G: 0.6366167068481445 (Real: [3.8530531804263592, 1.2614667597477862], Fake: [3.9821322321891786, 1.4585048001342096]) 
10600: D: 0.5325584411621094/0.46303293108940125 G: 1.7383157014846802 (Real: [3.9648299539089202, 1.2399059782174002], Fake: [4.060488053560257, 1.1730191072188487]) 
10800: D: 0.6437920331954956/0.5035157799720764 G: 0.9933136701583862 (Real: [4.1032256758213039, 1.1671453504601785], Fake: [4.1737215477228169, 1.2012757545519264]) 
11000: D: 0.42675310373306274/0.5508109331130981 G: 1.0152722597122192 (Real: [3.8773564541339876, 1.1895635232595614], Fake: [3.8669850456714632, 1.2400226154293

19800: D: 0.07039639353752136/0.2768682837486267 G: 0.9615154266357422 (Real: [4.1379947936534878, 1.2100385606994994], Fake: [4.3604865014553074, 1.2507012790237906]) 
20000: D: 0.5961591601371765/0.07158263027667999 G: 1.5370358228683472 (Real: [3.9547063088417054, 1.3284698806316206], Fake: [4.1566352355480198, 1.230018501840062]) 
20200: D: 0.0005776762263849378/0.10584354400634766 G: 2.0180163383483887 (Real: [4.0141260910034182, 1.3362087487636503], Fake: [3.9344013535976412, 1.305578022792699]) 
20400: D: 1.4720218181610107/0.11203619837760925 G: 1.4525229930877686 (Real: [4.083313555717468, 1.2401686177512399], Fake: [4.24296635389328, 1.1287874351016316]) 
20600: D: 0.064852774143219/0.3406282067298889 G: 2.3255927562713623 (Real: [4.0159021782875062, 1.2846850916581454], Fake: [4.0977309095859527, 1.1508431761155871]) 
20800: D: 0.12489984929561615/0.07633918523788452 G: 1.0566996335983276 (Real: [4.0571343487501146, 1.1251583803834062], Fake: [4.1039616334438325, 1.154747529

29600: D: 0.7010936141014099/0.7517508864402771 G: 0.6966797709465027 (Real: [3.9535645699501036, 1.1206787041972104], Fake: [3.3924783504009248, 1.2858441422878188]) 
29800: D: 0.7213623523712158/0.585443377494812 G: 0.8109933137893677 (Real: [4.0313089126348496, 1.3256953417746111], Fake: [4.4494288516044618, 1.2259823092119997]) 
