In [242]:
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import random

representation_size = 2
n_samples = 2000

data = []
for i in range(n_samples//2):
    data.append([0,1])
    data.append([1,0])
data = np.array(data)

In [243]:
random.shuffle(data)

In [244]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        #generating distribution
        self.gen1 = nn.Linear(representation_size, 10)
        self.gen2 = nn.Linear(10, 10)
        self.gen3 = nn.Linear(10, input_size)
        
        #functions
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def generate(self):
        mean = torch.zeros(input_size)
        z = torch.randn_like(mean)
        g1 = self.relu(self.gen1(z))
        g2 = self.relu(self.gen2(g1))
        g3 = self.gen3(g2)
        return g3
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        #discriminator
        self.disc1 = nn.Linear(input_size, 10)
        self.disc2 = nn.Linear(10, 10)
        self.disc3 = nn.Linear(10, 1)
        
        #functions
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def discriminate(self, x):
        d1 = self.relu(self.disc1(x))
        d2 = self.relu(self.disc2(d1))
        prob = self.sigmoid(self.disc3(d2))
        return prob

In [245]:
gen = Generator()
disc = Discriminator()
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=1e-3)
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=1e-3)

def lossf(Dx, Dz):
    return -torch.sum(torch.log(Dx) + torch.log(1 - Dz))

def lossgen(Z):
    return -torch.sum(torch.log(1-Z))

def train(batch_size, epochs, data):
    for epoch in range(epochs):
        random.shuffle(data)
        disc_loss = 0
        gen_loss = 0
        for i in range(batch_size):
            optimizer_gen.zero_grad()
            
            data_point = data[i]
            generated_point = gen.generate()
            data_point = Variable(torch.tensor(data_point).float())
            
            #optimize discriminator
            optimizer_disc.zero_grad()
            dgen = disc.discriminate(generated_point)
            dreal = disc.discriminate(data_point)
            
            
            loss = lossf(dgen, dreal)
            loss.backward(retain_graph = True)
            optimizer_disc.step()
            disc_loss += loss
            
            #optimize generator
            lossg = lossgen(disc.discriminate(generated_point))
            lossg.backward()
            optimizer_gen.step()
            gen_loss += lossg
        print('EPOCH {}, DISCRIMINATOR LOSS {}, GENERATOR LOSS {}'.format(epoch, disc_loss, gen_loss))

In [246]:
train(1000, 5, data)

EPOCH 0, DISCRIMINATOR LOSS 1284.3839111328125, GENERATOR LOSS 772.8784790039062
EPOCH 1, DISCRIMINATOR LOSS 1223.7919921875, GENERATOR LOSS 795.8995361328125
EPOCH 2, DISCRIMINATOR LOSS 1206.81884765625, GENERATOR LOSS 839.5524291992188
EPOCH 3, DISCRIMINATOR LOSS 1189.7274169921875, GENERATOR LOSS 850.2138671875
EPOCH 4, DISCRIMINATOR LOSS 1172.271240234375, GENERATOR LOSS 878.3260498046875


In [284]:
disc.discriminate(torch.tensor([0.5,0.0]))

tensor([0.8575], grad_fn=<SigmoidBackward>)

In [282]:
gen.generate()

tensor([0.0333, 1.0474], grad_fn=<AddBackward0>)

In [240]:
data