In [2]:
import os
import json
import math
import numpy as np 
import pandas as pd

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

  set_matplotlib_formats('svg', 'pdf') # For export


In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Device: cuda:0


In [4]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor()])

# Loading the training dataset. We need to split it into a training and validation part
dataset = ImageFolder("trafic_32", transform=transform)

FileNotFoundError: [WinError 3] System nie może odnaleźć określonej ścieżki: 'trafic_32'

In [None]:
train_loader = data.DataLoader(dataset, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)

In [None]:
class Cond_Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_of_classes):
        super(Cond_Discriminator, self).__init__()

        self.fc_1 = nn.Linear(input_dim + num_of_classes, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_out  = nn.Linear(hidden_dim, 1)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x, y):
        x = torch.flatten(x, 1)
        x = torch.cat([x,y], 1)
        x = self.LeakyReLU(self.fc_1(x))
        x = self.LeakyReLU(self.fc_2(x))
        x = self.fc_out(x)
        return x

In [None]:
class Cond_Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim, num_of_classes):
        super(Cond_Generator, self).__init__()
        self.fc_1 = nn.Linear(latent_dim+num_of_classes, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_3 = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x,y):
        x = torch.cat([x,y], 1)
        h     = self.LeakyReLU(self.fc_1(x))
        h     = self.LeakyReLU(self.fc_2(h))
        x_hat = torch.sigmoid(self.fc_3(h))
        x_hat = x_hat.view([-1, 3, 32, 32])
        return x_hat

In [None]:
# Models
latent_dim = 64
img_size = 32 * 32 * 3
cond_generator = Cond_Generator(latent_dim=latent_dim, hidden_dim=256, output_dim=img_size, num_of_classes=len(dataset.classes)).to(device)
cond_discriminator = Cond_Discriminator( hidden_dim=256, input_dim=img_size, num_of_classes=len(dataset.classes)).to(device)

# Optimizers
generator_optimizer = torch.optim.Adam(cond_generator.parameters(), lr=0.0001)
generator_scheduler = optim.lr_scheduler.ExponentialLR(optimizer=generator_optimizer, gamma=0.99)
discriminator_optimizer = torch.optim.Adam(cond_discriminator.parameters(), lr=0.0001)
discriminator_scheduler = optim.lr_scheduler.ExponentialLR(optimizer=discriminator_optimizer, gamma=0.99)

# loss
criterion = nn.MSELoss()

In [None]:
fixed_noise = torch.randn(16, latent_dim,device=device)
fixed_labels = torch.randint(len(dataset.classes),(16,),device=device)
fixed_labels = F.one_hot(fixed_labels, len(dataset.classes)).float()

## Trening

In [None]:
G_losses = []
D_losses = []
num_epochs = 201
for epoch in range(num_epochs):
    # For each batch in the dataloader
    discriminator_fake_acc = []
    discriminator_real_acc = []
    for i, data in enumerate(train_loader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        discriminator_optimizer.zero_grad()
        # Format batch
        real_images = data[0].to(device)
        y = data[1]
        y = F.one_hot(y, num_classes=len(dataset.classes)).to(device).float()
        b_size = real_images.size(0)
        label = torch.ones((b_size,), dtype=torch.float, device=device) # Setting labels for real images
        # Forward pass real batch through D
        output = cond_discriminator(real_images, y).view(-1)
        # Calculate loss on all-real batch
        error_discriminator_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        error_discriminator_real.backward()
        discriminator_real_acc.append(output.mean().item())

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, latent_dim,device=device)
        rand_y = torch.randint(len(dataset.classes),(b_size,),device=device)
        rand_y = F.one_hot(rand_y, len(dataset.classes)).float()
        # Generate fake image batch with Generator
        fake_images = cond_generator(noise, rand_y)
        label_fake = torch.zeros((b_size,), dtype=torch.float, device=device)
        # Classify all fake batch with Discriminator
        output = cond_discriminator(fake_images.detach(), rand_y).view(-1)
        # Calculate D's loss on the all-fake batch
        error_discriminator_fake = criterion(output, label_fake)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        error_discriminator_fake.backward()
        discriminator_fake_acc.append(output.mean().item())
        # Compute error of D as sum over the fake and the real batches
        error_discriminator = error_discriminator_real + error_discriminator_fake
        # Update D
        discriminator_optimizer.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
#         for _ in range(3):
        noise = torch.randn(b_size, latent_dim,device=device)
        rand_y = torch.randint(len(dataset.classes),(b_size,),device=device)
        rand_y = F.one_hot(rand_y, len(dataset.classes)).float()
        fake_images = cond_generator(noise, rand_y)
        generator_optimizer.zero_grad()
        label = torch.ones((b_size,), dtype=torch.float, device=device)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = cond_discriminator(fake_images, rand_y).view(-1)
        # Calculate G's loss based on this output
        error_generator = criterion(output, label)
        # Calculate gradients for G
        error_generator.backward()
        D_G_z2 = output.mean().item()
        # Update G
        generator_optimizer.step()

        # Output training stats
        # Save Losses for plotting later
        G_losses.append(error_generator.item())
        D_losses.append(error_discriminator.item())

    print(f"Epoch: {epoch}, discrimiantor fake error: {np.mean(discriminator_fake_acc):.3}, discriminator real acc: {np.mean(discriminator_real_acc):.3}")
    generator_scheduler.step()
    discriminator_scheduler.step()
    if epoch % 10 == 0:
        with torch.no_grad():
            fake = cond_generator(fixed_noise, fixed_labels).detach().cpu()
        grid = torchvision.utils.make_grid(fake)
        grid = grid.permute(1, 2, 0)
        plt.figure(figsize=(10,10))
        plt.title(f"Generations")
        plt.imshow(grid)
        plt.axis('off')
        plt.show()

## Generowanie wyników

In [None]:
def generate_images(number_of_images, clazz=None):
    fixed_noise = torch.randn(number_of_images, latent_dim, device=device)
    fixed_labels = torch.randint(len(dataset.classes), (number_of_images,), device=device)
    if clazz:
        fixed_labels = torch.tensor([clazz for _ in range(number_of_images)], device=device)
    fixed_labels = F.one_hot(fixed_labels, len(dataset.classes)).float()
    with torch.no_grad():
        fake = cond_generator(fixed_noise, fixed_labels).detach().cpu()
        grid = torchvision.utils.make_grid(fake)
        grid = grid.permute(1, 2, 0)
        plt.figure(figsize=(10,10))
        plt.title(f"Generations")
        plt.imshow(grid)
        plt.axis('off')
        plt.show()

In [None]:
generate_images(number_of_images=8, clazz=20)