In [2]:
import torch.nn as nn
import numpy as np
import torch
import yaml
from torch.autograd import Variable
from torch.utils.data import DataLoader
from txt2image_dataset import Text2ImageDataset
from utils import Utils, Logger
from PIL import Image
import os

### Dataset

In [3]:
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

config

{'birds_images_path': 'data/cvpr2016_cub/images/',
 'birds_embedding_path': 'data/cub_icml/',
 'birds_text_path': 'data/cvpr2016_cub/cvpr2016_cub/text_c10/',
 'birds_dataset_path': 'data/cvpr2016_cub/text2image/birds.hdf5',
 'val_split_path': 'data/cvpr2016_cub/valclasses.txt',
 'train_split_path': 'data/cvpr2016_cub/trainclasses.txt',
 'test_split_path': 'data/cvpr2016_cub/testclasses.txt',
 'flowers_images_path': 'data/cvpr2016_flowers/images/',
 'flowers_embedding_path': 'data/flowers_icml/',
 'flowers_text_path': 'data/cvpr2016_flowers/text_c10/',
 'flowers_dataset_path': 'data/cvpr2016_flowers/text2image/flowers.hdf5',
 'flowers_val_split_path': 'data/cvpr2016_flowers/valclasses.txt',
 'flowers_train_split_path': 'data/cvpr2016_flowers/trainclasses.txt',
 'flowers_test_split_path': 'data/cvpr2016_flowers/testclasses.txt'}

### Vanilla GAN Generator and Discriminator

In [18]:
from models import gan

In [19]:
generator = gan.generator()
discriminator = gan.discriminator()
print(generator, discriminator)

generator(
  (netG): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
) 

### Load pretrained Generator and Discriminator if they exist

In [50]:
pre_trained_disc = 'results/gan/checkpoints_vanilla_gan/disc_50.pth'
pre_trained_gen = 'results/gan/checkpoints_vanilla_gan/gen_50.pth'

## initialize the generator and discriminator ##
generator = generator.cuda()
discriminator = discriminator.cuda()
################################################

def load_gen_and_disc(pre_trained_disc, pre_trained_gen):
    epochs_trained = int(pre_trained_disc.split("/")[-1].split('_')[-1].split('.')[0])
    generator.load_state_dict(torch.load(pre_trained_gen))
    discriminator.load_state_dict(torch.load(pre_trained_disc))
    return generator, discriminator, epochs_trained

epochs_trained = 0

if pre_trained_disc and pre_trained_gen:
    generator, discriminator, epochs_trained = load_gen_and_disc(pre_trained_disc, 
                                                 pre_trained_gen)
else:
    generator.apply(Utils.weights_init)
    discriminator.apply(Utils.weights_init)

### Training a vanilla GAN

In [51]:
noise_dim = 100
batch_size = 64
num_workers = 1
lr = 0.0002
beta1 = 0.5
num_epochs = 100
l1_coef = 50
l2_coef = 200

vis_screen = 'vanilla_gan'
checkpoints_path = 'checkpoints_vanilla_gan'
save_path = 'results/gan'
logger = Logger(vis_screen)

Setting up a new session...


In [52]:
criterion = nn.BCELoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

optimG = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimD = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

