## Import de bibliotecas

In [None]:
import re
import matplotlib.pyplot as plt
from PIL import Image  # Install Pillow -> conda install anaconda::pillow or pip install pillow
import os
from skimage.io import imread, imshow  # Install scikit-image -> conda install scikit-image or pip install scikit-image
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import pandas as pd
import torch.nn as nn
from torch.autograd import Variable
from torch import autograd
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from time import time

## Define Settings and Paths

In [None]:
# General Settings
IMG_WIDTH = 75
IMG_HEIGHT = 75
BATCH_SIZE = 16

# Paths
train_dataset_path = 'data-students/TRAIN/'
test_dataset_path = 'data-students/TEST'

## Define the Discriminator Class

In [None]:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        n = 128
        
        # Define the discriminator blocks
        self.db_1 = self._discriminator_block(4, n, 3, 1, 0)
        self.db_2 = self._discriminator_block(n, 2 * n, 5, 2, 1)
        self.db_3 = self._discriminator_block(2 * n, 4 * n, 4, 2, 1)
        self.db_4 = self._discriminator_block(4 * n, 8 * n, 4, 2, 0)
        
        # Final layer to output a single value
        self.adv_layer = nn.Conv2d(8 * n, 1, 8, 1, 0)

        # Embedding for the label
        self.embedding = nn.Embedding(10, 24 * 24)
        self.transpose_embedding = nn.ConvTranspose2d(1, 1, 6, 3, 0)

    def _discriminator_block(self, in_filters, out_filters, kernel_size, stride, padding):
        """Creates a block for the discriminator with Conv2d, InstanceNorm2d, and LeakyReLU."""
        block = [
            nn.Conv2d(in_filters, out_filters, kernel_size=kernel_size, stride=stride, padding=padding), 
            nn.InstanceNorm2d(out_filters, 0.8),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        return nn.Sequential(*block)

    def forward(self, img, label):
        # Embed the label and reshape it
        l = self.embedding(label)
        l = l.view(l.size(0), 1, 24, 24)
        l = self.transpose_embedding(l)
        
        # Concatenate the image and label embedding
        x = torch.cat([img, l], 1)
        
        # Apply discriminator blocks
        x = self.db_1(x)
        x = self.db_2(x)
        x = self.db_3(x)
        x = self.db_4(x)
        
        # Apply the final layer
        y = self.adv_layer(x)
        
        return y


## Define the Generator Class

In [None]:
class Generator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.channels = 3
        n = 128

        # Define the generator blocks
        self.gb_1 = self._generator_block(100 + 256, 16 * n, 4, 1, 0)
        self.gb_2 = self._generator_block(16 * n, 8 * n, 4, 2, 1)
        self.gb_3 = self._generator_block(8 * n, 4 * n, 4, 2, 0)
        self.gb_4 = self._generator_block(4 * n, 2 * n, 4, 2, 1)
        self.gb_5 = self._generator_block(2 * n, n, 5, 2, 1)
        
        # Final layer to output the image
        self.projection_layer = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=n, 
                out_channels=self.channels, 
                kernel_size=3, 
                stride=1, 
                padding=0
            ),
            nn.Tanh()
        )

        # Embedding for the label
        self.embedding = nn.Embedding(10, 256)

    def _generator_block(self, in_filters, out_filters, kernel_size, stride, padding):
        """Creates a block for the generator with ConvTranspose2d, BatchNorm2d, and ReLU."""
        block = [
            nn.ConvTranspose2d(in_filters, out_filters, kernel_size=kernel_size, stride=stride, padding=padding), 
            nn.BatchNorm2d(out_filters, 0.8),
            nn.ReLU(True)
        ]
        return nn.Sequential(*block)

    def forward(self, noise, label):
        # Embed the label and reshape it
        l = self.embedding(label)
        l = l.view(l.size(0), l.size(1), 1, 1)
        
        # Concatenate the noise and label embedding
        x = torch.cat([noise, l], 1)
        
        # Apply generator blocks
        x = self.gb_1(x)
        x = self.gb_2(x)
        x = self.gb_3(x)
        x = self.gb_4(x)
        x = self.gb_5(x)
        
        # Apply the final projection layer
        y = self.projection_layer(x)
        
        return y
    
    def make_random_noise_vector(self, batch_size):
        """Generates a random noise vector."""
        return torch.randn(batch_size, 100, 1, 1)


## Define the WGAN class

