# Special thanks to 
https://www.chinahadoop.cn/course/1327

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
BATCH_SIZE = 500 
NOISE_DIM = 5

In [3]:
def get_gaussian_dist(mu, sigma): # the real data
    temp = np.random.normal(mu, sigma, size=BATCH_SIZE)[:, np.newaxis]
    return torch.from_numpy(temp).float()

In [4]:
def extract(v):
    return v.detach().storage().tolist()
def stats(d):
    return [np.mean(d), np.std(d)]

In [5]:
# Generator 
G = nn.Sequential(                      
    nn.Linear(NOISE_DIM, 128),            
    nn.ReLU(),
    nn.Linear(128, 1),                  
)

# Discriminator 
D = nn.Sequential(
    nn.Linear(1, 128),     
    nn.ReLU(),
    nn.Linear(128, 1),     
    nn.Sigmoid(),
)

In [None]:
opt_D = torch.optim.Adam(D.parameters(), lr=0.0001)
opt_G = torch.optim.Adam(G.parameters(), lr=0.0001)

In [None]:
for step in range(10000):
    # train discrimintor
    d_real_data = get_gaussian_dist(5, 2)    # real data
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)  # random noise
    d_fake_data = G(noise)                      # fake data from G (generated from random ideas)

    prob_real_decision = D(d_real_data)          # D try to increase this prob
    prob_fake_decision = D(d_fake_data.detach()) # D try to reduce this prob
    
    D_loss = - torch.mean(torch.log(prob_real_decision) + torch.log(1. - prob_fake_decision))
    opt_D.zero_grad()
    D_loss.backward()
    opt_D.step()

    # train generator 
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)  # random noise
    g_fake_data = G(noise)                      # fake data from G (generated from random ideas)
    prob_fake_decision = D(g_fake_data)         # G try to increase this prob
    
    G_loss = torch.mean(torch.log(1. - prob_fake_decision))
    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 500 == 0:  # plotting
        print("Epoch %s: ; Real Dist (%s),  Fake Dist (%s) " %
                  (step, stats(extract(d_real_data)), stats(extract(g_fake_data))))

Epoch 0: ; Real Dist ([5.067422164037824, 2.0268136991888213]),  Fake Dist ([-0.13762627686187626, 0.12559447968408524]) 
Epoch 500: ; Real Dist ([4.95735705780983, 2.00384641364993]),  Fake Dist ([3.0464577887058257, 0.880591239461629]) 
Epoch 1000: ; Real Dist ([4.990398664072156, 2.0437309877125798]),  Fake Dist ([5.893177792072296, 1.576236947167336]) 
Epoch 1500: ; Real Dist ([5.097402470998466, 1.97503666138329]),  Fake Dist ([6.415912767410278, 1.7493719966267285]) 
Epoch 2000: ; Real Dist ([4.977614680558443, 2.073433243487078]),  Fake Dist ([4.587025060653686, 1.3611078646869337]) 
Epoch 2500: ; Real Dist ([5.020027194902301, 1.9687778726397451]),  Fake Dist ([4.95946099948883, 1.519033190055754]) 
Epoch 3000: ; Real Dist ([5.2128319461643695, 1.844861419961821]),  Fake Dist ([5.041665683746338, 1.713292325384186]) 
Epoch 3500: ; Real Dist ([5.0327643441818655, 1.9598375931970633]),  Fake Dist ([5.118369703292847, 2.227701643320257]) 
Epoch 4000: ; Real Dist ([4.92052979561686

In [None]:
print("Plotting the generated distribution...")
values = extract(g_fake_data)
plt.hist(values, bins=50,color="red")
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.show()

In [None]:
print("Plotting the generated distribution...")
values = extract(d_real_data)
plt.hist(values, bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Real Distribution')
plt.grid(True)
plt.show()