In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision 
from torchvision import transforms

In [2]:
torch.__version__

'1.12.0+cpu'

## 数据准备

In [3]:
#对数据进行归一化（-1，1）
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0,0.5)
])

In [4]:
train_ds = torchvision.datasets.MNIST('data',
                                     train=True,
                                     transform= transform,
                                     download=True)

In [5]:
dataloader = torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle= True)

In [6]:
imgs,_ = next(iter(dataloader))
imgs.shape

torch.Size([64, 1, 28, 28])

## 定义生成器
### 输入长度100的噪声（正态分布的随机数）
### 输出是（1，28，28）的图片
linear 1 : 100 - 256
linear 2 : 256 - 512
linear 3 : 100 - 28*28
reshape : 28*28 - (1,28,28)

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main = nn.Sequential(
                                    nn.Linear(100,256),
                                    nn.ReLU(),
                                    nn.Linear(256,512),
                                    nn.ReLU(),
                                    nn.Linear(512,28*28),
                                    nn.Tanh()
        )
    def forward(self,x):
        img = self.main(x)
        img = img.view(-1,28,28,1)
        return img

## 定义判别器
### 输入为（1，28，28）的图片，输出为二分类的概率值，使用sigmoid激活函数 0-1
### BCELoss计算交叉熵损失
### 判别器一般推荐leakReLU

In [8]:
class Discriminator (nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
                                nn.Linear(28*28,512),
                                nn.LeakyReLU(),
                                nn.Linear(512,256),
                                nn.LeakyReLU(),
                                nn.Linear(256,1),
                                nn.Sigmoid()
                                )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.main(x)
        return x
    

### 初始化模型 优化器及损失计算函数

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
gen = Generator().to(device)
dis = Discriminator().to(device)

In [11]:
d_optim = torch.optim.Adam(dis.parameters(),lr=0.001)
g_optim = torch.optim.Adam(gen.parameters(),lr=0.001)

In [12]:
loss_fn = torch.nn.BCELoss()

### 绘图函数

In [13]:
def gen_img_plt(model,test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()

In [14]:
test_input = torch.randn(16,100,device=device)

### GAN的训练

In [15]:
D_loss = []
G_loss = []

In [None]:
#训练循环
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step ,(img,_) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size,100,device= device)
        
        d_optim.zero_grad()
        real_output = dis(img)  #判别器输入真实图片,real_output对生成图片的预测
        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output)) #得到真实图片在判别器上的损失
        d_real_loss.backward()
        
        gen_img = gen(random_noise)#判别器输入生成图片,fake_output对生成图片的预测
        fake_output = dis(gen_img.detach()) #detach()的作用是截断梯度,但是没搞懂
        d_fake_loss = loss_fn(fake_output,
                             torch.zeros_like(fake_output))#得到在生成图片在判别器上的损失
        d_fake_loss.backward()
        d_loss = d_real_loss+d_fake_loss
        d_optim.step()
        
        g_optim.zero_grad()
        fake_output = dis(gen_img)
        g_loss =loss_fn(fake_output,
                       torch.ones_like(fake_output)) #生成器损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:',epoch+1)
        gen_img_plt(gen,test_input=test_input)

Epoch: 1
