In [None]:
import torch
from torch import nn, autograd
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json

# Define o dispositivo como a GPU disponível, se houver uma, senão utiliza a CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Arquitetura do Gerador
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 3*64*64),  # para gerar imagens 64x64 com 3 canais (RGB)
            nn.Tanh()
        )

    def forward(self, input):
        return self.gen(input)

# Arquitetura do Discriminador
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(3*64*64, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, input):
        return self.disc(input)
print("Using device:", device)

def compute_gradient_penalty(D, real_samples, fake_samples):

    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.rand((real_samples.size(0), 1), device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).requires_grad_(False).to(device)
    gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates,
                              grad_outputs=fake, create_graph=True, retain_graph=True,
                              only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Variáveis
z_dim = 100
learning_rate = 0.0001
batch_size = 64
num_epochs = 500
lambda_gp = 10

# Transforms
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Carrega o dataset de imagens
num_workers = 4  # Definindo o número de workers para carregar os dados. Ajuste este número conforme necessário.
dataset = datasets.ImageFolder('/kaggle/input/art-portraits', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


# Define um DataLoader para carregar os dados em lotes
#dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Inicializando o Gerador e o Discriminador
gen = Generator(z_dim).to(device)
disc = Discriminator().to(device)
# gen = Generator(z_dim).to(device)
# disc = Discriminator().to(device)
# gen.load_state_dict(torch.load("/kaggle/input/epoch2016/gerador2016.pth"))
# disc.load_state_dict(torch.load("/kaggle/input/epoch2016/discriminador2016.pth"))

# Definindo os otimizadoresAdam(learning_rate=0.0001, beta_1=0.0, beta_2=0.9)
gen_opt = torch.optim.Adam(gen.parameters(), lr=learning_rate,betas=(0.0,0.9))
disc_opt = torch.optim.Adam(disc.parameters(), lr=learning_rate,betas=(0.0,0.9))

# Criar listas para armazenar as perdas
d_losses = []
g_losses = []

# Loop de Treinamento

for epoch in range(num_epochs):
    counter = 0
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for real_samples, _ in dataloader:
        counter+=1
        real_samples = real_samples.view(real_samples.size(0), -1).to(device)  # move os dados reais para o dispositivo
        real_batch_size = real_samples.shape[0]  # tamanho real do lote
       
        
        # Treinando o Discriminador
        real_samples_labels = torch.ones((real_batch_size, 1)).to(device)
        latent_space_samples = torch.randn((real_batch_size, z_dim)).to(device)
        generated_samples = gen(latent_space_samples)

        disc_opt.zero_grad()

        # Real samples
        real_validity = disc(real_samples)
        # Fake samples
        fake_validity = disc(generated_samples.detach())
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(disc, real_samples.data, generated_samples.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

        d_loss.backward()
        disc_opt.step()
        
        if(counter % 5 == 0):
            # Treinando o Gerador
            gen_opt.zero_grad()
            # Generate a batch of samples
            gen_samples = gen(latent_space_samples)
            # Loss measures generator's ability to fool the discriminator
            g_loss = -torch.mean(disc(gen_samples))
            g_loss.backward()
            gen_opt.step()

    print(f"Epoch: {epoch} Loss D.: {d_loss.item()} Loss G.: {g_loss.item()}")
    
    # Adicionar as perdas à lista
    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())
    
    if((epoch + 1) % 5 == 0 ):
        torch.save(gen.state_dict(),f"gerador{epoch}.pth")
        torch.save(disc.state_dict(),f"discriminador{epoch}.pth")
        print(f"-> Guardei até á epoch:{epoch} ")
        fileDloss = open(f"Dloss{epoch}.json", "w")
        fileGloss = open(f"Gloss{epoch}.json", "w")
        fileDloss.write(json.dumps(d_losses))
        fileGloss.write(json.dumps(g_losses))
        fileDloss.close()
        fileGloss.close()
        

# Criar DataFrame com as perdas
history = {'Loss D': d_losses, 'Loss G': g_losses}
history_df = pd.DataFrame(history)

# Plotar o gráfico das perdas
fig = plt.figure(figsize=(15, 4))
ax = sns.lineplot(data=history_df)
ax.set(xlabel="Epochs", ylabel="Loss")
ax.set_title("Model Learning Curve")
plt.show()

def test_generator(gen, z_dim, num_samples):
    z = torch.randn(num_samples, z_dim).to(device)
    gen.eval()
    with torch.no_grad():
        gen_imgs = gen(z).view(-1, 3, 64, 64).cpu().numpy()

    # Ajustando as dimensões das imagens e revertendo a normalização
    gen_imgs = np.transpose(gen_imgs, (0, 2, 3, 1))  # reorganiza as dimensões para a visualização de imagens
    gen_imgs = (gen_imgs + 1) / 2  # desfaz a normalização

    fig, axs = plt.subplots(5, 5)
    cnt = 0
    for i in range(5):
        for j in range(5):
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("GAN_images.png")
    plt.close()

test_generator(gen, z_dim, 25)