In [1]:
import keras
import keras.backend as K
from keras.layers import Input,Conv3D,MaxPooling3D,UpSampling3D,Lambda,Activation,concatenate
from keras.models import Model,Sequential
from keras.callbacks import CSVLogger
import tensorflow as tf
import numpy as np
import h5py
from _malis import malis_loss_weights
from malis_utils import mknhood3d,seg_to_affgraph,nodelist_from_shape
#from malis_loss import malis_loss_op

Using TensorFlow backend.


In [2]:
fa = h5py.File('sample_A_20160501.hdf','r')
raw_data = fa['volumes']['raw']
#data1 = np.expand_dims(np.expand_dims(fa['volumes']['raw'],axis = 0),axis = 0)
neuron_ids = fa['volumes']['labels']['neuron_ids']

nhood = mknhood3d(1)
e = nhood.shape[0]
aff_gt = seg_to_affgraph(neuron_ids, nhood)   #(edge,z,y,x) 
#aff_gt_label = np.expand_dims(aff_gt,axis=0)

# data_ch = data1[:,:,:124,:1248,:1248]  #(batch,channel=1,z,y,x)
# aff_gt_label = aff_gt_label[:,:,:124,:1248,:1248]  #(batch,edge=4,z,y,x)
# seg_gt = seg_gt[:124,:1248,:1248]   #(z,y,x)

In [3]:
data = np.zeros((100,125,125,125))
seg_gt = np.zeros((100,125,125,125))
aff_gt_label = np.zeros((100,e,125,125,125))
for i in range(10):
    for j in range(10):
        data[i*10+j] = raw_data[:,i*125:(i+1)*125,j*125:(j+1)*125]
        seg_gt[i*10+j] = neuron_ids[:,i*125:(i+1)*125,j*125:(j+1)*125]
        aff_gt_label[i*10+j] = aff_gt[:,:,i*125:(i+1)*125,j*125:(j+1)*125]

In [4]:
data_ch = np.expand_dims(data,axis=1)[:,:,:124,:124,:124]  #(batch,channel=1,z,y,x)
aff_gt_label = aff_gt_label[:,:,:124,:124,:124]  #(batch,edge=4,z,y,x)
seg_gt = np.expand_dims(seg_gt,axis=1)[:,:,:124,:124,:124]   #(batch, 1, z,y,x)

----------------------------------------------------------------

In [34]:
class MalisWeights(object):

    def __init__(self, output_shape, neighborhood):

        self.output_shape = np.asarray(output_shape)
        self.neighborhood = np.asarray(neighborhood)
        self.edge_list = nodelist_from_shape(self.output_shape, self.neighborhood)

    def get_edge_weights(self, affs, gt_affs, gt_seg, gt_aff_mask, gt_seg_unlabelled):

        # replace the unlabelled-object area with a new unique ID
        if tf.greater(tf.size(gt_seg_unlabelled), 0):
            gt_seg[gt_seg_unlabelled == 0] = gt_seg.max() + 1

        assert affs.shape[0] == len(self.neighborhood)

        weights_neg = self.malis_pass(affs, gt_affs, gt_seg, gt_aff_mask, pos=0)
        weights_pos = self.malis_pass(affs, gt_affs, gt_seg, gt_aff_mask, pos=1)

        return weights_neg + weights_pos

    def malis_pass(self, affs, gt_affs, gt_seg, gt_aff_mask, pos):

        # create a copy of the affinities and change them, such that in the
        #   positive pass (pos == 1): affs[gt_affs == 0] = 0
        #   negative pass (pos == 0): affs[gt_affs == 1] = 1

        pass_affs = tf.identity(affs)

        if tf.equal(tf.size(gt_aff_mask), 0):
            constraint_edges = gt_affs == (1 - pos)
        else:
            constraint_edges = np.logical_and(
                gt_affs == (1 - pos),
                gt_aff_mask == 1)
        
        if pos == 0:
            pass_affs = tf.where(constraint_edges, tf.ones_like(pass_affs), pass_affs)
        else:
            pass_affs = tf.where(constraint_edges, tf.zeros_like(pass_affs), pass_affs)
        
        weights = malis_loss_weights(
            gt_seg.numpy().astype(np.int32).flatten(),
            self.edge_list[0].flatten(),
            self.edge_list[1].flatten(),
            pass_affs.numpy().astype(np.float32).flatten(),
            pos)
        
        
        weights = weights.reshape((-1,) + tuple(self.output_shape))
        assert weights.shape[0] == len(self.neighborhood)
        
        weights = tf.convert_to_tensor(weights)
        
        # '1-pos' samples don't contribute in the 'pos' pass
        weights = tf.where(gt_affs == (1 - pos), tf.zeros_like(weights), weights)

        # masked-out samples don't contribute
        if tf.greater(tf.size(gt_aff_mask), 0):
            weights = tf.where(gt_aff_mask == 0, tf.zeros_like(weights), weights)

            
        weights = tf.cast(weights, dtype=tf.float32)
        num_pairs = tf.reduce_sum(weights)
        if tf.greater(num_pairs,0):
            weights = tf.divide(weights,num_pairs)

        return weights

