本文使用GAN来训练一个能生成一组符合高斯分布的数据。

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal

高斯分布的参数定义如下，每次生成30个符合高斯分布的数据

In [2]:
data_mean = 3.0
data_stddev = 0.4
Series_Length = 30

下面两个函数分别从高斯分布，0-1均匀分布中采样，返回可以生成$m\times n$ tensor的函数。前者用于生成真实样本，后者用于生成输入Generator的噪声。

In [3]:
def get_real_sampler(mu, sigma):
    dist = Normal(mu, sigma)
    return lambda m, n: dist.sample((m, n)).requires_grad_()

def get_noise_sampler():
    return lambda m, n: torch.rand(m, n).requires_grad_()

actual_data = get_real_sampler(data_mean, data_stddev)
noise_data = get_noise_sampler()

下面定义两个子网络的结构：Generator和Discriminator。

In [4]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.xfer = torch.nn.SELU()
        
    def forward(self, x):
        x = self.xfer(self.map1(x))
        x = self.xfer(self.map2(x))
        return self.xfer(self.map3(x))
    

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.elu = torch.nn.ELU()
 
    def forward(self, x):
        x = self.elu(self.map1(x))
        x = self.elu(self.map2(x))
        return torch.sigmoid(self.map3(x))

下面实例化网络，并定义损失函数，优化函数：

In [5]:
g_input_size = 20    
g_hidden_size = 150  
g_output_size = Series_Length

d_input_size = Series_Length
d_hidden_size = 75   
d_output_size = 1

d_learning_rate = 3e-3
g_learning_rate = 8e-3

G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)

d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) 
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate )

下面开始训练，每轮交替训练Discriminator和Generator：

![title](./GAN_algo.png)

注意！在开始的若干轮中，Generator还比较弱，所以$D(G(Z^{(i)}))\rightarrow 0$，所以$log(1-D(G(Z^{(i)})))\rightarrow 0$。为了缓解这个问题，后面我都将$log(1-D(G(Z^{(i)})))$替换为$-log(D(G(Z^{(i)})))$，事实证明这样效果很好。

In [6]:
num_epochs = 1500

d_minibatch_size = 15 
g_minibatch_size = 10

for epoch in range(num_epochs):
    
    """
        训练Discriminator
    """
    D.zero_grad()
    # 用真实样本训练D 
    real_data = actual_data(d_minibatch_size, d_input_size)
    real_decision = D(real_data)
    real_error = -torch.sum(torch.log(real_decision))/d_minibatch_size
    real_error.backward()
    
    # 用生成样本训练D
    noise = noise_data(d_minibatch_size, g_input_size)
    fake_data = G(noise) 
    fake_decision = D(fake_data)
    fake_error = torch.sum(torch.log(fake_decision))/d_minibatch_size
    fake_error.backward()
    
    d_optimizer.step()
    
    
    """
        训练Generator
    """
    G.zero_grad()
    # 训练G
    noise = noise_data(g_minibatch_size, g_input_size)
    fake_data = G(noise)
    fake_decision = D(fake_data)
    gen_loss = -torch.sum(torch.log(fake_decision))/g_minibatch_size
    gen_loss.backward()

    g_optimizer.step()
    
    if epoch % 50 == 0:
        print("Epoch %d.D Loss on real date %5.3f" % (epoch, real_error))
        print("Epoch %d.G Loss %5.3f" % (epoch, gen_loss))

Epoch 0.D Loss on real date 0.567
Epoch 0.G Loss 0.703
Epoch 50.D Loss on real date 0.094
Epoch 50.G Loss 0.559
Epoch 100.D Loss on real date 0.079
Epoch 100.G Loss 0.542
Epoch 150.D Loss on real date 0.115
Epoch 150.G Loss 0.596
Epoch 200.D Loss on real date 0.181
Epoch 200.G Loss 0.641
Epoch 250.D Loss on real date 0.223
Epoch 250.G Loss 0.571
Epoch 300.D Loss on real date 0.261
Epoch 300.G Loss 0.519
Epoch 350.D Loss on real date 0.318
Epoch 350.G Loss 0.467
Epoch 400.D Loss on real date 0.381
Epoch 400.G Loss 0.443
Epoch 450.D Loss on real date 0.257
Epoch 450.G Loss 0.169
Epoch 500.D Loss on real date 0.083
Epoch 500.G Loss 0.037
Epoch 550.D Loss on real date 0.038
Epoch 550.G Loss 0.009
Epoch 600.D Loss on real date 0.021
Epoch 600.G Loss 0.006
Epoch 650.D Loss on real date 0.015
Epoch 650.G Loss 0.003
Epoch 700.D Loss on real date 0.010
Epoch 700.G Loss 0.003
Epoch 750.D Loss on real date 0.008
Epoch 750.G Loss 0.001
Epoch 800.D Loss on real date 0.007
Epoch 800.G Loss 0.002
Epo