In [1]:
import torch
import pandas as pd
import numpy as np

from torch import nn
from torch.optim import Adam


from IPython.core.debugger import set_trace

from mnist_loader import mnist_data
from adaboost import Booster
from haar_features import HaarFeatureMaker


data = mnist_data()

        

## The code to for the generator

Here we define the generator object, and also the code to train the generator


In [2]:
class GenerativeNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GenerativeNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
# Noise
def noise(size):
    n = torch.randn(size, 100)
    if torch.cuda.is_available(): return n.cuda 
    return n

def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

# Loss function
loss = nn.BCELoss()


def train_generator(optimizer, fake_data, oracle_function):
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = oracle_function(fake_data)
    # Calculate error and backpropagate
    target = torch.ones(fake_data.size(0), 1)
    error = loss(prediction, target)
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error



## Discriminator oracle

Here we define the discriminator oracle function. The main object of interest is the Booster() class, which implements the adaboost algorithm on haar features. The adabooster has a function which generates a tensor object that can be used as the final discriminator in torch format.

In [3]:



haars = HaarFeatureMaker(28)

def discriminator_oracle(real_data, fake_data):
    X = np.vstack((real_data,fake_data))
    Y = np.zeros((X.shape[0],))
    num_real = real_data.shape[0]
    Y[:num_real] = 1.
    boosting = Booster(X,Y,haars)
    boosting.train(5)
    predictor = boosting._group_hypothesis.get_tensor
    return predictor

## Time for training

In [4]:
generator_batch_size = 100
discriminator_batch_size = 10
data_loader_gen = torch.utils.data.DataLoader(data, batch_size=generator_batch_size, shuffle=True)
data_loader_dis = torch.utils.data.DataLoader(data, batch_size=discriminator_batch_size, shuffle=True)

# Create a generator
generator = GenerativeNet()

# Optimizer for generator
g_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Number of epochs
num_epochs = 200


# We start by grabbing batches of data:
for epoch,(real_discrim_batch,_)  in enumerate(data_loader_dis):
    
    # Seems to be some problem with the size being [N, 1, 28, 28], need to conver to [N, 28, 28]
    real_data_discrim = real_discrim_batch.view((real_discrim_batch.size(0), 28, 28)).numpy()
    
    # Generate some fake data:
    fake_data_discrim = generator(noise(real_discrim_batch.size(0))).detach().numpy()
    
    # This thing returns a vector, but my boosting alg requires 28x28 images:
    fake_data_discrim = fake_data_discrim.reshape((-1,28,28))
    
    # Train the boosting algorithm    
    oracle_predictor = discriminator_oracle(real_data_discrim, fake_data_discrim)
    
    # Now we train the generator according to the function received from oracle
    for n_batch, (real_batch,_) in enumerate(data_loader_gen):

        fake_data = generator(noise(real_batch.size(0)))
        g_error = train_generator(g_optimizer, fake_data, oracle_predictor)


    # Display Progress
    if (n_batch) % 100 == 0:
        continue
#         display.clear_output(True)
        # Display Images
#         test_images = vectors_to_images(generator(test_noise)).data.cpu()

    # Model Checkpoints
#     logger.save_models(generator, discriminator, epoch)