In [7]:
import torch
import torch.nn as nn
import tqdm

In [8]:
class Generator(nn.Module):

    def __init__(self, input_dimension = 100, hidden_dimension = 1200, output_dimension = 28*28): 
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dimension, hidden_dimension),
            nn.ReLU(),
            nn.Linear(hidden_dimension, hidden_dimension),
            nn.ReLU(),
            nn.Linear(hidden_dimension, output_dimension),
            nn.Tanh()
        )


    def forward(self, z):
        return self.network(z)


In [9]:
class Discrimator(nn.Module):
    def __init__(self, input_dimension = 28*28, hidden_dimension = 240, output_dimension = 1):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dimension, hidden_dimension),      
            nn.LeakyReLU(negative_slope = 0.2),
            nn.Linear(hidden_dimension, hidden_dimension),
            nn.LeakyReLU(negative_slope = 0.2),
            nn.Linear(hidden_dimension, output_dimension),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.network(x)

In [10]:
def sample_noise(batch_size):
    pass

def sample_minibatch(batch_size):
    pass

In [11]:
# Training Loop
def training_loop(generator, discrimator, generator_optimizer, discriminator_optimizer, epochs, k = 1, batch_size = 100, device = "cpu"):
    # the disc_gen_ratio (k) is used to set the value for the number of training steps for the discriminator (usually higher than the generator)
    
    training_loss = {"generator_loss" : [], "discrimator_loss" : []}

    for epoch in tqdm.tqdm(range(epochs)):

        # discrminator
        for _ in range(k):
            # sample minibatch of m noise samples
            z = sample_noise(batch_size)
            
            # sample minibatch of m examples data
            x = sample_minibatch(batch_size)
            
            # update the discriminator by ascending its stochastic gradient
            fake_data_loss = nn.BCELoss(discrimator(generator(z)).squeeze(-1), torch.zeros(batch_size, device = device))
            real_data_loss = nn.BCELoss(discrimator(x).squeeze(-1), torch.ones(batch_size, device = device))
            actual_loss = (fake_data_loss + real_data_loss) / 2.0 # the paper uses a sum but I'm using a mean here (more typical in modern ML)

            discriminator_optimizer.zero_grad()
            actual_loss.backward()
            discriminator_optimizer.step()

            training_loss["discrimator_loss"].append(actual_loss.item())
    

        # generator
        z = sample_noise(batch_size)
        fake_data_loss_gen = nn.BCELoss(discrimator(generator(z)).squeeze(-1), torch.ones(batch_size, device = device))
        generator_optimizer.zero_grad()
        fake_data_loss_gen.backward()
        generator_optimizer.step()
        training_loss["generator_loss"].append(fake_data_loss_gen.item())        



    return training_loss