# Special thanks to 
https://www.chinahadoop.cn/course/1327

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
BATCH_SIZE = 500 
NOISE_DIM = 5

In [3]:
def get_gaussian_dist(mu, sigma): # the real data
    temp = np.random.normal(mu, sigma, size=BATCH_SIZE)[:, np.newaxis]
    return torch.from_numpy(temp).float()

In [4]:
def extract(v):
    return v.detach().storage().tolist()
def stats(d):
    return [np.mean(d), np.std(d)]

In [5]:
# Generator 
G = nn.Sequential(                      
    nn.Linear(NOISE_DIM, 128),            
    nn.ReLU(),
    nn.Linear(128, 1),                  
)

# Discriminator 
D = nn.Sequential(
    nn.Linear(1, 128),     
    nn.ReLU(),
    nn.Linear(128, 1),     
    nn.Sigmoid(),
)

In [None]:
opt_D = torch.optim.Adam(D.parameters(), lr=0.0001)
opt_G = torch.optim.Adam(G.parameters(), lr=0.0001)

In [None]:
for step in range(10000):
    # train discrimintor
    d_real_data = get_gaussian_dist(5, 2)    # real data
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)  # random noise
    d_fake_data = G(noise)                      # fake data from G (generated from random ideas)

    prob_real_decision = D(d_real_data)          # D try to increase this prob
    prob_fake_decision = D(d_fake_data.detach()) # D try to reduce this prob
    
    D_loss = - torch.mean(torch.log(prob_real_decision) + torch.log(1. - prob_fake_decision))
    opt_D.zero_grad()
    D_loss.backward()
    opt_D.step()

    # train generator 
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)  # random noise
    g_fake_data = G(noise)                      # fake data from G (generated from random ideas)
    prob_fake_decision = D(g_fake_data)         # G try to increase this prob
    
    G_loss = torch.mean(torch.log(1. - prob_fake_decision))
    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 500 == 0:  # plotting
        print("Epoch %s: ; Real Dist (%s),  Fake Dist (%s) " %
                  (step, stats(extract(d_real_data)), stats(extract(g_fake_data))))

Epoch 0: ; Real Dist ([4.891124389514327, 1.993818500989288]),  Fake Dist ([-0.17570002839714288, 0.1492427485260771]) 
Epoch 500: ; Real Dist ([5.043834386639297, 2.0627448433238538]),  Fake Dist ([1.9699697268009186, 0.5897737314037829]) 
Epoch 1000: ; Real Dist ([5.009938806027174, 2.0473615917828276]),  Fake Dist ([4.845271273374557, 1.4156075536506205]) 
Epoch 1500: ; Real Dist ([5.088246856942773, 2.07399483126232]),  Fake Dist ([6.482142136573792, 1.8230713297809369]) 
Epoch 2000: ; Real Dist ([5.112281850039959, 1.9519776036228214]),  Fake Dist ([5.268023253440857, 1.5748381451724114]) 
Epoch 2500: ; Real Dist ([5.04068314383924, 1.9228489605375532]),  Fake Dist ([4.727532731056213, 1.5001497336720422]) 
Epoch 3000: ; Real Dist ([5.102851587474346, 2.0479261284382475]),  Fake Dist ([5.306155563354492, 1.703597770612663]) 
Epoch 3500: ; Real Dist ([4.983626407340169, 2.0956825773737378]),  Fake Dist ([5.04322944021225, 1.6618326131988692]) 
Epoch 4000: ; Real Dist ([4.8540205473

In [None]:
print("Plotting the generated distribution...")
values = extract(g_fake_data)
plt.hist(values, bins=50,color="red")
plt.xlabel('Value')
plt.ylabel('Count')
#plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.show()

In [None]:
print("Plotting the generated distribution...")
values = extract(d_real_data)
plt.hist(values, bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Real Distribution')
plt.grid(True)
plt.show()