In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [483]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 4  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [484]:
# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

Using data [Data and variances]


In [485]:
# ##### DATA: Target data and generator input data
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussiana

In [486]:
def Binarize(tensor,quant_mode='det'):
    if quant_mode=='det':
        return tensor.sign()
    else:
        return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        try:
            self.exempt = kwargs["exempt"]
            del kwargs["exempt"]
        except:
            self.exempt=False
        self.exempt=False
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)
        

    def forward(self, input):
        if self.exempt:
#         if input.size(1) != self.excempt:
            input.data=Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
#         print(self.weight.shape)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)

        return out

In [487]:
# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = BinarizeLinear(input_size, hidden_size, exempt=True) #1-50
        self.htanh1 = nn.Hardtanh()
#         self.bn1 = nn.BatchNorm1d(hidden_size)
        
        self.map2 = BinarizeLinear(hidden_size, hidden_size) #50-50
        self.htanh2 = nn.Hardtanh()
#         self.bn2 = nn.BatchNorm1d(hidden_size)
        
        self.map3 = BinarizeLinear(hidden_size, output_size) #50-1
#         self.logsoftmax=nn.LogSoftmax()
    def forward(self, x):
        x = self.htanh1(self.map1(x))
        x = self.htanh1(self.map2(x))
#         return self.logsoftmax(self.map3(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = BinarizeLinear(input_size*2, hidden_size, exempt=True) #100-50
        print((input_size, hidden_size), self.map1.weight.shape)
        self.htanh1 = nn.Hardtanh()
#         self.bn1 = nn.BatchNorm1d(hidden_size)
        
        self.map2 = BinarizeLinear(hidden_size, hidden_size) #50-50
        self.htanh2 = nn.Hardtanh()
#         self.bn2 = nn.BatchNorm1d(hidden_size)
        
        self.map3 = BinarizeLinear(hidden_size, output_size) #50-1
#         self.logsoftmax=nn.LogSoftmax()
        self.htanh3 = nn.Hardtanh()
        
    def forward(self, x):
        x = self.htanh1(self.map1(x))
        x = self.htanh1(self.map2(x))
        return self.htanh3(self.map3(x))


In [488]:
def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [491]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
# D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

(100, 50) torch.Size([50, 200])


In [492]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = (D(preprocess(d_real_data)) + 1)/2
#         d_real_decision = D(d_real_data)
#         print(d_real_decision)
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = (D(preprocess(d_fake_data.t())) + 1)/2
#         d_fake_decision = D(d_fake_data.t())
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        
        d_fake_error.backward()
        for p in list(D.parameters()):
            if hasattr(p,'org'):
                p.data.copy_(p.org)
#         print(1)
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
        for p in list(D.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = (D(preprocess(g_fake_data.t())) + 1)/2
#         print(preprocess(g_fake_data.t()).shape, dg_fake_decision)
        #         dg_fake_decision = D(g_fake_data.t())
#         print(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        for p in list(G.parameters()):
            if hasattr(p,'org'):
                p.data.copy_(p.org)
#         for i.grad in G.parameters():
#             print(i)
#         print(g_fake_data.grad)
        g_optimizer.step()  # Only optimizes G's parameters
#         print("################")
        for p in list(G.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))
        
#         for i in G.parameters():
#             print(i)
            
    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

  "Please ensure they have the same size.".format(target.size(), input.size()))


0: D: 0.0/0.7810595631599426 G: 0.0 (Real: [4.033352342844009, 1.0754584511117078], Fake: [4.685608332157135, 2.4548669928369327]) 
200: D: 0.0/27.63102149963379 G: 27.63102149963379 (Real: [4.17116320848465, 1.3007859089802456], Fake: [4.7590951704606415, 2.7135804101719927]) 
400: D: 27.63102149963379/0.0 G: 27.63102149963379 (Real: [4.053149722963571, 1.2402774566304324], Fake: [3.41772410299629, 3.0358430610362093]) 
600: D: 0.0/27.63102149963379 G: 27.63102149963379 (Real: [3.815350051522255, 1.2130965964084066], Fake: [4.004249610081315, 3.023410910800886]) 
800: D: 0.0/27.63102149963379 G: 0.0 (Real: [4.139606406390667, 1.3132308233927945], Fake: [8.70237875878811, 3.16160790726547]) 
1000: D: 27.63102149963379/27.63102149963379 G: 0.0 (Real: [3.9563324177265167, 1.3848358407513288], Fake: [7.8920968319475655, 3.362056514264185]) 
1200: D: 0.0/0.0 G: 27.63102149963379 (Real: [4.027162272334099, 1.301330310122871], Fake: [7.692215462699533, 3.543574247039918]) 
1400: D: 0.0/0.0 G

KeyboardInterrupt: 