In [1]:
#conda install -c conda-forge ipywidgets
#jupyter nbextension enable --py widgetsnbextension
#

In [2]:
import os

In [3]:
from utils.ObjectData import ObjectDataset
from net.Generator import Generator 
from net.Discriminator import Discriminator
from utils.VisdomLinePlotter import VisdomLinePlotter
from utils.utils import save_plot_voxels

  return f(*args, **kwds)


In [4]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader

In [5]:
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage

In [6]:
batch_size = 20
n_workers = 0
learning_rate_G = 0.002
learning_rate_D = 0.0002
beta_1 = 0.5
device = "cuda"
output_dir = "checkpoint_toilet/"
CKPT_PREFIX = "tlet"
SAVE_INTERVAL = 10
EPOCHS = 1000
PRINT_INTERVAL = 10

FAKE_IMG_FNAME = 'checkpoint_toilet/fake_sample_epoch_{:04d}'
REAL_IMG_FNAME = 'checkpoint_toilet/real_sample_epoch_{:04d}'
LOGS_FNAME = 'logs.tsv'

In [7]:
side_len = 32
z_dim = 32

In [8]:
# dataset
object_data = ObjectDataset("data/test_toilet.csv", side_len=side_len)
object_data_loader = DataLoader(object_data, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

In [9]:
# netowrks
netG = Generator(side_len, z_dim).to(device)
netD = Discriminator(side_len).to(device)

# criterion
bce = nn.BCELoss()

# optimizers
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate_G, betas=(beta_1, 0.5))
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate_D, betas=(beta_1, 0.5))

# misc
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

# plotter
plotter = VisdomLinePlotter(env_name="3dgan_train")



In [10]:
def get_noise():
    return torch.randn(batch_size, z_dim, 1, 1, device=device)

In [11]:
def step(engine, batch):

        real = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()
        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            'errD': errD.item(),
            'errG': errG.item(),
            'D_x': D_x,
            'D_G_z1': D_G_z1,
            'D_G_z2': D_G_z2
        }



In [12]:
trainer = Engine(step)
checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False)
timer = Timer(average=True)

# attach running average metrics
monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2']
RunningAverage(output_transform=lambda x: x['errD']).attach(trainer, 'errD')
RunningAverage(output_transform=lambda x: x['errG']).attach(trainer, 'errG')
RunningAverage(output_transform=lambda x: x['D_x']).attach(trainer, 'D_x')
RunningAverage(output_transform=lambda x: x['D_G_z1']).attach(trainer, 'D_G_z1')
RunningAverage(output_transform=lambda x: x['D_G_z2']).attach(trainer, 'D_G_z2')

# attach progress bar
pbar = ProgressBar()
pbar.attach(trainer, metric_names=monitoring_metrics)





In [13]:
@trainer.on(Events.ITERATION_COMPLETED)
def print_logs(engine):
    fname = os.path.join(output_dir, LOGS_FNAME)
    columns = ["iteration", ] + list(engine.state.metrics.keys())
    values = [str(engine.state.iteration), ] + \
             [str(round(value, 5)) for value in engine.state.metrics.values()]

    with open(fname, 'a') as f:
        if f.tell() == 0:
            print('\t'.join(columns), file=f)
        print('\t'.join(values), file=f)

    message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch,
                                                          max_epoch=EPOCHS,
                                                          i=(engine.state.iteration % len(object_data_loader)),
                                                          max_i=len(object_data_loader))
    for name, value in zip(columns, values):
        message += ' | {name}: {value}'.format(name=name, value=value)

    pbar.log_message(message)


# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EXCEPTION_RAISED)
def handle_exception(engine, e):
    if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
        engine.terminate()
        warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

        create_plots(engine)
        checkpoint_handler(engine, {
            'netG_exception': netG,
            'netD_exception': netD
        })

    else:
        raise e