In [35]:
def malis_weights_op(
        affs,
        gt_affs,
        gt_seg,
        neighborhood,
        gt_aff_mask=None,
        gt_seg_unlabelled=None,
        name=None):

    if gt_aff_mask is None:
        gt_aff_mask = tf.zeros((0,))
    if gt_seg_unlabelled is None:
        gt_seg_unlabelled = tf.zeros((0,))

    output_shape = gt_seg.shape.as_list()

    malis_weights = MalisWeights(output_shape, neighborhood)
    print('malis weights', malis_weights)
    malis_functor = lambda \
            affs, \
            gt_affs, \
            gt_seg, \
            gt_aff_mask, \
            gt_seg_unlabelled, \
            mw=malis_weights: \
        mw.get_edge_weights(
            affs,
            gt_affs,
            gt_seg,
            gt_aff_mask,
            gt_seg_unlabelled)
        
    print('malis_functor',malis_functor)

    weights = tf.py_function(
        malis_functor,
        [affs, gt_affs, gt_seg, gt_aff_mask, gt_seg_unlabelled],
        [tf.float32],
        name=name)
    

    return weights[0]

In [9]:
def malis_loss_op(
        affs,
        gt_affs,
        gt_seg,
        neighborhood,
        gt_aff_mask=None,
        gt_seg_unlabelled=None,
        name=None):

    weights = malis_weights_op(
        affs,
        gt_affs,
        gt_seg,
        neighborhood,
        gt_aff_mask,
        gt_seg_unlabelled,
        name)
    edge_loss = tf.square(tf.subtract(gt_affs, affs))

    return tf.reduce_sum(tf.multiply(weights, edge_loss))

----------------------------------------

In [5]:
def malis_weights(affinity_pred, affinity_gt, seg_gt, nhood,
                 unrestrict_neg=False):    
    nhood = np.array(nhood)
    
    edgelist_cache = dict()

    sh = affinity_pred.shape
    vol_sh = sh[1:]
    
    key = (tuple(vol_sh), nhood.tobytes())
    
    if key in edgelist_cache:
        node1, node2 = edgelist_cache[key]
    else:
        node1, node2 = nodelist_from_shape(vol_sh, nhood)
        node1, node2 = node1.ravel(), node2.ravel()
        edgelist_cache[key] = (node1, node2)


    affinity_gt   = np.ascontiguousarray(affinity_gt,dtype=np.float32).ravel()
    affinity_pred = np.ascontiguousarray(affinity_pred, dtype=np.float32).ravel()
    seg_gt        = np.ascontiguousarray(seg_gt, dtype=np.int32).ravel()

    # MALIS
    edge_weights_pos = np.minimum(affinity_pred, affinity_gt)
    pos_counts = malis_loss_weights(seg_gt,
                                    node1,
                                    node2,
                                    edge_weights_pos,
                                    1)
    if unrestrict_neg:
        edge_weights_neg = affinity_pred
    else:
        edge_weights_neg = np.maximum(affinity_pred, affinity_gt)


    neg_counts = malis_loss_weights(seg_gt,
                                    node1,
                                    node2,
                                    edge_weights_neg,
                                    0)

    pos_counts = pos_counts.reshape(sh)
    neg_counts = neg_counts.reshape(sh)
    

    return pos_counts, neg_counts