In [None]:
class WGAN(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        
        self.generator = Generator()
        self.discriminator = Discriminator()

        self.generator = self.generator.to('cuda')
        self.discriminator = self.discriminator.to('cuda')
    
    def forward(self, label):
        noise = self.generator.make_random_noise_vector(1).to('cuda')
        fake_image = self.generator(noise, label)
        fake_image = fake_image.cpu().detach().numpy()
        fake_image = (fake_image + 1) / 2  # Normalize to [0, 1]
        return fake_image.transpose(0, 2, 3, 1)[0]

    def gradient_penalty(self, real_images, fake_images, label):
        batch_size = real_images.size(0)
        eta = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0, 1).to('cuda')
        eta = eta.expand_as(real_images)

        interpolated = eta * real_images + ((1 - eta) * fake_images)
        interpolated = interpolated.to('cuda')
        interpolated = Variable(interpolated, requires_grad=True)

        # Calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated, label)

        # Calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(
            outputs=prob_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones(prob_interpolated.size()).to('cuda'),
            create_graph=True,
            retain_graph=True
        )[0]

        # Flatten the gradients to calculate norm batchwise
        gradients = gradients.view(batch_size, -1)
        grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return grad_penalty

    def train(self, dataset, n_epochs: int, batch_size: int, n_critic: int, lr_generator: float, lr_discriminator: float):

        for epoch in range(n_epochs):
            pbar = tqdm(enumerate(dataset), total=len(dataset))
            for i, (imgs, label) in pbar:
                label = label.to('cuda')
                imgs = imgs.to('cuda')
                batch_size = imgs.size(0)

                for p in self.discriminator.parameters():
                    p.requires_grad = True

                # Train the discriminator
                for _ in range(n_critic):
                    self.discriminator.zero_grad()

                    # Generate a batch of images
                    noise = self.generator.make_random_noise_vector(batch_size).to('cuda')

                    # Loss for real images
                    d_real_loss = self.discriminator(imgs, label).mean()

                    # Loss for fake images
                    fake_images = self.generator(noise, label).detach()
                    d_fake_loss = self.discriminator(fake_images, label).mean()

                    w_d = d_real_loss - d_fake_loss
                    # Gradient penalty
                    gradient_penalty = self.gradient_penalty(imgs.data, fake_images.data, label)

                    # Total loss
                    d_loss = -w_d + gradient_penalty * 5
                    d_loss.backward()
                    optimizer_discriminator.step()
                
                # Train the generator
                for p in self.discriminator.parameters():
                    p.requires_grad = False

                self.generator.zero_grad()

                noise = self.generator.make_random_noise_vector(batch_size).to('cuda')
                fake_images = self.generator(noise, label)
                g_loss = self.discriminator(fake_images, label).mean()
                g_loss = -g_loss
                g_loss.backward()
                optimizer_generator.step()

                pbar.set_description(f'Epoch: {epoch}, Batch: {i}, Discriminator Loss: {d_loss:.4f}, Generator Loss: {g_loss:.4f}')

                if i % 100 == 0:  # Save output every 100 batches
                    l = torch.tensor([0]).to('cuda')
                    image = self.forward(l)
                    plt.imshow(image)
                    plt.savefig(f'output/output_epoch_{epoch}_batch_{i}.png')

            if epoch % 10 == 0:
                self.save_images(epoch)
                self.save_models(epoch)

    def save_images(self, epoch):
        noise = self.generator.make_random_noise_vector(25).to('cuda')
        labels = torch.randint(0, 10, (25,)).to('cuda')
        fake_images = self.generator(noise, labels).cpu().detach().numpy()
        fake_images = (fake_images + 1) / 2  # Normalize to [0, 1]

        fig, axes = plt.subplots(5, 5, figsize=(10, 10))
        for i, ax in enumerate(axes.flat):
            ax.imshow(fake_images[i].transpose(1, 2, 0))
            ax.axis('off')
            ax.set_title(f'Label: {labels[i].item()}')
        plt.savefig(f'output/epoch_{epoch}.png')

    def save_models(self, epoch):
        self.save(f'output/epoch_{epoch}', save_generator=True, save_discriminator=True)

    def save(self, path, save_generator=True, save_discriminator=True):
        if save_generator:
            torch.save(self.generator.state_dict(), path + '_generator.pth')
        if save_discriminator:
            torch.save(self.discriminator.state_dict(), path + '_discriminator.pth')

    def load(self, path, load_generator=True, load_discriminator=True):
        if load_generator:
            self.generator.load_state_dict(torch.load(path + '_generator.pth'))
        if load_discriminator:
            self.discriminator.load_state_dict(torch.load(path + '_discriminator.pth'))


## Execution

In [None]:
# Define paths and settings
train_dataset_path = 'data-students/TRAIN/'
IMG_WIDTH = 75
IMG_HEIGHT = 75
BATCH_SIZE = 16

def prepare_datasets(train_dataset_path, img_width, img_height, batch_size):
    """Prepares the datasets and dataloaders."""
    transform = transforms.Compose([
        transforms.Resize((img_width, img_height)),
        transforms.ToTensor()
    ])
    
    # Load dataset
    traffic_signals_dataset = datasets.ImageFolder(root=train_dataset_path, transform=transform)
    
    # Split dataset into training and validation sets
    train_idx, valid_idx = train_test_split(
        range(len(traffic_signals_dataset)),
        test_size=0.1,
        shuffle=True,
        stratify=traffic_signals_dataset.targets
    )

    train_subset = Subset(traffic_signals_dataset, train_idx)
    valid_subset = Subset(traffic_signals_dataset, valid_idx)

    # Create dataloaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, traffic_signals_dataset.targets

def display_class_distribution(targets):
    """Displays the class distribution of the dataset."""
    class_counts = {k: 0 for k in set(targets)}
    for t in targets:
        class_counts[t] += 1
    print('Training class distribution:', class_counts)

def main():
    # Prepare datasets and dataloaders
    train_loader, valid_loader, training_targets = prepare_datasets(
        train_dataset_path, IMG_WIDTH, IMG_HEIGHT, BATCH_SIZE
    )
    
    # Display class distribution
    display_class_distribution(training_targets)
    
    # Initialize WGAN
    wgan = WGAN()
    
    # Train WGAN
    wgan.train(
        dataset=train_loader, 
        n_epochs=200, 
        batch_size=BATCH_SIZE, 
        n_critic=5, 
        lr_generator=1e-4, 
        lr_discriminator=1e-4
    )

if __name__ == '__main__':
    main()