In [14]:
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED(every=PRINT_INTERVAL))
def save_fake_example(engine):
    fake = netG(fixed_noise).reshape(-1, side_len, side_len, side_len)
    plotter.plot_voxels(FAKE_IMG_FNAME.format(engine.state.epoch), fake[0].detach().cpu().numpy(), FAKE_IMG_FNAME.format(engine.state.epoch), savePLY=True)

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED(every=PRINT_INTERVAL))
def save_real_example(engine):
    img = engine.state.batch.reshape(-1, side_len, side_len, side_len)
    plotter.plot_voxels(REAL_IMG_FNAME.format(engine.state.epoch), img[0].detach().cpu().numpy(), REAL_IMG_FNAME.format(engine.state.epoch))


# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
    pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(engine.state.epoch, timer.value()))
    timer.reset()

# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED(every=PRINT_INTERVAL))
def create_plots(engine):
    fake = netG(fixed_noise).reshape(-1, side_len, side_len, side_len)
    save_plot_voxels(fake[0:10], FAKE_IMG_FNAME.format(engine.state.epoch), engine.state.epoch)

In [15]:
# adding handlers using `trainer.add_event_handler` method API
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                          to_save={
                              'netG': netG,
                              'netD': netD
                          })

# automatically adding handlers via a special `attach` method of `Timer` handler
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
             pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

<ignite.handlers.timing.Timer at 0x7fa0d42af6a0>

