## 不明白的点
+ 为什么采用的是全连接，而不是卷积和反卷积
+ 为什么Relu和Sigmod函数交叉使用
+ Discriminator最后一层的输出的维度为什么是1
+ 损失函数为什么是binary_cross_entropy，其他的损失函数行不行
+ GAN里面的博弈思想需要更详细的理解

In [24]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

"""
输入一组噪声，利用反卷积生成一组tensor数据

"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [33]:
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()

#         self.model = nn.Sequential(
#             nn.Linear(Z_dim, H_dim),
#             nn.ReLU(),
#             nn.Linear(H_dim,X_dim),
#             nn.Sigmoid()
#         )

#     def forward(self, input):
#         return self.model(input)
    
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator,self).__init__()
        
#         self.model = nn.Sequential(
#             nn.Linear(X_dim,H_dim),
#             nn.ReLU(),
#             nn.Linear(H_dim,1),
#             nn.Sigmoid()
#         )
    
#     def forward(self,input):
#         return self.model(input)

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(100,128),
            nn.ReLU(),
            nn.Linear(128,28*28),
            nn.Sigmoid() # 挤压到32*32的图像
        )
    
    def forward(self,x):
        insize = x.size(0)
        output = self.layer(x)
        output = output.view(insize,28,28)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,6,5), # 28
            nn.ReLU(),
            nn.MaxPool2d(2,2),# 14
            nn.Conv2d(6,16,5),# 10
            nn.ReLU(),
            nn.MaxPool2d(2,2) # 5
        )
        self.layer2 = nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        insize = x.size(0)
        output = self.layer1(x)
        output = output.view(insize,-1)
        output = self.layer2(output)
        return output

In [26]:
train_loader = data.DataLoader(
    dataset=datasets.MNIST(root='./data',download=False,train=True,transform=transforms.ToTensor()),
    batch_size = 64,
    shuffle=True
)
cifar_data = datasets.CIFAR10(root='./data',train=True,download=True,transform=transforms.ToTensor())
train_dataloader = data.DataLoader(dataset=cifar_data,shuffle=True,batch_size=128)

In [36]:
gen = Generator()
dis = Discriminator()

gen = gen.to(device)
dis = dis.to(device)

optim_gen = torch.optim.Adam(gen.parameters())
optim_dis = torch.optim.Adam(dis.parameters())

In [39]:
G_loss_run = 0.0
D_loss_run = 0.0
gen = gen.to(device)
dis = dis.to(device)

for epoch in range(30):
    for step,(x_data,y_data) in enumerate(train_dataloader):
        x_data = x_data.to(device)
        size = x_data.size()
        # noise 
        z = torch.randn(size[0],100)

        """
        Confused ???
        """
        ones_label = torch.ones(size[0],1)
        zeros_label = torch.zeros(size[0],1)

        """
        第一部分 更新判别网络
        """

        D_real = dis(x_data)
        D_fake = dis(gen(z))
        D_real_loss = F.binary_cross_entropy(D_real,ones_label)
        D_fake_loss = F.binary_cross_entropy(D_fake,zeros_label)

        D_loss = D_real_loss + D_fake_loss    
        optim_dis.zero_grad()
        D_loss.backward()
        optim_dis.step()

        """
        第二部分 更新生成网络
        生成网络的目的就在与不断调整网络结构使得生成出来的图片与真实图片不断接近
        """
        z = torch.randn(size[0],Z_dim)
        D_fake = dis(gen(z))
        G_loss = F.binary_cross_entropy(D_fake,ones_label)

        optim_gen.zero_grad()
        G_loss.backward()
        optim_gen.step()

        G_loss_run += G_loss.item()
        D_loss_run += D_loss.item()

        print('Epoch:{},   G_loss:{},    D_loss:{}'.format(1, G_loss_run/(step+1), D_loss_run/(step+1)))


Epoch:1,   G_loss:0.6579353213310242,    D_loss:1.3846793174743652
Epoch:1,   G_loss:0.6513187289237976,    D_loss:1.3857946991920471
Epoch:1,   G_loss:0.6476121544837952,    D_loss:1.3884308735529582
Epoch:1,   G_loss:0.6474071741104126,    D_loss:1.3934195339679718
Epoch:1,   G_loss:0.6496160387992859,    D_loss:1.3977505683898925
Epoch:1,   G_loss:0.6534068286418915,    D_loss:1.4004085461298625
Epoch:1,   G_loss:0.6584780216217041,    D_loss:1.4017420836857386
Epoch:1,   G_loss:0.6645385473966599,    D_loss:1.4015713334083557
Epoch:1,   G_loss:0.6714313957426283,    D_loss:1.4004259639316134
Epoch:1,   G_loss:0.6788182616233825,    D_loss:1.3983630895614625
Epoch:1,   G_loss:0.6863949353044684,    D_loss:1.3959266164086082
Epoch:1,   G_loss:0.6939402719338735,    D_loss:1.3931944767634075
Epoch:1,   G_loss:0.7014826444479135,    D_loss:1.3906580668229322
Epoch:1,   G_loss:0.7090623421328408,    D_loss:1.388281720025199
Epoch:1,   G_loss:0.7167040745417277,    D_loss:1.3859773317972

KeyboardInterrupt: 

In [12]:
fig = gen(torch.randn(1,100))
img = fig[0] # plt.imshow()只能接受3-D Tensor，所以也要用image[0]消去batch那一维
img = img.detach().numpy() # FloatTensor转为ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
plt.imshow(img)
plt.show()

TypeError: Invalid dimensions for image data