In [2]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
latent_size = 64
batch_size = 100
sample_dir = '.'
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=(0.5,0.5,0.5),
    #                   std=(0.5,0.5,0.5) )
])

In [4]:
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=False)


data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

In [5]:
hidden_size = 256
image_size = 784
'''-----------------判别器（Discriminator)-----------------'''
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
D = D.to(device)
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

In [6]:
'''-----------------生成器（Generator）-----------------'''
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
G = G.to(device)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [None]:
import matplotlib.pyplot as plt
from torchvision import datasets, transforms,utils
num_epochs = 50
criterion = nn.BCELoss()
def denorm(x):
    out = (x+1)/2
    return out.clamp(0, 1) #clamp：取（0,1）范围内的数值，即“掐头去尾”

def reset_grad(): #梯度双清
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        print(images[10])
        real_labels = torch.ones(batch_size, 1).to(device)  # 定义真标签，全1，shape（100，1）
        fake_labels = torch.zeros(batch_size, 1).to(device)  # 定义假标签，全0，shape（100,1）
    
        '''-------------
        |   训练判别器   |
        --------------'''
        outputs = D(images) #前项
        d_loss_real = criterion(outputs, real_labels)  #损失计算
        real_score = outputs
        
        '''随机->生成器'''
        z = torch.randn(batch_size, latent_size).to(device)  # 设置随机种子z，shape(100,64)
        fake_images = G(z)  # 随机种子喂入生成器，形成假图
        
        '''假图->判别器'''
        outputs = D(fake_images)  # 利用判别器，计算假图输出
        d_loss_fake = criterion(outputs, fake_labels)  # 计算判别器假图输出和假标签损失
        fake_score = outputs
        
        '''判别优化'''
        d_loss = d_loss_fake+d_loss_real  # 计算损失和
        reset_grad()  # 梯度归零，在反向传播之前，使用optimizer将它要更新的所有张量的梯度清零(这些张量是模型可学习的权重)
        d_loss.backward()  # 梯度反向传播
        d_optimizer.step()  # 单步优化，调用optimizer的step函数更新所有参数


        '''-------------
        |   训练生成     |
        --------------'''       
        z = torch.randn(batch_size, latent_size).to(device)  # 生成随机种子
        fake_images = G(z)  # 随机种子喂入生成器
        outputs = D(fake_images)  # 判别器判别假图输出
        g_loss = criterion(outputs, real_labels)  # 判别器假图输出和真标签损失
        reset_grad()  # 梯度归零
        g_loss.backward()  # 反向传播
        g_optimizer.step()  # 优化
    
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}],step[{}/{}],d_loss:{:.4f},g_loss:{:.4f},D(x):{:.2f},D(G(z)):{:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean(), fake_score.mean()))
            # 每200个bacth，输出一次
    
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        img = utils.make_grid(images.cpu(),6)
        img = img.numpy().transpose(1,2,0) 
        plt.figure(figsize=(10,10))
        plt.imshow(img)
        plt.show()
    
    
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    img = utils.make_grid(fake_images.cpu(),6)
    img = img.detach().numpy().transpose(1,2,0) 
    plt.figure(figsize=(10,10))
    plt.imshow(img)
    plt.show()





tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 