In [1]:
import torch
import os
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

from utils.dataloader import InsectsDataset, ToTensor
from utils.imgs import show_tensor_images
from utils.read import read_config
from models.conditional_GAN import Generator, Discriminator, get_noise, get_one_hot_labels, combine_vectors, get_input_dimensions


In [2]:
config = read_config()

base_path = os.path.join(config['base_data']['dest_dir'], 'augmented')
csv_file = os.path.join(base_path, 'data_info.csv')
root_dir = os.path.join(base_path, 'images')

cfg_model = config['conditional_GAN']

img_shape = tuple(cfg_model['img_shape'])
n_classes = cfg_model['n_classes']
n_epochs = cfg_model['n_epochs']
z_dim = cfg_model['z_dim']
display_step = cfg_model['display_step']
batch_size = cfg_model['batch_size']
lr = cfg_model['lr']
device = cfg_model['device']

criterion = nn.BCEWithLogitsLoss()

In [3]:
cfg_model

{'img_shape': [3, 64, 64],
 'n_classes': 7,
 'n_epochs': 200,
 'z_dim': 64,
 'display_step': 500,
 'batch_size': 128,
 'lr': 0.0002,
 'device': 'cpu'}

In [4]:
dataset = InsectsDataset(csv_file=csv_file, root_dir=root_dir, transform=ToTensor(), return_one_hot=True)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [5]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, img_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
cur_step = 0
generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    for batch in tqdm(dataloader):
        real = batch['image']
        labels = batch['class_name']
        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)

        one_hot_labels = torch.tensor(batch['one_hot']).squeeze().to(device) 
        image_one_hot_labels = one_hot_labels[:, :, None, None] # agrega dos dimensiones mas
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, img_shape[1], img_shape[2]) # repeat: size per dim

        ### Update discriminator ###
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size 
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        
        # Now you can get the images from the generator
        # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
        #        2) Generate the conditioned fake images
       
        #### START CODE HERE ####
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)
        #### END CODE HERE ####
        
        # Make sure that enough images were generated
        assert len(fake) == len(real)
        # Check that correct tensors were combined
        assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
        # It comes from the correct generator
        #print('tuple(fake.shape)::', tuple(fake.shape))
        assert tuple(fake.shape) == (len(real), 3, 64, 64)

        # Now you can get the predictions from the discriminator
        # Steps: 1) Create the input for the discriminator
        #           a) Combine the fake images with image_one_hot_labels, 
        #              remember to detach the generator (.detach()) so you do not backpropagate through it
        #           b) Combine the real images with image_one_hot_labels
        #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
        #        3) Get the discriminator's prediction on the reals as disc_real_pred
        
        #### START CODE HERE ####
        fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)
        disc_fake_pred = disc(fake_image_and_labels)
        disc_real_pred = disc(real_image_and_labels)
        #### END CODE HERE ####
        
        # Make sure shapes are correct 
        assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], 64 ,64)
        assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1], 64 ,64)
        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)
        # Shapes must match
        assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
        assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)
        
        
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step() 

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]

        ### Update generator ###
        # Zero out the generator gradients
        gen_opt.zero_grad()

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        # This will error if you didn't concatenate your labels to your image correctly
        disc_fake_pred = disc(fake_image_and_labels)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the generator losses
        generator_losses += [gen_loss.item()]
        #

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss"
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
        cur_step += 1

HBox(children=(FloatProgress(value=0.0, max=118.0), HTML(value='')))



Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!


In [8]:
ds = InsectsDataset(csv_file=csv_file, root_dir=root_dir, transform=ToTensor(), return_one_hot=True)

In [16]:
ds[9000]['one_hot']

array([[1, 0, 0, 0, 0, 0, 0]])

In [7]:
df = ds.df_data

In [47]:
import pandas as pd
dummies = pd.get_dummies(df['especie'])
dummies

Unnamed: 0,Bemisia tabaci,Macrolophus pygmaeus,Nesidiocoris tenuis,brevicoryne brassicae,liriomyza huidobrensis,prodiplosis longifila,trips tabaci
0,0,0,0,0,1,0,0
1,0,0,0,0,1,0,0
2,0,0,0,0,1,0,0
3,0,0,0,0,1,0,0
4,0,0,0,0,1,0,0
...,...,...,...,...,...,...,...
14989,0,0,1,0,0,0,0
14990,0,0,1,0,0,0,0
14991,0,0,1,0,0,0,0
14992,0,0,1,0,0,0,0


In [48]:
from sklearn.preprocessing import LabelBinarizer


In [49]:
y =LabelBinarizer()

In [51]:
y = y.fit(df['especie'])

In [62]:
a = y.transform(['Macrolophus pygmaeus', 'Macrolophus pygmaeus'])

In [64]:
type (a)

numpy.ndarray

In [65]:
a

array([[0, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0]])

In [71]:
a = torch.tensor(a)

In [72]:
a.dtype

torch.int64

In [73]:
a

tensor([[0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0]])

In [11]:
def calc_o(i, s, k):
    return s*(i-1)+k

In [20]:
a = calc_o(1,2,3)
b = calc_o(a,2,3)
c = calc_o(b,2,3)
d = calc_o(c,2,4)
e = calc_o(d,2,2)
e

64

In [24]:
d

32

In [15]:
def calc_o_conv(i, s, k):
    return (i-k)//(s) + 1

In [22]:
a = calc_o_conv(64,2,4)
b = calc_o_conv(a,2,4)
c = calc_o_conv(b,2,4)
d = calc_o_conv(c,3,4)
d

1

In [26]:
d

1