In [74]:
#project gans

In [75]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [76]:
def get_sample_image(generator, noise_dim):
    """
    Save sample 100 images
    """
    noise = torch.randn(100, noise_dim).to(device)
    generated_images = generator(noise).view(100, 28, 28)  # (100, 28, 28)
    result = generated_images.cpu().data.numpy()
    img = np.zeros([280, 280])
    for j in range(10):
        img[j * 28:(j + 1) * 28] = np.concatenate([x for x in result[j * 10:(j + 1) * 10]], axis=-1)
    return img

In [77]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x, y=False):
        features = self.conv(x)
        features = features.view(features.size(0), -1)
        output = self.fc(features)
        return output


In [78]:
class Generator(nn.Module):
    def __init__(self, input_size=100, num_classes=784):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 4 * 4 * 512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )

    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        features = self.fc(x)
        features = features.view(features.size(0), 512, 4, 4)
        output = self.conv(features)
        return output


In [79]:
# Instantiate the Generator and Discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [79]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],
                                std=[0.5])]
)

In [None]:
batch_size = 64

data = torchvision.datasets.FashionMNIST(root='./data/', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, drop_last=True)

loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.999))


In [None]:
max_epochs = 50
step = 0
n_critic = 1
n_noise = 100

d_labels = torch.ones([batch_size, 1]).to(device)
d_fakes = torch.zeros([batch_size, 1]).to(device)

In [None]:
# Training loop
for epoch in range(max_epochs):
    for idx, (images, labels) in enumerate(data_loader):
        real_images = images.to(device)

        # Discriminator training
        real_outputs = discriminator(real_images)
        d_real_loss = loss_fn(real_outputs, d_labels)

        fake_noise = torch.randn(batch_size, n_noise).to(device)
        fake_images = generator(fake_noise)
        fake_outputs = discriminator(fake_images.detach())
        d_fake_loss = loss_fn(fake_outputs, d_fakes)

        d_loss = d_real_loss + d_fake_loss

        discriminator.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Generator training (every n_critic iterations)
        if step % n_critic == 0:
            fake_outputs = discriminator(fake_images)
            g_loss = loss_fn(fake_outputs, d_labels)

            generator.zero_grad()
            discriminator.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            if step % 500 == 0:
                print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epochs, step, d_loss.item(), g_loss.item()))

            if step % 1000 == 0:
                generator.eval()
                img = get_sample_image(generator, n_noise)
                # imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
                generator.train()
            step += 1

In [None]:
# neeed to test