In [1]:
from __future__ import division
import rawpy
import numpy as np
from matplotlib.pyplot import imshow
from scipy.misc import imread
import tensorflow as tf
import tensorflow.contrib.slim as slim
import glob
import os, time, scipy.io

In [2]:
# get ids of training set
input_dir = './dataset/HUAWEIMATE20/'
#input_dir = './dataset/NIKOND70/'
checkpoint_dir = './result/'
result_dir = './result/'
save_freq = 2
train_fns = glob.glob(input_dir + 'a*.dng')
train_ids = [int(os.path.basename(train_fn)[1:5]) for train_fn in train_fns]
#train_ids = train_ids[0:5]
# Initialize crop size
ps = 512

In [3]:
def relu(x):
    return tf.maximum(x * 0.2, x)

In [4]:
# data preprocessing
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    print("black_level: ", raw.black_level_per_channel)
    # black level of NIKON D70 is 128, and the raw image is in 12 bits.
    im = np.maximum(im, 0) / 4096
    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

In [5]:
def upsample_and_concat(x1, x2, output_channels, in_channels):
    pool_size = 2
    deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
    deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])

    deconv_output = tf.concat([deconv, x2], 3)
    deconv_output.set_shape([None, None, None, output_channels * 2])

    return deconv_output

In [6]:
def network(data):
    print("1")
    conv1 = slim.conv2d(data, 32, [3, 3], rate=1, activation_fn=relu, scope='g_conv1_1')
    conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=relu, scope='g_conv1_2')
    pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')

    conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=relu, scope='g_conv2_1')
    conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=relu, scope='g_conv2_2')
    pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')

    conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=relu, scope='g_conv3_1')
    conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=relu, scope='g_conv3_2')
    pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')

    conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=relu, scope='g_conv4_1')
    conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=relu, scope='g_conv4_2')
    pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')

    conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=relu, scope='g_conv5_1')
    conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=relu, scope='g_conv5_2')

    up6 = upsample_and_concat(conv5, conv4, 256, 512)
    conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=relu, scope='g_conv6_1')
    conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=relu, scope='g_conv6_2')

    up7 = upsample_and_concat(conv6, conv3, 128, 256)
    conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=relu, scope='g_conv7_1')
    conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=relu, scope='g_conv7_2')

    up8 = upsample_and_concat(conv7, conv2, 64, 128)
    conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=relu, scope='g_conv8_1')
    conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=relu, scope='g_conv8_2')

    up9 = upsample_and_concat(conv8, conv1, 32, 64)
    conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=relu, scope='g_conv9_1')
    conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=relu, scope='g_conv9_2')

    conv10 = slim.conv2d(conv9, 12, [1, 1], rate=1, activation_fn=None, scope='g_conv10')
    out = tf.depth_to_space(conv10, 2)
    return out



In [7]:
# Create a session.
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Reserve memory in the flow for input and label.
input_image = tf.placeholder(tf.float32, [None, None, None, 4])
target_image = tf.placeholder(tf.float32, [None, None, None, 3])
output_image = network(input_image)
# Define the loss function.
G_loss = tf.reduce_mean(tf.abs(output_image - target_image)*255)

t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
# Define optimizer for the flow.
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
    print('loaded ' + ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

input_images = [None] * len(train_ids)
target_images = [None] * len(train_ids)
learning_rate = 1e-4
g_loss = np.zeros((5000, 1))


for epoch in range(4001):
    if epoch > 2000:
        learning_rate = 1e-5
    i = 0
    for ind in np.random.permutation(len(train_ids)):
        if i == 5:
            break
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + 'a%04d*.dng' % train_id)
        in_path = in_files[0]
        
        if input_images[ind] is None:
            raw = rawpy.imread(in_path)
            input_images[ind] = np.expand_dims(pack_raw(raw), axis=0)
            print("before precrocessing: ", input_images[ind].shape)
            im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            target_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)
            print("after precrocessing: ", target_images[ind].shape)

        H = input_images[ind].shape[1]
        W = input_images[ind].shape[2]
        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)

        input_patch = input_images[ind][:, yy:yy + ps, xx:xx + ps, :]
        target_patch = target_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]
        print('target_shape', target_patch.shape)

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            target_patch = np.flip(target_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            target_patch = np.flip(target_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            target_patch = np.transpose(target_patch, (0, 2, 1, 3))
            
        input_patch = np.minimum(input_patch, 1.0)
        print('input_shape', input_patch.shape)
        print('target_shape', target_patch.shape)
            
        _, G_current, output = sess.run([G_opt, G_loss, output_image],
                                        feed_dict={input_image: input_patch, target_image: target_patch, lr: learning_rate})
        output = np.minimum(np.maximum(output, 0), 1)
        
        print('output_shape', output.shape)
        g_loss[ind] = G_current
        print('loss', G_current)
        print("%d Loss=%.3f" % (epoch, np.mean(g_loss[np.where(g_loss)])))
        if epoch % save_freq == 0:
            if not os.path.isdir(result_dir + '%04d' % epoch):
                os.makedirs(result_dir + '%04d' % epoch)
            print(target_patch*255)
            print(output*255)
            temp = np.concatenate((target_patch[0, :, :, :], output[0, :, :, :]), axis=1)
            scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255).save(
                result_dir + '%04d/%05d_00_train.jpg' % (epoch, train_id))
        i += 1
    saver.save(sess, checkpoint_dir + 'model.ckpt')
            

1
Instructions for updating:
Colocations handled automatically by placer.
loaded ./result/model.ckpt
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./result/model.ckpt
black_level:  [0, 0, 0, 0]
before precrocessing:  (1, 1488, 1984, 4)
after precrocessing:  (1, 2976, 3968, 3)
target_shape (1, 1024, 1024, 3)
input_shape (1, 512, 512, 4)
target_shape (1, 1024, 1024, 3)
output_shape (1, 1024, 1024, 3)
loss 0.053660274
0 Loss=0.054
[[[[ 48.793774    38.684822     0.15564202]
   [ 48.793774    47.22568     46.396885  ]
   [ 47.287937    51.548637    66.54864   ]
   ...
   [ 32.505836    37.031128    30.77821   ]
   [ 45.287937    47.66926     41.400776  ]
   [ 45.027237    45.494164    50.945526  ]]

  [[ 39.233463    40.447468    43.346302  ]
   [ 44.618675    48.12451     54.653694  ]
   [ 40.42412     45.59922     56.4358    ]
   ...
   [ 33.41245     36.836575    28.92607   ]
   [ 33.59533     38.77821   

`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.


black_level:  [0, 0, 0, 0]
before precrocessing:  (1, 1488, 1984, 4)
after precrocessing:  (1, 2976, 3968, 3)
target_shape (1, 1024, 1024, 3)
input_shape (1, 512, 512, 4)
target_shape (1, 1024, 1024, 3)


KeyboardInterrupt: 