In [1]:
from torch import optim
import os
import torchvision.utils as vutils
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Arguments
BATCH_SIZE = 64
Z_DIM = 100
LABEL_EMBED_SIZE = 5
NUM_CLASSES = 10
IMGS_TO_DISPLAY_PER_CLASS = 20
LOAD_MODEL = False

DB = 'SVHN'

CHANNELS = 3
EPOCHS = 100

# Directories for storing data, model and output samples
db_path = os.path.join('./data', DB)
os.makedirs(db_path, exist_ok=True)
model_path = os.path.join('./model', DB)
os.makedirs(model_path, exist_ok=True)
samples_path = os.path.join('./samples', DB)
os.makedirs(samples_path, exist_ok=True)

In [4]:
# Data loader
transform = transforms.Compose([transforms.Resize([32, 32]),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])


train_dataset = datasets.ImageFolder(root='TRAIN', transform=transform)

print(train_dataset)


data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

print(data_loader)

#show labels of the dataset

print(train_dataset.classes)
print(len(train_dataset.classes))

Dataset ImageFolder
    Number of datapoints: 277
    Root location: TRAIN
    StandardTransform
Transform: Compose(
               Resize(size=[32, 32], interpolation=bilinear, max_size=None, antialias=warn)
               ToTensor()
               Normalize(mean=[0.5], std=[0.5])
           )
<torch.utils.data.dataloader.DataLoader object at 0x000002273B65C310>
['12', '13', '24', '38', '39', '44', '46', '49', '50', '6']
10


In [14]:
# Method for storing generated images
def generate_imgs(z, fixed_label, epoch=0):
    gen.eval()
    fake_imgs = gen(z, fixed_label)
    fake_imgs = (fake_imgs + 1) / 2
    fake_imgs_ = vutils.make_grid(fake_imgs, normalize=False, nrow=IMGS_TO_DISPLAY_PER_CLASS)
    vutils.save_image(fake_imgs_, os.path.join(samples_path, 'sample_' + str(epoch) + '.png'))


In [6]:
# Networks
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)

In [7]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, num_classes=10, label_embed_size=5, channels=3, conv_dim=64):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, label_embed_size)
        self.tconv1 = conv_block(z_dim + label_embed_size, conv_dim * 4, pad=0, transpose=True)
        self.tconv2 = conv_block(conv_dim * 4, conv_dim * 2, transpose=True)
        self.tconv3 = conv_block(conv_dim * 2, conv_dim, transpose=True)
        self.tconv4 = conv_block(conv_dim, channels, transpose=True, use_bn=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)

            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, label):
        x = x.reshape([x.shape[0], -1, 1, 1])
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0], -1, 1, 1])
        x = torch.cat((x, label_embed), dim=1)
        x = F.relu(self.tconv1(x))
        x = F.relu(self.tconv2(x))
        x = F.relu(self.tconv3(x))
        x = torch.tanh(self.tconv4(x))
        return x


In [8]:
class Discriminator(nn.Module):
    def __init__(self, num_classes=10, channels=3, conv_dim=64):
        super(Discriminator, self).__init__()
        self.image_size = 32
        self.label_embedding = nn.Embedding(num_classes, self.image_size*self.image_size)
        self.conv1 = conv_block(channels + 1, conv_dim, use_bn=False)
        self.conv2 = conv_block(conv_dim, conv_dim * 2)
        self.conv3 = conv_block(conv_dim * 2, conv_dim * 4)
        self.conv4 = conv_block(conv_dim * 4, 1, k_size=4, stride=1, pad=0, use_bn=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)

            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, label):
        alpha = 0.2
        label_embed = self.label_embedding(label)
        label_embed = label_embed.reshape([label_embed.shape[0], 1, self.image_size, self.image_size])
        x = torch.cat((x, label_embed), dim=1)
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = torch.sigmoid(self.conv4(x))
        return x.squeeze()

In [9]:
gen = Generator(z_dim=Z_DIM, num_classes=NUM_CLASSES, label_embed_size=LABEL_EMBED_SIZE, channels=CHANNELS)
dis = Discriminator(num_classes=NUM_CLASSES, channels=CHANNELS)


# Teste unitário para verificar a saída do gerador
test_noise = torch.randn(1, Z_DIM)
test_label = torch.LongTensor([1])  # Exemplo de label
gen_output = gen(test_noise, test_label)
print("Output shape from generator:", gen_output.shape)

# Teste unitário para o discriminador
dis_output = dis(gen_output, test_label)
print("Output from discriminator:", dis_output.shape)


Output shape from generator: torch.Size([1, 3, 32, 32])
Output from discriminator: torch.Size([])


In [10]:
gen = Generator(z_dim=Z_DIM, num_classes=NUM_CLASSES, label_embed_size=LABEL_EMBED_SIZE, channels=CHANNELS)
dis = Discriminator(num_classes=NUM_CLASSES, channels=CHANNELS)

# Load previous model   
if LOAD_MODEL:
    gen.load_state_dict(torch.load(os.path.join(model_path, 'gen.pkl')))
    dis.load_state_dict(torch.load(os.path.join(model_path, 'dis.pkl')))
    
# Model Summary
print("------------------Generator------------------")
print(gen)
print("------------------Discriminator------------------")
print(dis)

# Define Optimizers
g_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
d_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)

# Loss functions
loss_fn = nn.BCELoss()

# Fix images for viz
fixed_z = torch.randn(IMGS_TO_DISPLAY_PER_CLASS*NUM_CLASSES, Z_DIM)
fixed_label = torch.arange(0, NUM_CLASSES)
fixed_label = torch.repeat_interleave(fixed_label, IMGS_TO_DISPLAY_PER_CLASS)


