In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import numpy as np

In [2]:
EPOCHS = 300
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Device:", DEVICE)

Device: cuda


In [3]:
trainset = torchvision.datasets.FashionMNIST(
    "./.data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
)
train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./.data\FashionMNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./.data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./.data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./.data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./.data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./.data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./.data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./.data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./.data\FashionMNIST\raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./.data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./.data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./.data\FashionMNIST\raw
Processing...




Done!


In [4]:
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 784),
    nn.Tanh()
)
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

In [5]:
D = D.to(DEVICE)
G = G.to(DEVICE)




In [6]:
criterion = nn.BCELoss()
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)

In [None]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
        
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
        
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        Z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
        fake_images = G(Z)
        
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        fake_images = G(Z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
    
    print(f"Epoch [{epoch}/{EPOCHS}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}\
            D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}")

z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
fake_images_cpu = fake_images.data.cpu().numpy()
for i in range(10):
    fake_images_img = np.reshape(fake_images_cpu[i], (28, 28))
    plt.imshow(fake_images_img, cmap = 'gray')
    plt.show()

Epoch [0/300], d_loss: 0.0460, g_loss: 4.3398            D(x): 0.99, D(G(z)): 0.03
Epoch [1/300], d_loss: 0.1039, g_loss: 4.8549            D(x): 0.96, D(G(z)): 0.02
Epoch [2/300], d_loss: 0.0727, g_loss: 6.7098            D(x): 0.98, D(G(z)): 0.02
Epoch [3/300], d_loss: 0.0151, g_loss: 7.1764            D(x): 1.00, D(G(z)): 0.01
Epoch [4/300], d_loss: 0.0029, g_loss: 9.7375            D(x): 1.00, D(G(z)): 0.00
Epoch [5/300], d_loss: 0.1194, g_loss: 5.8360            D(x): 0.97, D(G(z)): 0.02
Epoch [6/300], d_loss: 0.0609, g_loss: 5.0308            D(x): 1.00, D(G(z)): 0.05
Epoch [7/300], d_loss: 0.1763, g_loss: 5.6296            D(x): 0.95, D(G(z)): 0.05
Epoch [8/300], d_loss: 0.1614, g_loss: 4.0350            D(x): 0.97, D(G(z)): 0.09
Epoch [9/300], d_loss: 0.0630, g_loss: 7.1268            D(x): 0.98, D(G(z)): 0.02
Epoch [10/300], d_loss: 0.0891, g_loss: 6.1211            D(x): 0.99, D(G(z)): 0.05
Epoch [11/300], d_loss: 0.3123, g_loss: 3.2148            D(x): 0.89, D(G(z)): 0.06
Ep

Epoch [98/300], d_loss: 0.9730, g_loss: 1.9541            D(x): 0.72, D(G(z)): 0.31
Epoch [99/300], d_loss: 0.8393, g_loss: 2.2985            D(x): 0.73, D(G(z)): 0.26
Epoch [100/300], d_loss: 0.8337, g_loss: 1.8316            D(x): 0.72, D(G(z)): 0.28
Epoch [101/300], d_loss: 0.7566, g_loss: 1.6707            D(x): 0.76, D(G(z)): 0.28
Epoch [102/300], d_loss: 0.9136, g_loss: 1.7160            D(x): 0.75, D(G(z)): 0.33
Epoch [103/300], d_loss: 1.1020, g_loss: 1.6599            D(x): 0.69, D(G(z)): 0.36
Epoch [104/300], d_loss: 0.8889, g_loss: 1.6211            D(x): 0.73, D(G(z)): 0.30
Epoch [105/300], d_loss: 0.7374, g_loss: 1.9914            D(x): 0.75, D(G(z)): 0.24
Epoch [106/300], d_loss: 0.7368, g_loss: 1.8260            D(x): 0.77, D(G(z)): 0.27
Epoch [107/300], d_loss: 0.8168, g_loss: 1.6188            D(x): 0.74, D(G(z)): 0.32
Epoch [108/300], d_loss: 0.7589, g_loss: 2.5369            D(x): 0.77, D(G(z)): 0.27
Epoch [109/300], d_loss: 1.0309, g_loss: 1.6662            D(x): 0.

Epoch [195/300], d_loss: 0.8232, g_loss: 1.4835            D(x): 0.75, D(G(z)): 0.34
Epoch [196/300], d_loss: 1.1519, g_loss: 1.4995            D(x): 0.60, D(G(z)): 0.32
Epoch [197/300], d_loss: 1.0437, g_loss: 1.4222            D(x): 0.66, D(G(z)): 0.33
Epoch [198/300], d_loss: 0.8576, g_loss: 1.5026            D(x): 0.70, D(G(z)): 0.30
Epoch [199/300], d_loss: 1.2579, g_loss: 1.3690            D(x): 0.57, D(G(z)): 0.34
Epoch [200/300], d_loss: 1.0117, g_loss: 1.3304            D(x): 0.67, D(G(z)): 0.36
Epoch [201/300], d_loss: 1.2116, g_loss: 1.1813            D(x): 0.63, D(G(z)): 0.39
Epoch [202/300], d_loss: 0.9519, g_loss: 1.1737            D(x): 0.69, D(G(z)): 0.37
Epoch [203/300], d_loss: 1.1167, g_loss: 1.5359            D(x): 0.59, D(G(z)): 0.31
Epoch [204/300], d_loss: 0.8348, g_loss: 1.5610            D(x): 0.69, D(G(z)): 0.27
Epoch [205/300], d_loss: 1.0868, g_loss: 1.5761            D(x): 0.60, D(G(z)): 0.32
Epoch [206/300], d_loss: 0.9590, g_loss: 1.4691            D(x): 