In [53]:
def train_vanilla_gan(ds='flowers'):

    if ds == 'birds':
        dataset = Text2ImageDataset(config['birds_dataset_path'], split=0)
    if ds == 'flowers':
        dataset = Text2ImageDataset(config['flowers_dataset_path'], split=0)

    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True)

    iteration = 0
    print("num epochs", epochs, "num iterations", len(data_loader), 'batch_size', batch_size)
    
    for epoch in range(epochs_trained, num_epochs+1):
        for sample in data_loader:
            iteration += 1
            right_images = sample['right_images']

            right_images = Variable(right_images.float()).cuda()

            real_labels = torch.ones(right_images.size(0))
            fake_labels = torch.zeros(right_images.size(0))

            smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

            real_labels = Variable(real_labels).cuda()
            smoothed_real_labels = Variable(smoothed_real_labels).cuda()
            fake_labels = Variable(fake_labels).cuda()

            # Train the discriminator
            discriminator.zero_grad()
            outputs, activation_real = discriminator(right_images)
            # print(f"outputs shape {outputs.shape}, activation shape {activation_real.shape}")
            real_loss = criterion(outputs, smoothed_real_labels)
            real_score = outputs

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            # print(f"noise shape {noise.shape}")
            fake_images = generator(noise)
            # print(f"fake_images shape {fake_images.shape}")
            outputs, _ = discriminator(fake_images)
            fake_loss = criterion(outputs, fake_labels)
            fake_score = outputs

            d_loss = real_loss + fake_loss

            d_loss.backward()
            optimD.step()

            # Train the generator
            generator.zero_grad()

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)

            fake_images = generator(noise)

            outputs, activation_fake = discriminator(fake_images)
            _, activation_real = discriminator(right_images)

            activation_fake = torch.mean(activation_fake, 0)
            activation_real = torch.mean(activation_real, 0)

            g_loss = criterion(outputs, real_labels) \
                     + l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                     + l1_coef * l1_loss(fake_images, right_images)

            g_loss.backward()
            optimG.step()
            
            if iteration % 5 == 0:
                logger.log_iteration_gan(epoch, d_loss, g_loss, real_score, fake_score)
                logger.draw(right_images, fake_images)

        logger.plot_epoch_w_scores(iteration)
        
        if epoch % 10 == 0:
            Utils.save_checkpoint(discriminator, generator, save_path, checkpoints_path, epoch)



In [None]:
train_vanilla_gan()