# GPU Compatibility
is_cuda = torch.cuda.is_available()
if is_cuda:
    gen, dis = gen.cuda(), dis.cuda()
    real_label, fake_label = real_label.cuda(), fake_label.cuda()
    fixed_z, fixed_label = fixed_z.cuda(), fixed_label.cuda()

total_iters = 0
max_iter = len(data_loader)

------------------Generator------------------
Generator(
  (label_embedding): Embedding(10, 5)
  (tconv1): Sequential(
    (0): ConvTranspose2d(105, 256, kernel_size=(4, 4), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tconv2): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tconv3): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (tconv4): Sequential(
    (0): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)
------------------Discriminator------------------
Discriminator(
  (label_embedding): Embedding(10, 1024)
  (conv1): Sequential(
    (0): Conv2d(4, 64, kernel_siz

In [11]:
# Training
for epoch in range(EPOCHS):
    gen.train()
    dis.train()

    for i, data in enumerate(data_loader):

        total_iters += 1

        x_real, x_label = data
        batch_size = x_real.size(0)  # Tamanho do batch atual

        # Ajustando o tamanho de z_fake para corresponder ao tamanho do batch atual
        z_fake = torch.randn(batch_size, Z_DIM)
        
        real_label = torch.ones(batch_size)
        fake_label = torch.zeros(batch_size)

        if is_cuda:
            x_real = x_real.cuda()
            x_label = x_label.cuda()
            z_fake = z_fake.cuda()
                
        # Generate fake data
        x_fake = gen(z_fake, x_label)

        # Train Discriminator
        fake_out = dis(x_fake.detach(), x_label)
        real_out = dis(x_real.detach(), x_label)
        
        d_loss = (loss_fn(fake_out, fake_label) + loss_fn(real_out, real_label)) / 2

        d_opt.zero_grad()
        d_loss.backward()
        d_opt.step()

        # Train Generator
        fake_out = dis(x_fake, x_label)
        g_loss = loss_fn(fake_out, real_label)

        g_opt.zero_grad()
        g_loss.backward()
        g_opt.step()

        if i % 50 == 0:
            print("Epoch: " + str(epoch + 1) + "/" + str(EPOCHS)
                  + "\titer: " + str(i) + "/" + str(max_iter)
                  + "\ttotal_iters: " + str(total_iters)
                  + "\td_loss:" + str(round(d_loss.item(), 4))
                  + "\tg_loss:" + str(round(g_loss.item(), 4))
                  )

    if (epoch + 1) % 5 == 0:
        torch.save(gen.state_dict(), os.path.join(model_path, 'gen.pkl'))
        torch.save(dis.state_dict(), os.path.join(model_path, 'dis.pkl'))

        # generate_imgs(fixed_z, fixed_label, epoch=epoch + 1)

Epoch: 1/100	iter: 0/5	total_iters: 1	d_loss:0.8353	g_loss:2.5427
Epoch: 2/100	iter: 0/5	total_iters: 6	d_loss:0.2418	g_loss:2.2637
Epoch: 3/100	iter: 0/5	total_iters: 11	d_loss:0.1678	g_loss:2.9562
Epoch: 4/100	iter: 0/5	total_iters: 16	d_loss:0.1194	g_loss:3.1491
Epoch: 5/100	iter: 0/5	total_iters: 21	d_loss:0.1049	g_loss:3.973
Epoch: 6/100	iter: 0/5	total_iters: 26	d_loss:0.1029	g_loss:3.4266
Epoch: 7/100	iter: 0/5	total_iters: 31	d_loss:0.4683	g_loss:6.3576
Epoch: 8/100	iter: 0/5	total_iters: 36	d_loss:0.1043	g_loss:3.8823
Epoch: 9/100	iter: 0/5	total_iters: 41	d_loss:0.1255	g_loss:3.812
Epoch: 10/100	iter: 0/5	total_iters: 46	d_loss:0.185	g_loss:3.9073
Epoch: 11/100	iter: 0/5	total_iters: 51	d_loss:0.1935	g_loss:3.9918
Epoch: 12/100	iter: 0/5	total_iters: 56	d_loss:0.1908	g_loss:3.221
Epoch: 13/100	iter: 0/5	total_iters: 61	d_loss:0.2245	g_loss:4.4102
Epoch: 14/100	iter: 0/5	total_iters: 66	d_loss:0.1508	g_loss:3.442
Epoch: 15/100	iter: 0/5	total_iters: 71	d_loss:0.2247	g_loss:5.4

In [12]:
def generate_imgs_for_label_after_training(z, label, gen_model, epoch=0):
    gen_model.eval()
    
    # Crie um tensor de labels com a label desejada
    fixed_label = torch.full((z.size(0),), label, dtype=torch.long)
    
    # Gere as imagens com a label específica
    with torch.no_grad():
        fake_imgs = gen_model(z, fixed_label)
    fake_imgs = (fake_imgs + 1) / 2
    
    # Crie a grade de imagens
    fake_imgs_grid = vutils.make_grid(fake_imgs, normalize=False, nrow=IMGS_TO_DISPLAY_PER_CLASS)
    
    # Imprima a imagem gerada
    print(fake_imgs_grid)


In [15]:
generate_imgs(fixed_z, fixed_label)

In [20]:
label_to_generate = 8
generate_imgs_for_label_after_training(fixed_z, label_to_generate, gen, epoch=epoch + 1)
