In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random

In [None]:
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision import transforms

In [None]:
from matplotlib import pyplot as plt

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = CIFAR10('cifar', transform=transform, download=True, )

In [None]:
len(data)

In [None]:
def get_batch(data, batch_size):
    indexes = [random.randint(0, len(data) - 1) for _ in range(batch_size)]
    images = [data[index][0].tolist() for index in indexes]
    labels = [data[index][1] for index in indexes]
    return torch.tensor(images), torch.tensor(labels)

## TRAINING

In [None]:
class CifarGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 8192),
            nn.LeakyReLU(0.01, True),
            nn.Linear(8192, 8192),
            nn.LeakyReLU(0.01, True),
            nn.Linear(8192, 3 * 32 * 32),
            nn.Tanh()
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        return self.main(x).reshape(batch_size, 3, 32, 32)


class CifarDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(3 * 32 * 32, 512),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        batch_size = x.shape[0]
        return self.main(x.reshape(batch_size, 3 * 32 * 32))

In [None]:
generator = CifarGenerator()
discriminator = CifarDiscriminator()

In [None]:
batch = 100

In [None]:
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4)
generator_optimizer = optim.Adam(generator.parameters(), lr=2e-4)

# criterion = nn.BCELoss()
data_loader = torch.utils.data.DataLoader(data, batch_size=batch, shuffle=True)

for epoch in range(100):
    for batch_id, (x, _) in enumerate(data_loader):
        # sampling
        batch_of_noise = torch.randn(batch, 100)
        
        # zero grad
        generator_optimizer.zero_grad()
        discriminator_optimizer.zero_grad()

        # predictions
        batch_of_generated = generator(batch_of_noise)
        batch_of_generated_discrimination = discriminator(batch_of_generated)
        batch_of_real_discrimination = discriminator(x)

        # computing loss
        discriminator_loss = batch_of_generated_discrimination.mean() - batch_of_real_discrimination.mean()
        discriminator_loss.backward()

        discriminator_optimizer.step()

        batch_of_generated = generator(batch_of_noise)
        batch_of_generated_discrimination = discriminator(batch_of_generated)

        generator_loss = - batch_of_generated_discrimination.mean() - 0.3 * batch_of_generated.std(dim=0).sum()
        discriminator_optimizer.zero_grad()
        generator_optimizer.zero_grad()
        generator_loss.backward()

        generator_optimizer.step()
        print(f'------{batch_id}:D:{discriminator_loss.item()}:G:{generator_loss.item()}---')
    # log
    print(f'---{epoch}:D:{discriminator_loss.item()}:G:{generator_loss.item()}---')
    plt.imshow(transforms.ToPILImage()(batch_of_generated[0]))
    plt.show()