In [5]:
import os

import torch
import torchvision
import torch.nn as nn
from torchvision import  transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import time

start_time = time.time()


device=device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100

sample_dir = 'D:/data'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],std=[0.5])
])

mnist = torchvision.datasets.MNIST(root='D:/data',
                                   train=True,
                                   transform=transform,
                                   download=True)
data_loader = DataLoader(dataset=mnist,
                         batch_size=batch_size,
                        shuffle=True)

D=nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid()
)


G=nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh()
)

D = D.to(device)
G = G.to(device)

criterion=nn.BCELoss()

d_opt=torch.optim.Adam(D.parameters(),lr=0.0002)
g_opt=torch.optim.Adam(G.parameters(),lr=0.0002)

def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)

def re_grad():
    d_opt.zero_grad()
    g_opt.zero_grad()

total_step=len(data_loader)

for epoch in range(num_epochs):
    for i,(images,labels) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        outputs=D(images)
        d_loss_real=criterion(outputs,real_labels)

        real_score=outputs

        z=torch.randn(batch_size,latent_size).to(device)

        fake_images=G(z)


        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss=d_loss_fake+d_loss_real

        re_grad()

        d_loss.backward()

        d_opt.step()

        z=torch.randn(batch_size,latent_size).to(device)

        fake_images=G(z)
        outputs=D(fake_images)

        g_loss=criterion(outputs,real_labels)

        re_grad()

        g_loss.backward()

        g_opt.step()

        if (i + 1) % 200 == 0 :
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    if (epoch + 1) == 1 :
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

        # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))

torch.save(G.state_dict(), 'D:/G.pt')
torch.save(D.state_dict(), 'D:/D.pt')
end_time = time.time()
execution_time = end_time - start_time

print('时间为：',execution_time)




Epoch [0/200], Step [200/600], d_loss: 0.0380, g_loss: 4.3761, D(x): 0.99, D(G(z)): 0.03
Epoch [0/200], Step [400/600], d_loss: 0.0532, g_loss: 6.1023, D(x): 0.97, D(G(z)): 0.02
Epoch [0/200], Step [600/600], d_loss: 0.0458, g_loss: 4.8051, D(x): 0.98, D(G(z)): 0.02
Epoch [1/200], Step [200/600], d_loss: 0.0310, g_loss: 6.0029, D(x): 0.99, D(G(z)): 0.01
Epoch [1/200], Step [400/600], d_loss: 0.2351, g_loss: 3.4980, D(x): 0.90, D(G(z)): 0.05
Epoch [1/200], Step [600/600], d_loss: 0.1568, g_loss: 7.6516, D(x): 0.96, D(G(z)): 0.09
Epoch [2/200], Step [200/600], d_loss: 0.5237, g_loss: 2.5571, D(x): 0.82, D(G(z)): 0.15
Epoch [2/200], Step [400/600], d_loss: 0.2069, g_loss: 4.5404, D(x): 0.94, D(G(z)): 0.09
Epoch [2/200], Step [600/600], d_loss: 0.6847, g_loss: 3.1632, D(x): 0.86, D(G(z)): 0.31
Epoch [3/200], Step [200/600], d_loss: 0.9319, g_loss: 1.8768, D(x): 0.82, D(G(z)): 0.36
Epoch [3/200], Step [400/600], d_loss: 1.3164, g_loss: 4.5750, D(x): 0.78, D(G(z)): 0.34
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.4565, g_loss: 4.1811, D(x): 0.86, D(G(z)): 0.07
Epoch [31/200], Step [200/600], d_loss: 0.3227, g_loss: 2.8308, D(x): 0.91, D(G(z)): 0.11
Epoch [31/200], Step [400/600], d_loss: 0.2546, g_loss: 4.8864, D(x): 0.91, D(G(z)): 0.06
Epoch [31/200], Step [600/600], d_loss: 0.5045, g_loss: 4.8381, D(x): 0.84, D(G(z)): 0.12
Epoch [32/200], Step [200/600], d_loss: 0.2846, g_loss: 4.1228, D(x): 0.92, D(G(z)): 0.12
Epoch [32/200], Step [400/600], d_loss: 0.2835, g_loss: 3.9550, D(x): 0.92, D(G(z)): 0.11
Epoch [32/200], Step [600/600], d_loss: 0.4019, g_loss: 3.6037, D(x): 0.87, D(G(z)): 0.10
Epoch [33/200], Step [200/600], d_loss: 0.2474, g_loss: 4.1444, D(x): 0.93, D(G(z)): 0.09
Epoch [33/200], Step [400/600], d_loss: 0.2466, g_loss: 3.4604, D(x): 0.90, D(G(z)): 0.07
Epoch [33/200], Step [600/600], d_loss: 0.5793, g_loss: 2.2675, D(x): 0.95, D(G(z)): 0.28
Epoch [34/200], Step [200/600], d_loss: 0.3996, g_loss: 3.2407, D(x): 0.95, D(G(z)): 0.19
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.6814, g_loss: 2.2926, D(x): 0.83, D(G(z)): 0.25
Epoch [61/200], Step [600/600], d_loss: 0.5814, g_loss: 2.0921, D(x): 0.86, D(G(z)): 0.22
Epoch [62/200], Step [200/600], d_loss: 0.6261, g_loss: 2.6960, D(x): 0.78, D(G(z)): 0.16
Epoch [62/200], Step [400/600], d_loss: 0.5426, g_loss: 2.7437, D(x): 0.84, D(G(z)): 0.21
Epoch [62/200], Step [600/600], d_loss: 0.6266, g_loss: 2.1382, D(x): 0.79, D(G(z)): 0.22
Epoch [63/200], Step [200/600], d_loss: 0.5674, g_loss: 2.9002, D(x): 0.81, D(G(z)): 0.15
Epoch [63/200], Step [400/600], d_loss: 0.5054, g_loss: 2.5707, D(x): 0.82, D(G(z)): 0.16
Epoch [63/200], Step [600/600], d_loss: 0.7020, g_loss: 2.5435, D(x): 0.72, D(G(z)): 0.13
Epoch [64/200], Step [200/600], d_loss: 0.8878, g_loss: 1.8565, D(x): 0.79, D(G(z)): 0.34
Epoch [64/200], Step [400/600], d_loss: 0.5012, g_loss: 2.9842, D(x): 0.81, D(G(z)): 0.12
Epoch [64/200], Step [600/600], d_loss: 0.6495, g_loss: 2.2167, D(x): 0.78, D(G(z)): 0.18
Epoch [65/

KeyboardInterrupt: 