num epochs 50 num iterations 460 batch_size 64
Epoch: 50, d_loss= 0.411736, g_loss= 31.944408, D(X)= 0.923830, D(G(X))= 0.038027
Epoch: 50, d_loss= 0.396288, g_loss= 31.168949, D(X)= 0.892885, D(G(X))= 0.027885
Epoch: 50, d_loss= 0.379001, g_loss= 31.073532, D(X)= 0.883514, D(G(X))= 0.017081
Epoch: 50, d_loss= 0.391611, g_loss= 30.228327, D(X)= 0.835360, D(G(X))= 0.016893
Epoch: 50, d_loss= 0.415474, g_loss= 33.047977, D(X)= 0.941548, D(G(X))= 0.043398
Epoch: 50, d_loss= 0.522161, g_loss= 30.443277, D(X)= 0.947663, D(G(X))= 0.111672
Epoch: 50, d_loss= 0.440154, g_loss= 30.917885, D(X)= 0.794234, D(G(X))= 0.013487
Epoch: 50, d_loss= 0.392846, g_loss= 29.850954, D(X)= 0.903104, D(G(X))= 0.031707
Epoch: 50, d_loss= 0.384779, g_loss= 31.357220, D(X)= 0.875982, D(G(X))= 0.017722
Epoch: 50, d_loss= 0.384468, g_loss= 31.439552, D(X)= 0.856321, D(G(X))= 0.015384
Epoch: 50, d_loss= 0.447081, g_loss= 30.668814, D(X)= 0.772853, D(G(X))= 0.009365
Epoch: 50, d_loss= 0.380179, g_loss= 31.589989, D(X

Epoch: 51, d_loss= 0.395680, g_loss= 30.926514, D(X)= 0.899765, D(G(X))= 0.026821
Epoch: 51, d_loss= 0.377841, g_loss= 32.351639, D(X)= 0.882628, D(G(X))= 0.010979
Epoch: 51, d_loss= 0.403754, g_loss= 32.404549, D(X)= 0.829021, D(G(X))= 0.009282
Epoch: 51, d_loss= 0.391081, g_loss= 32.021221, D(X)= 0.931565, D(G(X))= 0.025638
Epoch: 51, d_loss= 0.383709, g_loss= 30.997368, D(X)= 0.942376, D(G(X))= 0.011630
Epoch: 51, d_loss= 0.388576, g_loss= 31.433226, D(X)= 0.912902, D(G(X))= 0.024296
Epoch: 51, d_loss= 0.364634, g_loss= 31.731073, D(X)= 0.912672, D(G(X))= 0.013858
Epoch: 51, d_loss= 0.380405, g_loss= 31.644180, D(X)= 0.937102, D(G(X))= 0.022170
Epoch: 51, d_loss= 0.408244, g_loss= 31.763086, D(X)= 0.956415, D(G(X))= 0.028758
Epoch: 51, d_loss= 0.391398, g_loss= 32.415588, D(X)= 0.924793, D(G(X))= 0.022891
Epoch: 51, d_loss= 0.382230, g_loss= 32.008438, D(X)= 0.850465, D(G(X))= 0.010153
Epoch: 51, d_loss= 0.381336, g_loss= 30.307907, D(X)= 0.896832, D(G(X))= 0.016780
Epoch: 51, d_los

Epoch: 52, d_loss= 0.394758, g_loss= 30.577419, D(X)= 0.908414, D(G(X))= 0.041539
Epoch: 52, d_loss= 0.424721, g_loss= 30.354744, D(X)= 0.790047, D(G(X))= 0.017012
Epoch: 52, d_loss= 0.422858, g_loss= 31.194889, D(X)= 0.873650, D(G(X))= 0.055780
Epoch: 52, d_loss= 0.407079, g_loss= 30.582806, D(X)= 0.914809, D(G(X))= 0.055298
Epoch: 52, d_loss= 0.394005, g_loss= 31.246262, D(X)= 0.859686, D(G(X))= 0.026433
Epoch: 52, d_loss= 0.377051, g_loss= 32.829109, D(X)= 0.866157, D(G(X))= 0.015096
Epoch: 52, d_loss= 0.455378, g_loss= 30.216137, D(X)= 0.795482, D(G(X))= 0.015604
Epoch: 52, d_loss= 0.414694, g_loss= 29.624636, D(X)= 0.911555, D(G(X))= 0.047976
Epoch: 52, d_loss= 0.431509, g_loss= 31.863228, D(X)= 0.905803, D(G(X))= 0.059725
Epoch: 52, d_loss= 0.383335, g_loss= 31.803410, D(X)= 0.845928, D(G(X))= 0.014220
Epoch: 52, d_loss= 0.381881, g_loss= 32.177902, D(X)= 0.911798, D(G(X))= 0.020952
Epoch: 52, d_loss= 0.390630, g_loss= 29.871265, D(X)= 0.869253, D(G(X))= 0.033932
Epoch: 52, d_los

Epoch: 53, d_loss= 0.350198, g_loss= 33.811302, D(X)= 0.902558, D(G(X))= 0.005154
Epoch: 53, d_loss= 0.357271, g_loss= 34.192856, D(X)= 0.879550, D(G(X))= 0.010413
Epoch: 53, d_loss= 0.396236, g_loss= 31.975424, D(X)= 0.825891, D(G(X))= 0.009633
Epoch: 53, d_loss= 0.409421, g_loss= 30.345968, D(X)= 0.781010, D(G(X))= 0.005264
Epoch: 53, d_loss= 0.400574, g_loss= 30.245157, D(X)= 0.943393, D(G(X))= 0.036676
Epoch: 53, d_loss= 0.366590, g_loss= 31.081520, D(X)= 0.838853, D(G(X))= 0.008790
Epoch: 53, d_loss= 0.365438, g_loss= 31.575748, D(X)= 0.923135, D(G(X))= 0.017847
Epoch: 53, d_loss= 0.364286, g_loss= 32.671787, D(X)= 0.881411, D(G(X))= 0.007058
Epoch: 53, d_loss= 0.355703, g_loss= 32.021133, D(X)= 0.919991, D(G(X))= 0.010467
Epoch: 53, d_loss= 0.406852, g_loss= 32.652710, D(X)= 0.946654, D(G(X))= 0.038275
Epoch: 53, d_loss= 0.438168, g_loss= 32.909138, D(X)= 0.968262, D(G(X))= 0.034783
Epoch: 53, d_loss= 0.370229, g_loss= 30.988680, D(X)= 0.928304, D(G(X))= 0.016675
Epoch: 53, d_los

Epoch: 54, d_loss= 0.358600, g_loss= 31.859898, D(X)= 0.895841, D(G(X))= 0.012143
Epoch: 54, d_loss= 0.389648, g_loss= 32.300034, D(X)= 0.950862, D(G(X))= 0.016329
Epoch: 54, d_loss= 0.376043, g_loss= 31.643848, D(X)= 0.933919, D(G(X))= 0.015965
Epoch: 54, d_loss= 0.423722, g_loss= 30.685696, D(X)= 0.771903, D(G(X))= 0.007729
Epoch: 54, d_loss= 0.357179, g_loss= 31.538078, D(X)= 0.870209, D(G(X))= 0.008940
Epoch: 54, d_loss= 0.392465, g_loss= 31.003199, D(X)= 0.833917, D(G(X))= 0.013068
Epoch: 54, d_loss= 0.370336, g_loss= 31.949642, D(X)= 0.874932, D(G(X))= 0.014379
Epoch: 54, d_loss= 0.396635, g_loss= 31.643984, D(X)= 0.922637, D(G(X))= 0.022948
Epoch: 54, d_loss= 0.384887, g_loss= 31.325624, D(X)= 0.832465, D(G(X))= 0.009545
Epoch: 54, d_loss= 0.369194, g_loss= 32.179817, D(X)= 0.893468, D(G(X))= 0.010731
Epoch: 54, d_loss= 0.417986, g_loss= 31.378521, D(X)= 0.953203, D(G(X))= 0.044576
Epoch: 54, d_loss= 0.376784, g_loss= 32.488777, D(X)= 0.851982, D(G(X))= 0.009066
Epoch: 54, d_los

Epoch: 55, d_loss= 0.401568, g_loss= 30.794741, D(X)= 0.809545, D(G(X))= 0.008090
Epoch: 55, d_loss= 0.354790, g_loss= 31.852953, D(X)= 0.896666, D(G(X))= 0.011212
Epoch: 55, d_loss= 0.367106, g_loss= 31.396816, D(X)= 0.895716, D(G(X))= 0.010531
Epoch: 55, d_loss= 0.380325, g_loss= 31.078012, D(X)= 0.942717, D(G(X))= 0.014320
Epoch: 55, d_loss= 0.359084, g_loss= 33.234989, D(X)= 0.891146, D(G(X))= 0.015080
Epoch: 55, d_loss= 0.435136, g_loss= 31.043936, D(X)= 0.898700, D(G(X))= 0.052782
Epoch: 55, d_loss= 0.391497, g_loss= 30.625065, D(X)= 0.834536, D(G(X))= 0.019749
Epoch: 55, d_loss= 0.456956, g_loss= 31.532259, D(X)= 0.947765, D(G(X))= 0.042338
Epoch: 55, d_loss= 0.386403, g_loss= 32.832542, D(X)= 0.938954, D(G(X))= 0.022708
Epoch: 55, d_loss= 0.376872, g_loss= 31.275614, D(X)= 0.940111, D(G(X))= 0.013575
Epoch: 55, d_loss= 0.371462, g_loss= 31.405006, D(X)= 0.910024, D(G(X))= 0.008792
Epoch: 55, d_loss= 0.377852, g_loss= 32.992245, D(X)= 0.922612, D(G(X))= 0.010417
Epoch: 55, d_los

Epoch: 56, d_loss= 0.361181, g_loss= 31.866756, D(X)= 0.890747, D(G(X))= 0.012238
Epoch: 56, d_loss= 0.409147, g_loss= 30.708654, D(X)= 0.826685, D(G(X))= 0.018211
Epoch: 56, d_loss= 0.446296, g_loss= 28.996819, D(X)= 0.772749, D(G(X))= 0.020713
Epoch: 56, d_loss= 0.413746, g_loss= 31.290905, D(X)= 0.946766, D(G(X))= 0.041754
Epoch: 56, d_loss= 0.395574, g_loss= 31.453518, D(X)= 0.876086, D(G(X))= 0.023741
Epoch: 56, d_loss= 0.402008, g_loss= 30.540464, D(X)= 0.837904, D(G(X))= 0.017976
Epoch: 56, d_loss= 0.372009, g_loss= 30.405197, D(X)= 0.890824, D(G(X))= 0.025718
Epoch: 56, d_loss= 0.381518, g_loss= 31.734444, D(X)= 0.939710, D(G(X))= 0.024398
Epoch: 56, d_loss= 0.395151, g_loss= 30.231722, D(X)= 0.844850, D(G(X))= 0.019270
Epoch: 56, d_loss= 0.371487, g_loss= 31.903238, D(X)= 0.857175, D(G(X))= 0.011640
Epoch: 56, d_loss= 0.369467, g_loss= 31.610317, D(X)= 0.920679, D(G(X))= 0.021701
Epoch: 56, d_loss= 0.385056, g_loss= 31.266138, D(X)= 0.913309, D(G(X))= 0.027193
Epoch: 56, d_los

Epoch: 57, d_loss= 0.372570, g_loss= 32.371952, D(X)= 0.863365, D(G(X))= 0.012605
Epoch: 57, d_loss= 0.357756, g_loss= 33.043312, D(X)= 0.873319, D(G(X))= 0.005980
Epoch: 57, d_loss= 0.349242, g_loss= 30.649353, D(X)= 0.890469, D(G(X))= 0.009744
Epoch: 57, d_loss= 0.355557, g_loss= 30.648693, D(X)= 0.899720, D(G(X))= 0.010942
Epoch: 57, d_loss= 0.370628, g_loss= 31.419792, D(X)= 0.938160, D(G(X))= 0.017203
Epoch: 57, d_loss= 0.361535, g_loss= 31.166656, D(X)= 0.861136, D(G(X))= 0.008520
Epoch: 57, d_loss= 0.421727, g_loss= 30.389904, D(X)= 0.759876, D(G(X))= 0.004373
Epoch: 57, d_loss= 0.359356, g_loss= 30.842251, D(X)= 0.876068, D(G(X))= 0.012842
Epoch: 57, d_loss= 0.355565, g_loss= 32.805386, D(X)= 0.895683, D(G(X))= 0.007186
Epoch: 57, d_loss= 0.369874, g_loss= 29.983219, D(X)= 0.846667, D(G(X))= 0.008107
Epoch: 57, d_loss= 0.393019, g_loss= 29.940218, D(X)= 0.945918, D(G(X))= 0.030641
Epoch: 57, d_loss= 0.356457, g_loss= 32.167786, D(X)= 0.863840, D(G(X))= 0.007591
Epoch: 57, d_los

In [None]:
def predict_vanilla_gan(ds='flowers'):
    
    if ds == 'birds':
        dataset = Text2ImageDataset(config['birds_dataset_path'], split=2)
    if ds == 'flowers':
        dataset = Text2ImageDataset(config['flowers_dataset_path'], split=2)

    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False)
    
    for sample in data_loader:
        right_images = sample['right_images']
        right_embed = sample['right_embed']
        txt = sample['txt']

        if not os.path.exists('results/{0}'.format(save_path)):
            os.makedirs('results/{0}'.format(save_path))

        right_images = Variable(right_images.float()).cuda()
        right_embed = Variable(right_embed.float()).cuda()

        # Train the generator
        noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = generator(noise)

        logger.draw(right_images, fake_images)

        for image, t in zip(fake_images, txt):
            im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
            t = t.strip()
            im.save('results/generated_images/{}.jpg'.format(t.replace(" ", "_")[:15]))
            print(t)

