In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
from inpainting.dataset import Data, ResizeTransform, NoiseSampler
from gan.gan import Generator5Net, Discriminator5
from gan.trainer import GanTrainer
import torch
print(torch.__version__)
from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt
import pandas as pd
from inpainting.visualize import plot_batch
from inpainting.visualize import GanPlotLossCallback as PlotLossCallback
from inpainting import celeba_config as conf
from inpainting.visualizer import Visualizer
from performance.estimator import FIDEstimator

%matplotlib notebook

In [None]:
#torch.cuda.set_device(conf.CUDA_DEVICE)
device = torch.device(conf.DEVICE)
if hasattr(conf, 'ESTIMATOR_DEVICE'):
    estimator_device = torch.device(conf.ESTIMATOR_DEVICE)
else:
    estimator_device = device

In [None]:

transform = ResizeTransform()
data = Data(conf.DATA_PATH, transform)
train_size = int(0.8 * len(data))
valid_size = len(data) - train_size
train_data, valid_data = torch.utils.data.random_split(data, [train_size, valid_size])
train_loader = DataLoader(train_data, batch_size=conf.BATCH_SIZE, num_workers=conf.NUM_WORKERS, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=conf.BATCH_SIZE, num_workers=conf.NUM_WORKERS, shuffle=True)
print('Dataset size: ', len(data))
noise_sampler = NoiseSampler(conf.Z_SIZE)


In [None]:
estimator = FIDEstimator(noise_sampler, device=estimator_device)

In [None]:
X_real = []
X_fake = []
for idx, sample1, sample2 in zip(range(100), valid_loader, train_loader):
    X1, = sample1
    X2, = sample2
    X_real.append(X1[:X2.shape[0], ...])
    X_fake.append(X2[:X1.shape[0], ...])
X_real = np.concatenate(X_real)
X_fake = np.concatenate(X_fake)
distance = estimator.distance(X_real, X_fake)
print(X_real.shape[0])
print(distance)
del X1, X2, X_real, X_fake

In [None]:
generator = Generator5Net(conf.Z_SIZE).to(device)
discriminator = Discriminator5().to(device)

In [None]:
visualizer = Visualizer(conf, noise_sampler)
trainer = GanTrainer(generator, discriminator, conf, noise_sampler, visualizer=visualizer, estimator=estimator)

In [None]:
print(estimator.score(generator, valid_loader))

In [None]:
if conf.CONTINUE_TRAINING:
    trainer.load_last_checkpoint()
    
trainer.train(train_loader, valid_loader, n_epochs=10)

In [None]:

#print(Z)

In [None]:
LOAD_MODEL = True
LOAD_EPOCH_N = 90
if LOAD_MODEL:
    generator.load_state_dict(torch.load(conf.MODEL_PATH + 'generator_%d.pth' % (LOAD_EPOCH_N,)))
    discriminator.load_state_dict(torch.load(conf.MODEL_PATH + 'discriminator_%d.pth' % (LOAD_EPOCH_N,)))

In [None]:
for batch,in train_loader:
    plot_batch((batch.data.cpu().numpy() + 1)/2, limit=6)
    break

In [None]:
Z, = noise_sampler.sample()
G_sample = generator(Z)
sample = G_sample.data.cpu().numpy()
print(discriminator.layer4.weight.cpu().detach().numpy())
plot_batch((G_sample.data.cpu().numpy() + 1) / 2)

In [None]:
w = generator.layer3[0].weight.data.cpu().numpy()
print(w)

In [None]:
for X, in valid_loader:
    Z, = noise_sampler.sample_batch(valid_loader.batch_size)
    G_sample = generator(Z)
    D_real, D_logit_real = discriminator(X)
    D_fake, D_logit_fake = discriminator(G_sample)
    print('D_real', D_real,'\n', 'D_fake', D_fake)
    print('D_logit_real', D_logit_real, '\n', 'D_logit_fake', D_logit_fake)
    break
    