In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os.path as op
import time

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import SVG
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.utils.vis_utils import model_to_dot
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tensorflow import set_random_seed
from tqdm import tqdm_notebook

from data import im_generator
from unet import unet

Using TensorFlow backend.


In [2]:
np.random.seed(1)
set_random_seed(1)

In [3]:
# print models as seen by tf
# from tensorflow.python.client import device_lib
# print(device_lib.list_local_devices())

In [4]:
# model = unet(input_size=(28, 28, 1), with_extra_sigmoid=False, n_layers=2)

In [5]:
# model.weights

In [6]:
# SVG(model_to_dot(model).create(prog='dot', format='svg'))

In [7]:
batch_size = 32
noise_std = 30
source = 'cifar_grey'
validation_split = 0.1
if 'cifar' in source:
    n_samples_train = 5*1e4
    size = 32
elif 'mnist' in source:
    n_samples_train = 6*1e4
    size = 28
n_samples_test = 1e4
im_gen_train = im_generator(mode='training', validation_split=0.1, batch_size=batch_size, source=source, noise_std=noise_std)
im_gen_val = im_generator(mode='validation', validation_split=0.1, batch_size=batch_size, source=source, noise_std=noise_std)
im_gen_test = im_generator(mode='testing', batch_size=batch_size, source=source, noise_std=noise_std)

In [8]:
params = {
    'cifar_classic_1': {'n_layers': 2},
    'cifar_without_relu_contracting_1': {'n_layers': 2, 'non_relu_contract': True},
    'cifar_aver_pool_1': {'n_layers': 2, 'pool': 'average'},
    'cifar_classic_2': {'n_layers': 3},
    'cifar_without_relu_contracting_2': {'n_layers': 3, 'non_relu_contract': True},
    'cifar_aver_pool_2': {'n_layers': 3, 'pool': 'average'},
    'cifar_classic_3': {'n_layers': 4},
    'cifar_without_relu_contracting_3': {'n_layers': 4, 'non_relu_contract': True},
    'cifar_aver_pool_3': {'n_layers': 4, 'pool': 'average'},
}

In [None]:
n_epochs = 50
K.clear_session() 
for run_id, run_params in tqdm_notebook(params.items()):
    print(run_id)
    model = unet(input_size=(size, size, 1), with_extra_sigmoid=False, **run_params)
    log_dir = op.join('logs', run_id)
    tboard_cback = TensorBoard(
        log_dir=log_dir, 
        histogram_freq=0, 
        batch_size=batch_size, 
        write_graph=True, 
        write_images=True, 
    )
    model.fit_generator(
        im_gen_train, 
        steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size), 
        epochs=n_epochs,
        validation_data=im_gen_val,
        validation_steps=int(validation_split * n_samples_train / batch_size),
        verbose=0,
        callbacks=[TQDMNotebookCallback(), tboard_cback],
    )
    K.clear_session() 

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

cifar_classic_1
Instructions for updating:
Colocations handled automatically by placer.


  model = Model(input=inputs, outputs=output)


Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=1406, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=1406, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=1406, style=ProgressStyle(description_width='i…

In [None]:
# run_id = str(int(time.time()))
# print(run_id)
# log_dir = op.join('logs', run_id)
# tboard_cback = TensorBoard(
#     log_dir=log_dir, 
#     histogram_freq=1, 
#     batch_size=batch_size, 
#     write_graph=True, 
#     write_images=True, 
# )

In [None]:
# history = model.fit_generator(
#     im_gen_train, 
#     steps_per_epoch=int(0.9*5*1e4 / 32), 
#     epochs=5,
#     validation_data=next(im_gen_val),
#     validation_steps=int(0.1*5*1e4 / 32),
# #     validation_freq=2,
#     verbose=0,
# #     use_multiprocessing=True,
# #     callbacks=[TQDMNotebookCallback(), tboard_cback],
# )

In [None]:
# from numba import cuda
# cuda.select_device(0)
# cuda.close()

In [None]:
# plt.figure(figsize=(9, 5))
# for key, val in history.history.items():
#     plt.plot(np.log(val), label=key)
# plt.legend()

In [None]:
# batch_test_noisy, batch_test_gt = next(im_gen_test)

In [None]:
# batch_test_pred = model.predict(batch_test_noisy)

In [None]:
# K.eval(model.layers[-1].activation(-1e6))

In [None]:
# batch_test_pred.max()

In [None]:
# fig, axs = plt.subplots(32, 3, figsize=(14, 32*5))
# for example_id in range(32):
    
#     axs[example_id, 0].imshow(batch_test_gt[example_id, ..., 0], cmap='gray')
#     axs[example_id, 0].set_title("original image")
#     axs[example_id, 1].imshow(batch_test_noisy[example_id, ..., 0], cmap='gray')
#     axs[example_id, 1].set_title("noisy image")
#     axs[example_id, 2].imshow(batch_test_pred[example_id, ..., 0], cmap='gray')
#     axs[example_id, 2].set_title("denoised image")