In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import tensorflow as tf
import tqdm

from collections import deque
from skimage.transform import resize

# Own libraries and modules
from helpers import loading, plotting, utils
from models import compression

In [None]:
training = {
    'n_epochs': 3000,
    'batch_size': 32,
    'patch_size': 128
}

In [None]:
patch_size = 128
n_latent = 256
n_latent_bytes = 2

bitmap_size = patch_size*patch_size*3
compression_rate = bitmap_size / (n_latent_bytes * n_latent)
compression_bpp = 8 * n_latent * n_latent_bytes / patch_size / patch_size

print('Bitmap size: {:,d} bytes'.format(bitmap_size))
print('Latent size: {:,}-D'.format(n_latent))
print('Latent repr: {:,} bytes'.format(n_latent * n_latent_bytes))
print('Compression rate: 1:{}'.format(compression_rate))
print('Compression Fi  : {:.2f} bpp'.format(compression_bpp))

## Load the dataset

In [None]:
class IPDataset:
    
    def __init__(self, data_directory, randomize=False, load='xy', val_patch_size=128, val_n_patches=10):
        self.files = {}
        self.files['training'], self.files['validation'] = loading.discover_files(data_directory, randomize=randomize)
                
        self.data = {
            'training': loading.load_images(self.files['training'], data_directory=data_directory, load=load),
            'validation': loading.load_patches(self.files['validation'], data_directory=data_directory, patch_size=val_patch_size // 2, n_patches=val_n_patches, load=load, discard_flat=True)
        }
        
        self.H, self.W = self.data['training']['x'].shape[1:3]
        
    def __getitem__(self, key):
        if key in ['training', 'validation']:
            return self.data[key]
        else:
            return super().__getitem__(key)
        
    def next_training_batch(self, batch_id, batch_size, patch_size):
        batch_x = np.zeros((batch_size, patch_size, patch_size, 3), dtype=np.float32)
        for b in range(training['batch_size']):
            xx = np.random.randint(0, self.W - patch_size)
            yy = np.random.randint(0, self.H - patch_size)
            batch_x[b, :, :, :] = self.data['training']['y'][batch_id * batch_size + b, yy:yy + patch_size, xx:xx + patch_size, :].astype(np.float) / (2**8 - 1)
        return batch_x

    def next_validation_batch(self, batch_id, batch_size):
        patch_size = self.data['validation']['y'].shape[1]
        batch_x = np.zeros((batch_size, patch_size, patch_size, 3), dtype=np.float32)
        for b in range(training['batch_size']):
            batch_x[b, :, :, :] = self.data['validation']['y'][batch_id * batch_size + b].astype(np.float)
        return batch_x
        

In [None]:
# Load data
camera_name = "Nikon D90"
data_directory = os.path.join('./data/raw/nip_training_data/', camera_name)

data = IPDataset(data_directory)

for dataset in ['training', 'validation']:
    print('{} : {}'.format(dataset, data[dataset]['y'].shape))

## Define the Deep Compression Network Models

In [None]:
class AutoencoderDCN(compression.DCN):
    
    def construct_model(self):
        
        with tf.name_scope('dcn'):

            activation = tf.nn.leaky_relu
            last_activation = tf.nn.sigmoid

            print('Building Deep Compression Network with d-latent={}'.format(n_latent))

            net = self.x
            print('net size: {}'.format(net.shape))
        
            # Convolutions
            n_filters = self.n_filters
            
            for r in range(self.n_layers):
                net = tf.contrib.layers.conv2d(net, n_filters, self.kernel, stride=2, activation_fn=activation, scope='dcn{}/conv_{}'.format(self.label, r))
            #     print('net size: {}'.format(net.shape))
#                     net = tf.contrib.layers.max_pool2d(net, 2, scope='dcn{}/pool_{}'.format(self.label, r))
                print('net size: {}'.format(net.shape))
                n_filters *= self.n_fscale

            # Flatten and get latent representation
            flat = tf.contrib.layers.flatten(net, scope='dcn{}/flat_{}'.format(self.label, 0))
            print('net size: {}'.format(flat.shape))

            latent = tf.contrib.layers.fully_connected(flat, self.n_latent, activation_fn=activation, scope='dcn{}/dense_{}'.format(self.label, 0))
            print('net size: {}'.format(latent.shape))

            inet = tf.contrib.layers.fully_connected(latent, int(flat.shape[-1]), activation_fn=activation, scope='dcn{}/dense_{}'.format(self.label, 1))
            print('net size: {}'.format(inet.shape))
            inet = tf.reshape(net, tf.shape(net), name='dcn{}/reshape_{}'.format(self.label, 0))
            print('net size: {}'.format(inet.shape))

            # Transposed convolutions
            for r in range(self.n_layers):
                inet = tf.contrib.layers.conv2d_transpose(inet, 3 if r == self.n_layers - 1 else n_filters, self.kernel, stride=2, 
                                                          activation_fn=last_activation if r == self.n_layers - 1 else activation,
                                                          scope='dcn{}/tconv_{}'.format(self.label, r))
                print('net size: {}'.format(inet.shape))
                n_filters = n_filters // self.n_fscale

            y = inet

        with tf.name_scope('dcn{}_optimization'.format(self.label)):
            lr = tf.placeholder(tf.float32, name='dcn_learning_rate')
            loss = tf.nn.l2_loss(self.x - y)
            adam = tf.train.AdamOptimizer(learning_rate=lr)
            opt = adam.minimize(loss, var_list=self.parameters)
            
        return y, lr, loss, adam, opt, latent

In [None]:
class ResDCN(compression.DCN):
    
    def construct_model(self):
        
        with tf.name_scope('dcn'):

            activation = tf.nn.ResDCN
            last_activation = tf.nn.sigmoid

            print('Building Deep Compression Network with d-latent={}'.format(n_latent))

            net = self.x
            print('net size: {}'.format(net.shape))
        
            # Convolutions
            n_filters = self.n_filters
            
            net = tf.contrib.layers.conv2d(net, 64, self.kernel, stride=2, activation_fn=activation, scope='dcn{}/conv_{}'.format(self.label, 0))
            
            for r in range(self.n_layers):
                net = tf.contrib.layers.conv2d(net, n_filters, self.kernel, stride=2, activation_fn=activation, scope='dcn{}/conv_{}'.format(self.label, r))
            #     print('net size: {}'.format(net.shape))
#                     net = tf.contrib.layers.max_pool2d(net, 2, scope='dcn{}/pool_{}'.format(self.label, r))
                print('net size: {} // {}'.format(net.shape, net))
                n_filters *= self.n_fscale

            # Flatten and get latent representation
            flat = tf.contrib.layers.flatten(net, scope='dcn{}/flat_{}'.format(self.label, 0))
            print('net size: {}'.format(flat.shape))

            latent = tf.contrib.layers.fully_connected(flat, self.n_latent, activation_fn=activation, scope='dcn{}/dense_{}'.format(self.label, 0))
            print('net size: {}'.format(latent.shape))

            inet = tf.contrib.layers.fully_connected(latent, int(flat.shape[-1]), activation_fn=activation, scope='dcn{}/dense_{}'.format(self.label, 1))
            print('net size: {}'.format(inet.shape))
            inet = tf.reshape(net, tf.shape(net), name='dcn{}/reshape_{}'.format(self.label, 0))
            print('net size: {}'.format(inet.shape))

            # Transposed convolutions
            for r in range(self.n_layers):
                inet = tf.contrib.layers.conv2d_transpose(inet, 3 if r == self.n_layers - 1 else n_filters, self.kernel, stride=2, 
                                                          activation_fn=last_activation if r == self.n_layers - 1 else activation,
                                                          scope='dcn{}/tconv_{}'.format(self.label, r))
                print('net size: {}'.format(inet.shape))
                n_filters = n_filters // self.n_fscale

            y = inet

        with tf.name_scope('dcn{}_optimization'.format(self.label)):
            lr = tf.placeholder(tf.float32, name='dcn_learning_rate')
            loss = tf.nn.l2_loss(self.x - y)
            adam = tf.train.AdamOptimizer(learning_rate=lr)
            opt = adam.minimize(loss, var_list=self.parameters)
            
        return y, lr, loss, adam, opt, latent

## Create DCN instance

In [None]:
graph = tf.Graph()
sess = tf.Session(graph=graph)

dcn = AutoencoderDCN(sess, graph, patch_size=128, n_latent=512, n_layers=3, n_fscale=1, n_filters=16)

print(dcn.summary())
# print(dcn.count_parameters_breakdown())
print('Compression stats:', dcn.compression_stats())

## Training

In [None]:
batch_x = data.next_training_batch(batch_id, training['batch_size'], 256)

print(batch_x.shape)
batch_t = np.zeros((batch_x.shape[0], 128, 128, 3), dtype=np.float32)

for i in range(len(batch_x)):
    batch_t[i] = resize(batch_x[i], [patch_size, patch_size], anti_aliasing=True)

print(batch_t.shape)

f = plotting.imsc(batch_x[0:8], ncols=8, figwidth=20)
f = plotting.imsc(batch_t[0:8], ncols=8, figwidth=20)

In [None]:
dcn.init()

# Compute the number of available batches
n_batches = data['training']['y'].shape[0] // training['batch_size']
v_batches = data['validation']['y'].shape[0] // training['batch_size']

loss = {'training': [], 'validation': []}
loss_ma = deque(maxlen=n_batches)
loss_va = deque(maxlen=v_batches)

# Configure data augmentation
augmentation_probs = {
    'resize': 0.0,
    'flip_h': 0.0,
    'flip_v': 0.0
}

with tqdm.tqdm(total=training['n_epochs'], ncols=120, desc='Train') as pbar:

    for epoch in range(0, training['n_epochs']):

        # Iterate through batches of the training data 
        for batch_id in range(n_batches):
            
            # Pick random patch size - will be resized later for augmentation
            current_patch = np.random.choice(np.arange(128, 256), 1) if np.random.uniform() < augmentation_probs['resize'] else patch_size
            
            # Sample next batch
            batch_x = data.next_training_batch(batch_id, training['batch_size'], current_patch)
            
            # If rescaling needed, apply
            if patch_size != current_patch:
                batch_t = np.zeros((batch_x.shape[0], patch_size, patch_size, 3), dtype=np.float32)
                for i in range(len(batch_x)):
                    batch_t[i] = resize(batch_x[i], [patch_size, patch_size], anti_aliasing=True)
                batch_x = batch_t                
            
            # Data augmentation - random horizontal flip
            if np.random.uniform() < augmentation_probs['flip_h']: batch_x = batch_x[:, :, ::-1, :]
            if np.random.uniform() < augmentation_probs['flip_v']: batch_x = batch_x[:, ::-1, :, :]
            
            # Make a training step
            loss_value = dcn.training_step(batch_x, 1e-4)
            loss_ma.append(loss_value)
        
        # Iterate through batches of the validation data
        for batch_id in range(v_batches):
            batch_x = data.next_validation_batch(batch_id, training['batch_size'])
            batch_y = dcn.process(batch_x)
            loss_value = np.linalg.norm(batch_x - batch_y)
            loss_va.append(loss_value)

        # Record average values for the whole epoch
        loss['training'].append(np.mean(loss_ma))
        loss['validation'].append(np.mean(loss_va))

        # Update progress bar
        pbar.set_postfix(loss=np.mean(loss['training']), loss_v=np.mean(loss['validation']))
        pbar.update(1)

In [None]:
import matplotlib.pyplot as plt
from helpers import utils

fig = plt.figure(figsize=(20, 4))
ax = fig.gca()
ax.semilogy(utils.ma_conv(loss['training'], n=11))
ax.semilogy(utils.ma_conv(loss['validation'], n=11))
ax.semilogy(loss['training'], '.', alpha=0.3)
ax.semilogy(loss['validation'], '.', alpha=0.3)
ax.legend(['train', 'valid'], loc='upper right')

In [None]:
batch_id = (batch_id + 1) % n_batches
batch_x = data.next_training_batch(batch_id, training['batch_size'], patch_size)
fig = plotting.imsc(batch_x[:8], ncols=8, figwidth=20)

In [None]:
# Show a sample and a reconstruction of the current batch
batch_y = dcn.process(batch_x)
f = plotting.imsc(batch_x[0:8], ncols=8, figwidth=20)
f = plotting.imsc(batch_y[0:8], ncols=8, figwidth=20)

In [None]:
from helpers import plotting

# See latent distribution
batch_z = dcn.compress(batch_x)
batch_z = batch_z[:4]

fig, axes = plotting.sub(len(batch_z), ncols=10, figwidth=20)

for i, ax in enumerate(axes):
    ax.hist(batch_z[i], bins=30)
    ax.set_yticks([])