In [6]:
def MALIS_loss(seg_gt):
    
    def loss(y_true,y_pred):
        
        z = K.int_shape(y_pred)[2]
        y = K.int_shape(y_pred)[3]
        x = K.int_shape(y_pred)[4]

        new_y_true = K.reshape(y_true,(e,z,y,x))
        new_y_pred = K.reshape(y_pred,(e,z,y,x))
        new_seg = K.reshape(seg_gt,(z,y,x))
        
        pos_t, neg_t = tf.py_function(
            malis_weights,
            [new_y_pred, new_y_true, new_seg, nhood],
            [tf.int32,tf.int32])
        pos_t = tf.cast(pos_t,tf.float32)
        neg_t = tf.cast(neg_t,tf.float32)
        
        loss = tf.reduce_sum(pos_t * new_y_pred)
        
        return loss
    return loss

In [7]:
inputShape = (data_ch.shape[1], data_ch.shape[2], data_ch.shape[3], data_ch.shape[4])

inputs = Input(shape=(inputShape))
input_seg = Input(shape=(inputShape))
conv_block_1 = Conv3D(32, (3, 3, 3), strides=(1, 1, 1), padding='same')(inputs)
conv_block_1 = Activation('relu')(conv_block_1)
pool_block_1 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_1)

conv_block_2 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(pool_block_1)
conv_block_2 = Activation('relu')(conv_block_2)
pool_block_2 = MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2))(conv_block_2)


up_block_1 = UpSampling3D((2, 2, 2))(pool_block_2)
up_block_1 = Conv3D(64, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_1)
up_block_2 = UpSampling3D((2, 2, 2))(up_block_1)
up_block_2 = Conv3D(32, (3, 3, 3), strides=(1, 1, 1), padding='same')(up_block_2)
conv_block_10 = Conv3D(e, (1, 1, 1), strides=(1, 1, 1), padding='same')(up_block_2)
outputs = Activation('sigmoid')(conv_block_10)

model = Model(inputs=[inputs,input_seg], outputs=outputs)
csv_logger = CSVLogger('/home/haicu/ruolin.shen/projects/malis/malis/training.log')
#tbCallBack = Tensor\Board(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True,update_freq='batch')
model.compile(optimizer='adadelta', loss=MALIS_loss(input_seg))
model.summary()
model.fit([data_ch,seg_gt],aff_gt_label,epochs=3,verbose=1,batch_size=1,callbacks=[csv_logger])


Tensor("loss/activation_3_loss/loss/Reshape:0", shape=(4, 124, 124, 124), dtype=float32) Tensor("loss/activation_3_loss/loss/Reshape_1:0", shape=(4, 124, 124, 124), dtype=float32) Tensor("loss/activation_3_loss/loss/Reshape_2:0", shape=(124, 124, 124), dtype=float32)
Tensor("loss/activation_3_loss/loss/Cast:0", dtype=float32) Tensor("loss/activation_3_loss/loss/Cast_1:0", dtype=float32)
Tensor("loss/activation_3_loss/loss/Sum:0", shape=(), dtype=float32)
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 1, 124, 124, 124)  0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 32, 124, 124, 124) 896       
_________________________________________________________________
activation_1 (Activation)    (None, 32, 124, 124, 124) 0         
______________________________________________________________

KeyboardInterrupt: 