## Train a GAN CLS

In [None]:
noise_dim = 100
batch_size = 64
num_workers = 1
lr = 0.0002
beta1 = 0.5
num_epochs = 100
l1_coef = 50
l2_coef = 200

vis_screen = 'gan_cls'
checkpoints_path = 'checkpoints_gan_cls'
save_path = 'results'
logger = Logger(vis_screen)

In [None]:
from models.gan_cls import generator, discriminator

generator = generator().cuda()
discriminator = discriminator().cuda()

pre_trained_disc = None
pre_trained_gen = None

if pre_trained_disc:
    discriminator.load_state_dict(torch.load(pre_trained_disc))
else:
    discriminator.apply(Utils.weights_init)

if pre_trained_gen:
    generator.load_state_dict(torch.load(pre_trained_gen))
else:
    generator.apply(Utils.weights_init)

optimD = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimG = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))


In [None]:
def train_gan_cls(ds='flowers', cls=True, pre_trained_disc=None, pre_trained_gen=None):
    criterion = nn.BCELoss()
    l2_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    iteration = 0
    
    if ds == 'birds':
        dataset = Text2ImageDataset(config['birds_dataset_path'], split=0)
    if ds == 'flowers':
        dataset = Text2ImageDataset(config['flowers_dataset_path'], split=0)

    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True)

    iteration = 0
    
    print("num iterations", len(data_loader), 'batch_size', batch_size)

    for epoch in range(num_epochs):
        for sample in data_loader:
            iteration += 1
            right_images = sample['right_images']
            right_embed = sample['right_embed']
            wrong_images = sample['wrong_images']

            right_images = Variable(right_images.float()).cuda()
            right_embed = Variable(right_embed.float()).cuda()
            wrong_images = Variable(wrong_images.float()).cuda()

            real_labels = torch.ones(right_images.size(0))
            fake_labels = torch.zeros(right_images.size(0))

            # ======== One sided label smoothing ==========
            # Helps preventing the discriminator from overpowering the
            # generator adding penalty when the discriminator is too confident
            # =============================================
            smoothed_real_labels = torch.FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

            real_labels = Variable(real_labels).cuda()
            smoothed_real_labels = Variable(smoothed_real_labels).cuda()
            fake_labels = Variable(fake_labels).cuda()

            # Train the discriminator
            discriminator.zero_grad()
            outputs, activation_real = discriminator(right_images, right_embed)
            real_loss = criterion(outputs, smoothed_real_labels)
            real_score = outputs

            if cls:
                outputs, _ = discriminator(wrong_images, right_embed)
                wrong_loss = criterion(outputs, fake_labels)
                wrong_score = outputs

            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = generator(right_embed, noise)
            outputs, _ = discriminator(fake_images, right_embed)
            fake_loss = criterion(outputs, fake_labels)
            fake_score = outputs

            d_loss = real_loss + fake_loss

            if cls:
                d_loss = d_loss + wrong_loss

            d_loss.backward()
            optimD.step()

            # Train the generator
            generator.zero_grad()
            noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = generator(right_embed, noise)
            outputs, activation_fake = discriminator(fake_images, right_embed)
            _, activation_real = discriminator(right_images, right_embed)

            activation_fake = torch.mean(activation_fake, 0)
            activation_real = torch.mean(activation_real, 0)


            #======= Generator Loss function============
            # This is a customized loss function, the first term is the regular cross entropy loss
            # The second term is feature matching loss, this measure the distance between the real and generated
            # images statistics by comparing intermediate layers activations
            # The third term is L1 distance between the generated and real images, this is helpful for the conditional case
            # because it links the embedding feature vector directly to certain pixel values.
            #===========================================
            g_loss = criterion(outputs, real_labels) \
                     + l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                     + l1_coef * l1_loss(fake_images, right_images)

            g_loss.backward()
            optimG.step()

            if iteration % 5 == 0:
                logger.log_iteration_gan(epoch,d_loss, g_loss, real_score, fake_score)
                logger.draw(right_images, fake_images)

        logger.plot_epoch_w_scores(epoch)

        if (epoch) % 10 == 0:
            Utils.save_checkpoint(discriminator, generator, checkpoints_path, save_path, epoch)

In [None]:
train_gan_cls()

In [None]:
def predict(ds='flowers'):
    
    if ds == 'birds':
        dataset = Text2ImageDataset(config['birds_dataset_path'], split=2)
    if ds == 'flowers':
        dataset = Text2ImageDataset(config['flowers_dataset_path'], split=2)

    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False)
    
    for sample in data_loader:
        right_images = sample['right_images']
        right_embed = sample['right_embed']
        txt = sample['txt']

        if not os.path.exists('results/{0}'.format(save_path)):
            os.makedirs('results/{0}'.format(save_path))

        right_images = Variable(right_images.float()).cuda()
        right_embed = Variable(right_embed.float()).cuda()

        # Train the generator
        noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
        noise = noise.view(noise.size(0), 100, 1, 1)
        fake_images = generator(right_embed, noise)

        logger.draw(right_images, fake_images)

        for image, t in zip(fake_images, txt):
            im = Image.fromarray(image.data.mul_(127.5).add_(127.5).byte().permute(1, 2, 0).cpu().numpy())
            t = t.strip()
            t = t.replace(".", '')
            im.save('results/gan_cls/generated_images/{}.jpg'.format(t.replace(" ", "_")[:100]))
            print(t)

In [None]:
predict()