# Deep Learning for Business Applications course

## TOPIC 7: Intro to Generative Adversarial Networks

### 1. Libraries and parameters

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# check if GPU available
# (works in GPU environment only)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device available:', DEVICE)

In [None]:
# noise dimension and channel size
NOISE_DIM = 100
CHANNELS = 128  # you may use 128, 256 or 512

# Training hyperparameters
NUM_EPOCHS = 20
BATCH_SIZE = 256
LR = .0002

# other
GEN_DATA = 'gens'
!mkdir -p $GEN_DATA

### 2. Generator and discriminator

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_dim, channels):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.main = nn.Sequential(
            nn.Linear(noise_dim, 7 * 7 * channels),
            nn.ReLU(True),
            nn.Unflatten(1, (channels, 7, 7)),
            nn.ConvTranspose2d(
                channels, int(channels / 2), 
                5, stride=1, padding=2
            ),
            nn.BatchNorm2d(int(channels / 2)),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                int(channels / 2), int(channels / 4), 
                5, stride=2, padding=2, output_padding=1
            ),
            nn.BatchNorm2d(int(channels / 4)),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                int(channels / 4), 1, 
                5, stride=2, padding=2, output_padding=1
            ),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, int(channels / 4), 5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(int(channels / 4)),
            nn.Conv2d(
                int(channels / 4), int(channels / 2), 
                5, stride=2, padding=2
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(int(channels / 2)),
            nn.Flatten(),
            nn.Linear(7 * 7 * int(channels / 2), 1)
        )

    def forward(self, x):
        return self.main(x)

### 3. Data

In [None]:
# dataset and dataloader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

# plot one sample
img, label = train_dataset[0]
plt.figure(figsize=(3, 3))
plt.imshow(img[0], cmap='gray')
plt.title('label of image: {}'.format(label))

### 4. Training

#### 4.1. Objects and utilities

In [None]:
# generator and discriminator
generator = Generator(noise_dim=NOISE_DIM, channels=CHANNELS)
generator = generator.to(DEVICE)
discriminator = Discriminator(channels=CHANNELS)
discriminator = discriminator.to(DEVICE)

In [None]:
# loss function
criterion = nn.BCEWithLogitsLoss()

# optimizers for both generator and discriminator
generator_optimizer = optim.Adam(
    generator.parameters(),
    lr=LR,
    betas=(.5, .999)
)
discriminator_optimizer = optim.Adam(
    discriminator.parameters(),
    lr=LR,
    betas=(.5, .999)
)

In [None]:
def generate_and_save_images(model, epoch, noise, path):
    """
    Generates images with som input noise
    then plots and saves generated images.
    
    """
    model.eval()
    with torch.no_grad():
        fake_images = model(noise).cpu()
        fake_images = fake_images.view(fake_images.size(0), 28, 28)
        fig = plt.figure(figsize=(4, 4))
        for i in range(fake_images.size(0)):
            plt.subplot(4, 4, i+1)
            plt.imshow(fake_images[i], cmap='gray')
            plt.axis('off')
        plt.savefig(f'{path}/img_at_epoch_{epoch+1:03d}.png')
        plt.show()

#### 4.2. Training loop

In [None]:
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(train_loader):
        real_images, _ = data
        real_images = real_images.to(DEVICE)

        # train discriminator with real images
        discriminator_optimizer.zero_grad()
        real_labels = torch.ones(real_images.size(0), 1, device=DEVICE)
        real_outputs = discriminator(real_images)
        real_loss = criterion(real_outputs, real_labels)
        real_loss.backward()

        # train discriminator with fake images
        generator.train()
        noise = torch.randn(real_images.size(0), NOISE_DIM, device=DEVICE)
        fake_images = generator(noise)
        fake_labels = torch.zeros(real_images.size(0), 1, device=DEVICE)
        fake_outputs = discriminator(fake_images.detach())
        fake_loss = criterion(fake_outputs, fake_labels)
        fake_loss.backward()
        discriminator_optimizer.step()

        # train generator
        generator_optimizer.zero_grad()
        fake_labels = torch.ones(real_images.size(0), 1, device=DEVICE)
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, fake_labels)
        gen_loss.backward()
        generator_optimizer.step()

        # display progress by batches
        if i % 10 == 0:
            print(
                f'epoch [{epoch + 1}/{NUM_EPOCHS}], step [{i + 1}/{len(train_loader)}] | '
                f'discriminator loss: {real_loss.item() + fake_loss.item():.4f} | '
                f'generator loss: {gen_loss.item():.4f}     ',
                end='\r'
            )

    # display epoch result
    print(
        f'epoch [{epoch + 1}/{NUM_EPOCHS}] done | '
        f'discriminator loss: {real_loss.item() + fake_loss.item():.4f} | '
        f'generator loss: {gen_loss.item():.4f}               '
    )
    
    # plot generated result
    if (epoch + 1) % 5 == 0:
        test_noise = torch.randn(16, NOISE_DIM, device=DEVICE)
        generate_and_save_images(generator, epoch, test_noise, path=GEN_DATA)

### <font color='red'>HOME ASSIGNMENT  (Option #2)</font>

You have to make few experiments with our toy GAN:
1. Run training process with different number of epochs. Find the epoch when generated images will become more close to real numbers (let's call it 'border' epoch).
2. Try to change `CHANNELS` parameter. How 'border' epoch has changed? What about training rime? Why is it happening?
3. __ADVANCED (not neseccary):__ plot graphs for generator and discrimator losses by the epochs. HINT: you may re-use code from `topic_03_finetuning` notebook.