In [1]:
%reload_ext autoreload
%autoreload 2

import tensorflow as tf
from keras import backend as K
import subtle.utils.io as suio
import matplotlib.pyplot as plt
import keras
import numpy as np
plt.set_cmap('gray')
plt.rcParams['figure.figsize'] = (10, 8)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

from keras.applications.vgg19 import VGG19
from keras.applications.imagenet_utils import preprocess_input as vgg_preprocess
from keras.models import Model

def extract_image_patches(x, ksizes, ssizes, padding='same', data_format='channels_last'):
    bs_i, w_i, h_i, ch_i = K.int_shape(x)
    kernel = [1, ksizes[0], ksizes[1], 1]
    strides = [1, ssizes[0], ssizes[1], 1]

    patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1], padding)
    bs, w, h, ch = K.int_shape(patches)
    reshaped = tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i])
    final_shape = [-1, w, h, ch_i, ksizes[0], ksizes[1]]
    patches = tf.reshape(tf.transpose(reshaped, [0, 1, 2, 4, 3]), final_shape)

    patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3])

    return patches

def ssim_loss(y_true, y_pred, kernel=(3, 3), k1=.01, k2=.03, kernel_size=3, max_value=1.):
    # ssim parameters
    cc1 = (k1 * max_value) ** 2
    cc2 = (k2 * max_value) ** 2

    # extract patches
    y_true = K.reshape(y_true, [-1] + list(K.int_shape(y_true)[1:]))
    y_pred = K.reshape(y_pred, [-1] + list(K.int_shape(y_pred)[1:]))

    patches_true = extract_image_patches(y_true, kernel, kernel, 'VALID', K.image_data_format())
    patches_pred = extract_image_patches(y_pred, kernel, kernel, 'VALID', K.image_data_format())
    
    bs, w, h, c1, c2, c3 = K.int_shape(patches_pred)
    patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3])
    patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3])

    # Get mean
    u_true = K.mean(patches_true, axis=-1)
    u_pred = K.mean(patches_pred, axis=-1)
    print('prod', K.eval(u_true * u_pred).mean())

    # Get variance
    var_true = K.var(patches_true, axis=-1)
    var_pred = K.var(patches_pred, axis=-1)

    # Get covariance
    covar_true_pred = K.mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
    
    # compute ssim and dssim
    ssim = (2 * u_true * u_pred + cc1) * (2 * covar_true_pred + cc2)
    denom = (K.square(u_true) + K.square(u_pred) + cc1) * (var_pred + var_true + cc2)
    ssim /= denom
    
    return K.mean((1.0 - ssim) / 2.0)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


<Figure size 432x288 with 0 Axes>

In [2]:
def perceptual_loss(y_true, y_pred, img_shape, resize_shape):
    # From https://bit.ly/2HTb4t9

    num_slices = int(y_pred.shape[-1])
    print('num slices', num_slices)

    vgg = VGG19(include_top=False, weights='imagenet', input_shape=img_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False

    loss_vals = []

    for idx in range(num_slices):
        y_true_sl = K.expand_dims(y_true[..., idx])
        y_pred_sl = K.expand_dims(y_pred[..., idx])

        if resize_shape > 0:
            # For 512x512 images, VGG-19 creates some grid artifacts because the
            # original network is trained with 224x224 images
            y_true_sl = tf.image.resize(y_true_sl, (resize_shape, resize_shape))
            y_pred_sl = tf.image.resize(y_pred_sl, (resize_shape, resize_shape))

        y_true_3c = K.concatenate([y_true_sl, y_true_sl, y_true_sl])
        y_pred_3c = K.concatenate([y_pred_sl, y_pred_sl, y_pred_sl])

        y_true_3c = vgg_preprocess(y_true_3c, mode='caffe')
        y_pred_3c = vgg_preprocess(y_pred_3c, mode='caffe')
        
#         print(K.eval(y_true_3c).mean(), K.eval(y_pred_3c).mean())
        
        v1 = loss_model(y_true_3c)
        v2 = loss_model(y_pred_3c)
        
#         print(K.eval(v1).mean(), K.eval(v2).mean())
        mse = K.mean(K.square(v1 - v2))
        loss_vals.append(mse)
    
    loss_vals = tf.stack(loss_vals)
    return tf.math.reduce_mean(loss_vals)

In [7]:
data = np.load('/home/srivathsa/projects/studies/gad/stanford/preprocess/slices/Patient_0088/ax/150.npy')
data2 = np.load('/home/srivathsa/projects/studies/gad/stanford/preprocess/slices/Patient_0088/ax/155.npy')

pre = data[0]
low = data[1]

full = data[2]
full2 = data2[2]

ip1 = K.constant(pre[None, ..., None].astype(np.float32))
ip2 = K.constant(low[None, ..., None].astype(np.float32))
ip3 = K.constant(full[None, ..., None].astype(np.float32))
ip4 = K.constant(full2[None, ..., None].astype(np.float32))

print(ip1.shape)

(1, 512, 512, 1)


In [10]:
vgg_loss = perceptual_loss(ip3, ip4, img_shape=(512, 512, 3), resize_shape=0)
print(K.eval(vgg_loss))

num slices 1
689.9607


In [None]:
vgg_val = K.eval(vgg_loss)

In [None]:
vgg_val

In [None]:
ip1_3c = K.concatenate([ip1, ip1, ip1])
ip2_3c = K.concatenate([ip2, ip2, ip2])
ip1_pp = vgg_preprocess(ip1_3c)
ip2_pp = vgg_preprocess(ip2_3c)

In [None]:
ip1_img = K.eval(ip1_pp)

plt.imshow(ip1_img[0, ..., 0])
plt.colorbar()

In [None]:
ip2_img = K.eval(ip2_pp)
plt.imshow(ip2_img[0, ..., 0])
plt.colorbar()