In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matploitlib.pyplot as plt
import numpy as np

ModuleNotFoundError: No module named 'matploitlib'

In [3]:
EPOCHS = 300
BATCH_SIZE = 100
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
print("Device:", DEVICE)

Device: cuda


In [4]:
trainset = torchvision.datasets.FashionMNIST(
    "./.data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
)
train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed =  nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )
        
    def forward(self, z, labels):
        c = self.embed(labels)
        x = torch.cat([z, c], 1)
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed = nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(784 + 10, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, labels):
        c = self.embed(labels)
        x = torch.cat([x, c], 1)
        return self.model(x)

In [6]:
D = Discriminator().to(DEVICE)
G = Generator().to(DEVICE)
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = optim.Adam(G.parameters(), lr=0.0002)

In [None]:
total_step = len(train_loader)
for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(BATCH_SIZE, -1).to(DEVICE)
        real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)
        
        labels = labels.to(DEVICE)
        outputs = D(images, labels)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        z = torch.randn(BATCH_SIZE, 100).to(DEVICE)
        g_label = torch.randint(0, 10, (BATCH_SIZE,)).to(DEVICE)
        fake_images = G(z, g_label)
        
        outputs = D(fake_images, g_label)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        fake_images = G(z, g_label)
        outputs = D(fake_images, g_label)
        g_loss = criterion(outputs, real_labels)

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
    
    print(f"Epoch [{epoch}/{EPOCHS}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}\
            D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}")

Epoch [0/300], d_loss: 0.3098, g_loss: 3.8268            D(x): 0.89, D(G(z)): 0.13
Epoch [1/300], d_loss: 0.7070, g_loss: 2.9519            D(x): 0.81, D(G(z)): 0.24
Epoch [2/300], d_loss: 0.5622, g_loss: 3.5931            D(x): 0.91, D(G(z)): 0.24
Epoch [3/300], d_loss: 0.6224, g_loss: 2.6376            D(x): 0.88, D(G(z)): 0.23
Epoch [4/300], d_loss: 0.6928, g_loss: 2.8056            D(x): 0.79, D(G(z)): 0.14
Epoch [5/300], d_loss: 0.5979, g_loss: 2.5283            D(x): 0.85, D(G(z)): 0.24
Epoch [6/300], d_loss: 0.8238, g_loss: 1.8892            D(x): 0.80, D(G(z)): 0.31
Epoch [7/300], d_loss: 0.7647, g_loss: 1.7173            D(x): 0.72, D(G(z)): 0.28
Epoch [8/300], d_loss: 0.6474, g_loss: 2.2031            D(x): 0.80, D(G(z)): 0.24
Epoch [9/300], d_loss: 0.7032, g_loss: 2.0464            D(x): 0.74, D(G(z)): 0.20
Epoch [10/300], d_loss: 1.1289, g_loss: 1.5758            D(x): 0.72, D(G(z)): 0.38
Epoch [11/300], d_loss: 1.1821, g_loss: 1.5223            D(x): 0.74, D(G(z)): 0.41
Ep

Epoch [98/300], d_loss: 1.0603, g_loss: 1.1251            D(x): 0.65, D(G(z)): 0.39
Epoch [99/300], d_loss: 1.1074, g_loss: 1.0353            D(x): 0.62, D(G(z)): 0.38
Epoch [100/300], d_loss: 1.1944, g_loss: 0.9210            D(x): 0.57, D(G(z)): 0.42
Epoch [101/300], d_loss: 1.0131, g_loss: 1.1608            D(x): 0.68, D(G(z)): 0.39
Epoch [102/300], d_loss: 1.1689, g_loss: 1.0504            D(x): 0.56, D(G(z)): 0.35
Epoch [103/300], d_loss: 1.2978, g_loss: 0.9357            D(x): 0.55, D(G(z)): 0.43
Epoch [104/300], d_loss: 1.1532, g_loss: 0.9188            D(x): 0.59, D(G(z)): 0.42
Epoch [105/300], d_loss: 1.1794, g_loss: 1.0404            D(x): 0.59, D(G(z)): 0.40
Epoch [106/300], d_loss: 1.1261, g_loss: 1.0179            D(x): 0.59, D(G(z)): 0.39
Epoch [107/300], d_loss: 1.2410, g_loss: 0.9444            D(x): 0.58, D(G(z)): 0.43
Epoch [108/300], d_loss: 1.2071, g_loss: 1.0792            D(x): 0.58, D(G(z)): 0.38
Epoch [109/300], d_loss: 1.2559, g_loss: 1.0242            D(x): 0.

In [None]:
item_number = 9
z = torch.randn(1, 100).to(DEVICE)
g_label = torch.full((1,), item_number, dtype=torch.long).to(DEVICE)
sample_images = G(z, g_label)

sample_images_img = np.reshape(sample_images.data.cpu().numpy()
                               [0],(28, 28))
plt.imshow(sample_images_img, cmap = 'gray')
plt.show()