In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
from inpainting.dataset import Data, ResizeTransform, ConditionSampler
from gan.conditional_gan import Generator5Net, Discriminator5
from gan.losses import GeneratorLoss, DiscriminatorLoss
from gan.trainer import GanTrainer
from torch.utils.data import DataLoader 
from torch.utils.data.dataset import Subset
import torch
import matplotlib.pyplot as plt
import pandas as pd
from inpainting.visualize import plot_batch
from inpainting.visualize import cGanPlotLossCallback, ConditionDescriber
from inpainting import cond_celeba_config as conf
from inpainting.visualizer import Visualizer
from performance.estimator import FIDEstimator

%matplotlib notebook

In [None]:
#torch.cuda.set_device(conf.CUDA_DEVICE)
device = torch.device(conf.DEVICE)

In [None]:

transform = ResizeTransform()
data = Data(conf.DATA_PATH, transform, return_attr=True, conditions=conf.conditions)
noise_sampler = ConditionSampler(data, conf.Z_SIZE)
#data = Subset(data, range(100))
train_size = int(0.8 * len(data))
valid_size = len(data) - train_size
train_data, valid_data = torch.utils.data.random_split(data, [train_size, valid_size])
train_loader = DataLoader(train_data, batch_size=conf.BATCH_SIZE, num_workers=conf.NUM_WORKERS, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=conf.BATCH_SIZE, num_workers=conf.NUM_WORKERS, shuffle=True)
print('Dataset size: ', len(data))
print('y size: ', conf.Y_SIZE)


In [None]:
estimator = FIDEstimator(noise_sampler, config=conf)

In [None]:
cd = ConditionDescriber(conf.conditions)
y = cd.create_y(Male=False, Smiling=True, Young=True, Eyeglasses=False)
#y = np.array([0, 1, 0, 0, 0])*2 - 1
idx = data.find_image(y)
img, y_new = data[idx]
print(y_new)
plot_batch([img], normalize=True, limit=1, descriptions=[cd.describe(y)])

In [None]:
generator = Generator5Net(conf.Z_SIZE, conf.Y_SIZE).to(device)
discriminator = Discriminator5(conf.Y_SIZE).to(device)

In [None]:
visualizer = Visualizer(conf, noise_sampler)
trainer = GanTrainer(generator, discriminator, conf, noise_sampler, visualizer=visualizer, estimator=estimator)

In [None]:
if conf.CONTINUE_TRAINING:
    trainer.load_last_checkpoint()
    
trainer.train(train_loader, valid_loader, n_epochs=200)

In [None]:
LOAD_MODEL = True
LOAD_EPOCH_N = 40
if LOAD_MODEL:
    generator.load_state_dict(torch.load(conf.MODEL_PATH + 'generator_%d.pth' % (LOAD_EPOCH_N,)))
    discriminator.load_state_dict(torch.load(conf.MODEL_PATH + 'discriminator_%d.pth' % (LOAD_EPOCH_N,)))
    generator.eval()
    discriminator.eval()

In [None]:
print(estimator.score(generator, valid_loader))

In [None]:
scalar = lambda v: np.asscalar(v.data.cpu().numpy())

Z = torch.normal(mean=torch.zeros(1, generator.z_size)).to(device)
        # 'Male', 'Smiling', 'Young', 'Eyeglasses', 'Wearing_Hat'
y = cd.create_y(Male=False, Smiling=True, Young=True, Eyeglasses=False, Wearing_Hat=False)
print(dtype)
idx = data.find_image(y)
img, y_new = data[idx]
X = torch.tensor(img[np.newaxis,:]).to(device)
y = y.astype(np.float32)
Y = torch.tensor([y]).to(device)

G_sample = generator(Z, Y)
D_real, D_logit_real = discriminator(X, Y)
D_fake, D_logit_fake = discriminator(G_sample, Y)

generator_loss = GeneratorLoss()
discriminator_loss = DiscriminatorLoss(label_smoothing=0.25)

print("D_real: ", scalar(D_real), "D_fake: ", scalar(D_fake))
print("D_logit_real: ", scalar(D_logit_real), "D_logit_fake: ", scalar(D_logit_fake))

d_loss = discriminator_loss(D_logit_real, D_logit_fake)

print("Discriminator loss: ", scalar(d_loss))
print(scalar(discriminator_loss(torch.tensor([10.0]).to(device), D_logit_fake)))
sample = np.concatenate((img[None, :, :, :], G_sample.data.cpu().numpy()))
plot_batch(sample, normalize=True, descriptions=[cd.describe(y)]*2)

In [None]:
w = discriminator.layer4[0].weight.data.cpu().numpy()
print(w.shape)
print(w[:, 0:5])
print(w[:, 5:])