In [None]:
trainer.run(object_data_loader, EPOCHS)

  "Please ensure they have the same size.".format(target.size(), input.size()))


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[1/1000][1/5] | iteration: 1 | errD: 1.39099 | errG: 1.52528 | D_x: 0.52706 | D_G_z1: 0.5276 | D_G_z2: 0.21763
[1/1000][2/5] | iteration: 2 | errD: 1.13614 | errG: 1.66921 | D_x: 0.59395 | D_G_z1: 0.44922 | D_G_z2: 0.19039
[1/1000][3/5] | iteration: 3 | errD: 0.85281 | errG: 1.86711 | D_x: 0.68365 | D_G_z1: 0.35822 | D_G_z2: 0.15862
[1/1000][4/5] | iteration: 4 | errD: 0.60796 | errG: 2.09063 | D_x: 0.75696 | D_G_z1: 0.25889 | D_G_z2: 0.12882
[1/1000][0/5] | iteration: 5 | errD: 0.44546 | errG: 2.29241 | D_x: 0.8067 | D_G_z1: 0.1883 | D_G_z2: 0.1057
Epoch 1 done. Time per batch: 3.987[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[2/1000][1/5] | iteration: 6 | errD: 0.22452 | errG: 2.66494 | D_x: 0.91043 | D_G_z1: 0.12067 | D_G_z2: 0.06961
[2/1000][2/5] | iteration: 7 | errD: 0.20295 | errG: 2.74096 | D_x: 0.91849 | D_G_z1: 0.10938 | D_G_z2: 0.06473
[2/1000][3/5] | iteration: 8 | errD: 0.17536 | errG: 2.84685 | D_x: 0.92424 | D_G_z1: 0.09023 | D_G_z2: 0.05848
[2/1000][4/5] | iteration: 9 | errD: 0.14833 | errG: 2.96054 | D_x: 0.93434 | D_G_z1: 0.07595 | D_G_z2: 0.05238
[2/1000][0/5] | iteration: 10 | errD: 0.12552 | errG: 3.08547 | D_x: 0.94282 | D_G_z1: 0.06358 | D_G_z2: 0.04641
Epoch 2 done. Time per batch: 5.278[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[3/1000][1/5] | iteration: 11 | errD: 0.08532 | errG: 3.33279 | D_x: 0.96174 | D_G_z1: 0.04516 | D_G_z2: 0.03584
[3/1000][2/5] | iteration: 12 | errD: 0.07982 | errG: 3.40299 | D_x: 0.96414 | D_G_z1: 0.04231 | D_G_z2: 0.03347
[3/1000][3/5] | iteration: 13 | errD: 0.07242 | errG: 3.50047 | D_x: 0.9692 | D_G_z1: 0.04019 | D_G_z2: 0.03051
[3/1000][4/5] | iteration: 14 | errD: 0.06476 | errG: 3.60418 | D_x: 0.9722 | D_G_z1: 0.0358 | D_G_z2: 0.028
[3/1000][0/5] | iteration: 15 | errD: 0.05869 | errG: 3.71991 | D_x: 0.97567 | D_G_z1: 0.03338 | D_G_z2: 0.02526
Epoch 3 done. Time per batch: 5.611[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[4/1000][1/5] | iteration: 16 | errD: 0.04573 | errG: 3.97101 | D_x: 0.98118 | D_G_z1: 0.02631 | D_G_z2: 0.01897
[4/1000][2/5] | iteration: 17 | errD: 0.04118 | errG: 4.03621 | D_x: 0.98228 | D_G_z1: 0.02298 | D_G_z2: 0.01786
[4/1000][3/5] | iteration: 18 | errD: 0.03771 | errG: 4.13127 | D_x: 0.98365 | D_G_z1: 0.02096 | D_G_z2: 0.01635
[4/1000][4/5] | iteration: 19 | errD: 0.03366 | errG: 4.24139 | D_x: 0.98546 | D_G_z1: 0.0188 | D_G_z2: 0.01464
[4/1000][0/5] | iteration: 20 | errD: 0.02952 | errG: 4.3535 | D_x: 0.98685 | D_G_z1: 0.01613 | D_G_z2: 0.01309
Epoch 4 done. Time per batch: 3.483[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[5/1000][1/5] | iteration: 21 | errD: 0.02239 | errG: 4.56365 | D_x: 0.98977 | D_G_z1: 0.01203 | D_G_z2: 0.01062
[5/1000][2/5] | iteration: 22 | errD: 0.02145 | errG: 4.62096 | D_x: 0.99037 | D_G_z1: 0.01171 | D_G_z2: 0.00996
[5/1000][3/5] | iteration: 23 | errD: 0.01975 | errG: 4.71273 | D_x: 0.99116 | D_G_z1: 0.01081 | D_G_z2: 0.00911
[5/1000][4/5] | iteration: 24 | errD: 0.01795 | errG: 4.81369 | D_x: 0.99193 | D_G_z1: 0.0098 | D_G_z2: 0.00832
[5/1000][0/5] | iteration: 25 | errD: 0.01629 | errG: 4.92502 | D_x: 0.99276 | D_G_z1: 0.00897 | D_G_z2: 0.00741
Epoch 5 done. Time per batch: 3.471[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[6/1000][1/5] | iteration: 26 | errD: 0.01249 | errG: 5.14611 | D_x: 0.99418 | D_G_z1: 0.00663 | D_G_z2: 0.00595
[6/1000][2/5] | iteration: 27 | errD: 0.01186 | errG: 5.21042 | D_x: 0.99453 | D_G_z1: 0.00636 | D_G_z2: 0.00554
[6/1000][3/5] | iteration: 28 | errD: 0.01079 | errG: 5.29992 | D_x: 0.99501 | D_G_z1: 0.00577 | D_G_z2: 0.00509
[6/1000][4/5] | iteration: 29 | errD: 0.00968 | errG: 5.40534 | D_x: 0.99554 | D_G_z1: 0.00519 | D_G_z2: 0.00458
[6/1000][0/5] | iteration: 30 | errD: 0.00872 | errG: 5.51287 | D_x: 0.996 | D_G_z1: 0.0047 | D_G_z2: 0.00413
Epoch 6 done. Time per batch: 5.290[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[7/1000][1/5] | iteration: 31 | errD: 0.00812 | errG: 5.71282 | D_x: 0.99685 | D_G_z1: 0.00495 | D_G_z2: 0.0034
[7/1000][2/5] | iteration: 32 | errD: 0.00728 | errG: 5.76695 | D_x: 0.99704 | D_G_z1: 0.00429 | D_G_z2: 0.00332
[7/1000][3/5] | iteration: 33 | errD: 0.02062 | errG: 5.83759 | D_x: 0.99724 | D_G_z1: 0.01595 | D_G_z2: 0.00336
[7/1000][4/5] | iteration: 34 | errD: 0.0373 | errG: 5.89929 | D_x: 0.99747 | D_G_z1: 0.02531 | D_G_z2: 0.00356
[7/1000][0/5] | iteration: 35 | errD: 0.03843 | errG: 5.95087 | D_x: 0.99751 | D_G_z1: 0.02649 | D_G_z2: 0.00372
Epoch 7 done. Time per batch: 5.963[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[8/1000][1/5] | iteration: 36 | errD: 0.02223 | errG: 6.03215 | D_x: 0.99748 | D_G_z1: 0.01726 | D_G_z2: 0.00365
[8/1000][2/5] | iteration: 37 | errD: 0.01522 | errG: 6.03084 | D_x: 0.99718 | D_G_z1: 0.01115 | D_G_z2: 0.00376
[8/1000][3/5] | iteration: 38 | errD: 0.01116 | errG: 6.0715 | D_x: 0.99728 | D_G_z1: 0.00781 | D_G_z2: 0.00353
[8/1000][4/5] | iteration: 39 | errD: 0.00898 | errG: 6.15001 | D_x: 0.99741 | D_G_z1: 0.00606 | D_G_z2: 0.00311
[8/1000][0/5] | iteration: 40 | errD: 0.00704 | errG: 6.25837 | D_x: 0.99763 | D_G_z1: 0.0045 | D_G_z2: 0.00253
Epoch 8 done. Time per batch: 4.358[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[9/1000][1/5] | iteration: 41 | errD: 0.00428 | errG: 6.48185 | D_x: 0.99798 | D_G_z1: 0.00226 | D_G_z2: 0.0018
[9/1000][2/5] | iteration: 42 | errD: 0.00426 | errG: 6.57863 | D_x: 0.99808 | D_G_z1: 0.00234 | D_G_z2: 0.00161
[9/1000][3/5] | iteration: 43 | errD: 0.0038 | errG: 6.69322 | D_x: 0.99831 | D_G_z1: 0.00211 | D_G_z2: 0.00139
[9/1000][4/5] | iteration: 44 | errD: 0.00317 | errG: 6.8069 | D_x: 0.99854 | D_G_z1: 0.00171 | D_G_z2: 0.00124
[9/1000][0/5] | iteration: 45 | errD: 0.00298 | errG: 6.9302 | D_x: 0.99856 | D_G_z1: 0.00154 | D_G_z2: 0.00111
Epoch 9 done. Time per batch: 5.178[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[10/1000][1/5] | iteration: 46 | errD: 0.00232 | errG: 7.18727 | D_x: 0.99891 | D_G_z1: 0.00123 | D_G_z2: 0.00077
[10/1000][2/5] | iteration: 47 | errD: 0.00203 | errG: 7.25883 | D_x: 0.99903 | D_G_z1: 0.00106 | D_G_z2: 0.00073
[10/1000][3/5] | iteration: 48 | errD: 0.00203 | errG: 7.36788 | D_x: 0.99899 | D_G_z1: 0.00102 | D_G_z2: 0.00066
[10/1000][4/5] | iteration: 49 | errD: 0.00175 | errG: 7.48984 | D_x: 0.99916 | D_G_z1: 0.00091 | D_G_z2: 0.00058
[10/1000][0/5] | iteration: 50 | errD: 0.00145 | errG: 7.59026 | D_x: 0.99929 | D_G_z1: 0.00074 | D_G_z2: 0.00054
Epoch 10 done. Time per batch: 4.515[s]


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

[11/1000][1/5] | iteration: 51 | errD: 0.00108 | errG: 7.80408 | D_x: 0.99948 | D_G_z1: 0.00056 | D_G_z2: 0.00042
[11/1000][2/5] | iteration: 52 | errD: 0.001 | errG: 7.87871 | D_x: 0.99952 | D_G_z1: 0.00052 | D_G_z2: 0.00039


ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception: .


NameError: name 'warnings' is not defined

ERROR:tornado.general:Uncaught exception in ZMQStream callback
Traceback (most recent call last):
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/tornado/stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 421, in execute_request
    self._abort_queues()
  File "/home/pablo/miniconda3/envs/facenet/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 636, in 