In [1]:
%matplotlib inline
import glob
import os
import keras
import tensorflow
import numpy
import pickle
import skimage
import skimage.io
from keras.layers import *
from keras.models import *
import skimage.transform
import tqdm
import h5py
import datetime

Using TensorFlow backend.


In [2]:
config = tensorflow.ConfigProto()
config.gpu_options.allow_growth = True
session = tensorflow.Session(config=config)
keras.backend.set_session(session)

In [3]:
data_dir = "/home/santiago/Projects/Enhancer/data"
xtrain = sorted(glob.glob(os.path.join(data_dir, "Imagenet32_train/*")), key=lambda s: s.split('_')[-1])
ytrain = sorted(glob.glob(os.path.join(data_dir, "Imagenet64_train_part*/*")), key=lambda s: s.split('_')[-1])
xval = os.path.join(data_dir, "Imagenet32_val/val_data")
yval = os.path.join(data_dir, "Imagenet64_val/val_data")
out_dir = "/home/santiago/Projects/Enhancer/checkpoints"

In [4]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dic = pickle.load(fo)
    return dic

In [5]:
def load(subset='train', batch=0):
    if subset == 'train':
        yfile = ytrain[batch // 2]
    else:
        yfile = yval
    ydata = unpickle(yfile)['data'].astype(numpy.float32) / 255
    yimages = np.dstack((ydata[:, :4096], ydata[:, 4096:8192], ydata[:, 8192:]))
    yimages = yimages.reshape((yimages.shape[0], 64, 64, 3))
    ximages = numpy.empty((yimages.shape[0], 32, 32, 3))
    for i in range(yimages.shape[0]):
        ximages[i, :, :, :] = skimage.transform.rescale(yimages[i, :, :, :], 0.5, mode='reflect')
    assert len(ximages) == len(yimages)
    return (ximages, yimages)

In [6]:
inputs = Input((32, 32, 3))

x = Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)

x = Conv2D(256, (3, 3), padding='same', activation='relu')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), padding='same', activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

model = Model(inputs, decoded)

In [7]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 32)        896       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 128)         73856     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 128)         0         
__________

In [8]:
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])

In [9]:
alive = sorted(glob.glob(os.path.join(out_dir, '*.hdf5')))
if len(alive) > 0:
    print("Loading weights from {}...".format(alive[-1]))
    model.load_weights(alive[-1])

Loading weights from /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-23_23:38:47.222014.hdf5...


In [10]:
while True:
    for i in range(10):
        print("Loading batch {}...".format(i))
        time = str(datetime.datetime.now()).replace(' ', '_')
        filename = os.path.join(out_dir, "checkpoint_{}.hdf5".format(time))
        data = load(subset='train', batch=i)
        model.fit(data[0], data[1], batch_size=512, epochs=10)
        del data
        model.save_weights(filename)
        alive.append(filename)
        if len(alive) > 10:
            os.remove(alive.pop(0))
        print("Wrote", filename)

Loading batch 0...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_01:19:26.561837.hdf5
Loading batch 1...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_01:30:59.407137.hdf5
Loading batch 2...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_01:42:13.593739.hdf5
Loading batch 3...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_01:53:30.053416.hdf5
Loading batch 4...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/

Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_02:38:21.406958.hdf5
Loading batch 8...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_02:49:32.574535.hdf5
Loading batch 9...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_03:00:40.399453.hdf5
Loading batch 0...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_03:11:52.325569.hdf5
Loading batch 1...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-

Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_04:07:33.193194.hdf5
Loading batch 6...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_04:18:40.160579.hdf5
Loading batch 7...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_04:29:47.999684.hdf5
Loading batch 8...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_04:40:55.521041.hdf5
Loading batch 9...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/san

Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_05:25:26.978745.hdf5
Loading batch 3...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_05:36:34.093872.hdf5
Loading batch 4...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_05:47:41.995224.hdf5
Loading batch 5...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_05:58:49.630341.hdf5
Loading batch 6...
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Wrote /home/santiago/Projects/Enhancer/checkpoints/checkpoint_2017-08-24_06:10:03

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-551c3aecc65c>", line 7, in <module>
    model.fit(data[0], data[1], batch_size=512, epochs=10)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1430, in fit
    initial_epoch=initial_epoch)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1079, in _fit_loop
    outs = f(ins_batch)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 2268, in __call__
    **self.session_kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 997, in _run
    feed_dict_string, options, run_metadata)
  File "/usr/local/lib/python3.

KeyboardInterrupt: 

In [None]:
val_data = load(subset='val')

In [None]:
print(model.evaluate(val_data[0], val_data[1], batch_size=512))

In [None]:
predictions = model.predict(val_data[0], batch_size=512, verbose=1)

In [None]:
wallpaper = skimage.io.imread("/home/santiago/Pictures/Wallpapers/412057.jpg").astype(numpy.float32) / 255

In [None]:
blocks = skimage.util.view_as_blocks(wallpaper, (32, 32, 3))
print(blocks.shape)

In [None]:
result = numpy.empty((blocks.shape[0]*64, blocks.shape[1]*64, 3))
for y in range(blocks.shape[0]):
    for x in range(blocks.shape[1]):
        result[y*64:(y+1)*64, x*64:(x+1)*64, :] = model.predict(blocks[y, x, :, :, :, :])[0]

In [None]:
skimage.io.imsave("/home/santiago/Projects/Enhancer/results/super.png", result)