In [1]:
import matplotlib.pyplot as plt
import PIL
import numpy as np
    
def transparent_to_white(cur_img):
        # Replace transparent background (from slide slicing) with white background
        canvas = PIL.Image.new('RGBA', cur_img.size, (220,220,220,255)) # Empty canvas colour (r,g,b,a)
        canvas.paste(cur_img, mask=cur_img) # Paste the image onto the canvas, using it's alpha channel as mask
        canvas.thumbnail([cur_img.width, cur_img.height], PIL.Image.ANTIALIAS)
        return canvas

def build_tissue_mask(input_svs, target_mpp, tile_size):
    from openslide import OpenSlide as op
    import math
    import itertools
    
    def on_error_return():
        print input_svs
        return None
 
    slide = slide = op(input_svs)
    
    slide_dim_0 = slide.dimensions
    slide_mpp = float(slide.properties['aperio.MPP'])
    smallest_level = len(slide.level_downsamples) - 1
    
    if target_mpp < slide_mpp:
        print 'Slide mpp: %f, Target mpp: %f, skipping: %s' % (slide_mpp, target_mpp, input_svs)
        return on_error_return()
        
    factor_mpp =  target_mpp / slide_mpp # e.g 20x->40x : 1/(1 / 0.25) = 0.25 (need to take quarter the size of the patch)
    tile_size_level_0 = ((int)(factor_mpp*tile_size[0]),(int)(factor_mpp*tile_size[1]))
        
    print tile_size_level_0, slide_dim_0, tile_size
        
    avail_levels = [i for i, ds in enumerate(slide.level_downsamples) if slide_mpp*ds<=target_mpp]
    if len(avail_levels) == 0:
        print 'No avail levels: %s' % input_svs
        return on_error_return()
    
    target_level = max(avail_levels)
    target_factor_mpp = target_mpp / (slide.level_downsamples[target_level] * slide_mpp)
    tile_size_level_target = ((int)(target_factor_mpp*tile_size[0]),(int)(target_factor_mpp*tile_size[1]))
    
    smallest_level = len(slide.level_downsamples) - 1
    tile_size_smallest_factor = slide.level_downsamples[-1]
    tile_size_smallest = ((int) (tile_size_level_0[0] / tile_size_smallest_factor), (int) (tile_size_level_0[1] / tile_size_smallest_factor) )
    
    print tile_size_smallest
    
    col_loc = range(0,slide_dim_0[0],tile_size_level_0[0])
    row_loc = range(0,slide_dim_0[1],tile_size_level_0[1])
    
    all_locs = list(itertools.product(col_loc, row_loc)) # cartesian product
    
    img_size = len(col_loc), len(row_loc)
    thumb_img_size = img_size[0] * tile_size_smallest[0], img_size[1] * tile_size_smallest[1]
    mask_image = np.zeros(shape=img_size)
    
    thumb_img = PIL.Image.new('RGBA', thumb_img_size, (220,220,220,255)) # Empty canvas colour (r,g,b,a)
    full_thumb_img = slide.get_thumbnail(size=thumb_img_size)
    
    loc_list=[]
    img_list = []
    for (i,j) in all_locs:        
        samp_image = slide.read_region(location=(i,j), level=smallest_level, size=tile_size_smallest)
        canvas = transparent_to_white(samp_image)
                
        cur_img_avg = np.mean(np.array(canvas))
        
        is_tissue = cur_img_avg <= 220
        x,y = i/tile_size_level_0[0], j/tile_size_level_0[1]
        if is_tissue:
            loc_list.append((i,j))
            mask_image[x,y] = 255
            thumb_img.paste(canvas, mask=canvas, box=(x*tile_size_smallest[0],y*tile_size_smallest[1])) # Paste the image onto the canvas, using it's alpha channel as mask
    mask_image = PIL.Image.fromarray(np.uint8(mask_image.T))
    print len(all_locs)
    print mask_image.size[0] * mask_image.size[1]
    return loc_list, target_level, tile_size_level_target, tile_size_level_0, mask_image, thumb_img, full_thumb_img
    #return slide.get_thumbnail(size=mask_image.size)

"""
#input_svs = '/var/shared/zelda-tcga/a32d95a7-66ec-441c-b40d-71b68adeb500/TCGA-A1-A0SK-01Z-00-DX1.A44D70FA-4D96-43F4-9DD7-A61535786297.svs'
input_svs = '/var/shared/zelda-tcga/cdbde7ab-3de0-40c9-a82c-0b40fba36a38/TCGA-S3-AA12-01Z-00-DX2.4F0A4F18-41C7-4497-A7B8-5DCE610E08AD.svs'
tile_size = (224, 224)
%time loc_list, target_level, tile_size_level_target, tile_size_level_0, mask_image, thumb_img, full_thumb_img = build_tissue_mask(input_svs, target_mpp=0.6, tile_size=tile_size)

print 'Target level: %d' % target_level
print 'Tile size target: %s' % str(tile_size_level_target)

fig=plt.figure(figsize=(12, 8))
fig.add_subplot(1,3,1)
plt.imshow(mask_image)
fig.add_subplot(1,3,2)
plt.imshow(thumb_img)
fig.add_subplot(1,3,3)
plt.imshow(full_thumb_img)
"""

"\n#input_svs = '/var/shared/zelda-tcga/a32d95a7-66ec-441c-b40d-71b68adeb500/TCGA-A1-A0SK-01Z-00-DX1.A44D70FA-4D96-43F4-9DD7-A61535786297.svs'\ninput_svs = '/var/shared/zelda-tcga/cdbde7ab-3de0-40c9-a82c-0b40fba36a38/TCGA-S3-AA12-01Z-00-DX2.4F0A4F18-41C7-4497-A7B8-5DCE610E08AD.svs'\ntile_size = (224, 224)\n%time loc_list, target_level, tile_size_level_target, tile_size_level_0, mask_image, thumb_img, full_thumb_img = build_tissue_mask(input_svs, target_mpp=0.6, tile_size=tile_size)\n\nprint 'Target level: %d' % target_level\nprint 'Tile size target: %s' % str(tile_size_level_target)\n\nfig=plt.figure(figsize=(12, 8))\nfig.add_subplot(1,3,1)\nplt.imshow(mask_image)\nfig.add_subplot(1,3,2)\nplt.imshow(thumb_img)\nfig.add_subplot(1,3,3)\nplt.imshow(full_thumb_img)\n"

In [2]:
import itertools

def slide_generator(input_svs, loc_list, target_level, tile_size_level_0, tile_size_level_target, tile_size):
    from openslide import OpenSlide as op
    
    slide = slide = op(input_svs)
    dim = slide.dimensions
    
    resize_factor = tuple([tile_size_level_0[i] / tile_size[i] for i in range(2)])
    
    for x,y in loc_list:
        #print x,y
        loc = x,y
        resize = tile_size
        tile = slide.read_region(loc, level=target_level, size=tile_size_level_target)
        tile_resized = transparent_to_white(tile).resize(resize)
        yield tile_resized


"""
%time patch_gen = slide_generator(input_svs, loc_list, target_level, tile_size_level_0, tile_size_level_target, tile_size)

patch_gen, patch_gen_clone = itertools.tee(patch_gen, 2)

%time patch_gen_clone.next()
"""

'\n%time patch_gen = slide_generator(input_svs, loc_list, target_level, tile_size_level_0, tile_size_level_target, tile_size)\n\npatch_gen, patch_gen_clone = itertools.tee(patch_gen, 2)\n\n%time patch_gen_clone.next()\n'

In [3]:
from keras.applications.imagenet_utils import preprocess_input

def normalize_input(X):
    #pixel_depth = 255.0
    #return (X - pixel_depth / 2) / pixel_depth
    return preprocess_input(X)

def batch_generator(patch_generator, n_patches, batch_size):
    while True:
        batch_patches = []
        try:
            for i in range(n_patches*batch_size):
                batch_patches.append(np.array(patch_generator.next().convert('RGB')))
        except StopIteration:
            if len(batch_patches) < n_patches:
                raise StopIteration
        if len(batch_patches) % n_patches != 0:
            batch_patches = batch_patches[: - (len(batch_patches) % n_patches)]
        batch_arr = np.array(batch_patches)
        batch_new_shape = tuple([batch_arr.shape[0] / n_patches , n_patches] + list(batch_arr.shape[1:]))
        yield normalize_input(np.reshape(batch_arr, batch_new_shape))
    

batch_size = 100
n_patches = 3

"""
%time batch_gen = batch_generator(patch_gen, n_patches, batch_size)

batch_gen, batch_gen_clone = itertools.tee(batch_gen, 2)

%time __ = batch_gen_clone.next()
"""

Using TensorFlow backend.


'\n%time batch_gen = batch_generator(patch_gen, n_patches, batch_size)\n\nbatch_gen, batch_gen_clone = itertools.tee(batch_gen, 2)\n\n%time __ = batch_gen_clone.next()\n'

In [22]:
def read_svs_region(slide, loc, tile_size_level_0, tile_size, target_level, tile_size_level_target):
    
    dim = slide.dimensions
    
    resize_factor = tuple([tile_size_level_0[i] / tile_size[i] for i in range(2)])
    
    resize = tile_size
    tile = slide.read_region(loc, level=target_level, size=tile_size_level_target)
    tile_resized = transparent_to_white(tile).resize(resize).convert('RGB')
    return np.array(tile_resized)
                
def multi_thread_batch_gen(input_svs, loc_list, n_patches, batch_size, tile_size_level_0, tile_size, target_level, tile_size_level_target, max_workers=10):
    from concurrent.futures import ThreadPoolExecutor
    from openslide import OpenSlide as op
    
    def prepare_batch(batch_patches):
        if len(batch_patches) % n_patches != 0:
            batch_patches = batch_patches[: - (len(batch_patches) % n_patches)]
        batch_arr = np.array(batch_patches)
        batch_new_shape = tuple([batch_arr.shape[0] / n_patches , n_patches] + list(batch_arr.shape[1:]))
        return normalize_input(np.reshape(batch_arr, batch_new_shape))
        
    slide = op(input_svs)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for loc in loc_list:
            f = executor.submit(read_svs_region, \
                slide, loc, tile_size_level_0, tile_size, target_level, tile_size_level_target)
            futures.append(f)
            
        batch_patches = []
        for i, f in enumerate(futures):
            patch = f.result()
            batch_patches.append(patch)
            if (i+1) % (n_patches*batch_size) == 0:
                yield prepare_batch(batch_patches)
                batch_patches = []
        if len(batch_patches) >= n_patches:
            yield prepare_batch(batch_patches)
            
"""
multi_batch_gen = multi_thread_batch_gen(input_svs, loc_list, n_patches, batch_size, tile_size_level_0, tile_size, target_level, tile_size_level_target)

multi_batch_gen, multi_batch_gen_clone = itertools.tee(multi_batch_gen, 2)

%time __ = multi_batch_gen_clone.next()
"""

'\nmulti_batch_gen = multi_thread_batch_gen(input_svs, loc_list, n_patches, batch_size, tile_size_level_0, tile_size, target_level, tile_size_level_target)\n\nmulti_batch_gen, multi_batch_gen_clone = itertools.tee(multi_batch_gen, 2)\n\n%time __ = multi_batch_gen_clone.next()\n'

In [6]:
import itertools

test_batch_gen = False

if test_batch_gen:
    print len(loc_list)
    batch_gen, batch_gen_clone = itertools.tee(multi_batch_gen, 2)
        
    counter = 0
    for b in batch_gen_clone:
        print len(b), b.shape, counter
        counter += 1
        if counter==5:
            break

In [7]:
import keras.backend as K
import keras
import matplotlib.pyplot as plt
from keras.layers import Reshape, Layer

def loss_msk_0(y_true, y_pred):
    '''Just another crossentropy'''
    loss = keras.losses.categorical_crossentropy(y_true, y_pred)
    mask = K.tf.reduce_sum(y_true, axis=-1) #K.tf.not_equal(K.tf.reduce_sum(y_true, axis=-1), 0 )
    loss_masked = loss * mask #loss_masked = K.tf.boolean_mask(loss, mask)
    return loss_masked

def acc_msk_0(x_true, x_pred):
    """Calculate accuracy, ignore (mask) 0 labels"""

    none_zero_count = K.sum(x_true)

    def calc_acc():
        trade_index = K.tf.not_equal(K.tf.reduce_sum(x_true, axis=-1), 0 )

        ##remove predictions that are 0
        x_true_tradeable = K.tf.boolean_mask(x_true, trade_index)
        x_pred_tradeable = K.tf.boolean_mask(x_pred, trade_index)

        #print x_true_tradeable.shape

        accuracy = K.mean(K.equal(K.tf.argmax(x_true_tradeable,-1), K.tf.argmax(x_pred_tradeable,-1)))
                
        accuracy = K.tf.cast(accuracy, K.tf.float32)
        return accuracy

    #K.eval(accuracy) will cause InvalidArgumentError if none_zero_count==0, add tensorflow condition so it won't be evaluated
    final_acc = K.tf.cond(K.tf.less(none_zero_count, 1e-7), lambda: 0.0, calc_acc)

    #return K.tf.stack([none_zero_count, final_acc]) #K.tf.stack([final_acc]) #
    return K.tf.stack([final_acc])

def relu6(x):
    return K.relu(x, max_value=6)

def make_one(tensor):
    tensor_comp = 1-tensor
    return K.concatenate([tensor, tensor_comp], axis=-1)

class FullReshape(Reshape):
    """Like keras Reshape layer but doing reshape on the batch dimension as well.
        Same as doing Lambda(x : K.reshape(x, K.tf.stack(target_shape))), but lambda layers are fragile to save\load"""
    def __init__(self, target_shape, **kwargs):
        super(FullReshape, self).__init__(target_shape, **kwargs)
        self.trainable = False
    def compute_output_shape(self, input_shape):
        return super(FullReshape, self).compute_output_shape( (None,)+input_shape )[1:]
    def call(self, inputs):
        return K.reshape(inputs, K.tf.stack(self.target_shape))
    def get_config(self):
        #print 'ReshapeLayerConfig'
        return super(FullReshape, self).get_config()
class SumLayer(Layer):
    def __init__(self, axis, keep_dims=False, **kwargs):
        super(SumLayer, self).__init__(**kwargs)
        self.axis = axis
        self.keep_dims = keep_dims
        self.trainable = False
    def compute_output_shape(self, input_shape):
        ls = input_shape[:self.axis]
        if self.keep_dims:
            ls += (1,)
        ls += input_shape[self.axis+1:]
        return ls #K.tf.stack(ls)
    def call(self, inputs):
        return K.sum(inputs, axis=self.axis, keepdims=self.keep_dims)
    
    def get_config(self):
        #print 'SumLayerConfig'
        config = {'axis': self.axis, 'keep_dims':self.keep_dims}
        base_config = super(SumLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

custom_objects_scope_dict = {'relu6': relu6, 'acc_msk_0':acc_msk_0, 'loss_msk_0':loss_msk_0, 
                             'make_one':make_one,
                             'FullReshape':FullReshape,
                             'SumLayer':SumLayer,
                             'keras':keras}

logs_dir = '/var/shared/hist_project/Data/logs'
checkpoint_file = 'exp_mobilenetv2_0.6mpp_3patches_sigmoid_att_cnn_new_er_eq_slidev3_dx_bn_preproc_val_e0.h5'
#checkpoint_file = 'exp_mobilenetv2_0.6mpp_3patches_sigmoid_att_cnn_new_er_eq_slidev3_dx_bn_preproc_e69_freezeenc_10p_val_e0.h5'

model = keras.models.load_model(logs_dir+"/"+checkpoint_file, 
                            custom_objects=custom_objects_scope_dict)

In [8]:
def get_model_layer_outputs(model, batch, layers_ids=None):
    inp = model.input                                           # input placeholder
    if layers_ids == None:
        layers_ids = range(len(model.layers))[1:]
    outputs = [model.layers[layer_idx].output for layer_idx in layers_ids]          # all layer outputs
    functor = K.function([inp, K.learning_phase()], outputs )   # evaluation function

    layer_outs = functor([batch, 0.])
    return list((model.layers[l_id].name, out) for l_id, out in  zip(layers_ids, layer_outs))

In [9]:
import math
from vis.utils import utils

def get_final_frames_layer_outs(model, loc_list, batch_gen, batch_size, n_patches):
    target_layer_ids = [utils.find_layer_idx(model, s) for s in ('ER_Status_patch_sigmoid','ER_Status_att_probs_tan')]

    print 'number of patches: %d' % len(loc_list)
    
    total_batches = int(math.ceil(1.0 * len(loc_list) / batch_size/ n_patches))
    print 'Getting output for %d batches. Batch_Size = %d' % (total_batches, batch_size)

    all_frames_layer_outs = []
    for i, batch_i in enumerate(batch_gen):
        print 'batch %d...' % i,
        batch_frames_layer_outs = get_model_layer_outputs(model, batch_i, target_layer_ids)
        all_frames_layer_outs.append(batch_frames_layer_outs)
        print 'Done'

    print 'Merging...',

    if len(all_frames_layer_outs) != total_batches:
        print 'WARNING: expected %d batches but %d found' % (total_batches, len(all_frames_layer_outs))
    
    # Merge
    final_frames_layer_outs = []
    for layer_id in range(len(target_layer_ids)):
        final_layer_out = None
        for batch_id in range(len(all_frames_layer_outs)):
            name, layer_out = all_frames_layer_outs[batch_id][layer_id]
            if final_layer_out is None:
                final_layer_out = layer_out
            else:
                final_layer_out = np.concatenate([final_layer_out, layer_out],axis=0)
        final_frames_layer_outs.append( (name, final_layer_out) )
    print 'Done'
    return final_frames_layer_outs

"""
final_frames_layer_outs = get_final_frames_layer_outs(model, loc_list, multi_batch_gen, batch_size, n_patches)
print final_frames_layer_outs[0][1].shape
"""

'\nfinal_frames_layer_outs = get_final_frames_layer_outs(model, loc_list, multi_batch_gen, batch_size, n_patches)\nprint final_frames_layer_outs[0][1].shape\n'

In [10]:
def cam_from_out_vector(name, out, n_patches, resize=True, verbose=False):
    import math
    print name, out.shape
    assert len(out.shape) == 3
    num_examples = len(out)
    roi_size = int(math.sqrt(out.shape[1]/n_patches))
    out_reshaped = out.reshape([num_examples, n_patches, roi_size, roi_size, 1])
    print out_reshaped.shape
    batch_heat_patches = []
    for i in range(num_examples): #num_examples
        example_heat_patches = []
        for p in range(n_patches):
            cam = out_reshaped[i,p,:,:,0]
            #heat_patches.append(cam)
            #INTER_NEAREST, INTER_LINEAR (default)
            cam_resized = cv2.resize(cam, (224, 224), interpolation=cv2.INTER_NEAREST) if resize else cam
            example_heat_patches.append( cam_resized )
        batch_heat_patches.append(example_heat_patches)
        if verbose:
            print [np.sum(s) for s in np.split(out_reshaped[i], n_patches, axis=0)]
    
    return np.array(batch_heat_patches)

def merge_heat_maps(heat_patches_dict, key1, key2):
    new_heat_patches = []
    for patch1, patch2 in zip(heat_patches_dict[key1], heat_patches_dict[key2]):
        new_heat_patches.append(patch1*patch2)
    return np.array(new_heat_patches)

def create_heatmap_dict(final_frames_layer_outs, resize, negative, n_patches):
    heat_patches_dict = {}
    for name, out in final_frames_layer_outs:
        #softmax , sigmoid, multiply, probs_tan
        if 'tan' in name:
            heat_patches_dict[name] = cam_from_out_vector(name, (out+1)/2.0, n_patches, resize=resize)
        elif 'sigmoid' in name:
            if negative:
                out = 1.0 - out
            heat_patches_dict[name] = cam_from_out_vector(name, out, n_patches, resize=resize)
        w,h = heat_patches_dict[name].shape[-2:]
        heat_patches_dict[name] = heat_patches_dict[name].reshape(-1, w, h)
        print heat_patches_dict[name].shape

    heat_patches_dict['final'] = merge_heat_maps(heat_patches_dict, *heat_patches_dict.keys())
    return heat_patches_dict

"""
resize = negative = False

heat_patches_dict = create_heatmap_dict(final_frames_layer_outs, resize, negative, n_patches)
"""

'\nresize = negative = False\n\nheat_patches_dict = create_heatmap_dict(final_frames_layer_outs, resize, negative, n_patches)\n'

In [11]:
def build_heatmap_mask(input_svs, loc_list, tile_size_dim_0, frame_heatmaps, output_dir=None):
    from openslide import OpenSlide as op
    from PIL import ImageDraw, Image
    from PIL import PngImagePlugin
    import cv2
    import os
    
    slide = op(input_svs) # Open slide, this is the slow part
    
    dim = slide.dimensions
    heatmap_shape = frame_heatmaps[0].shape
    
    # pixel in mask to fullsize image.
    resize_factor = (1.0 * tile_size_dim_0[0] / heatmap_shape[0], 1.0 * tile_size_dim_0[1] / heatmap_shape[1])
    
    mask_image_size = (int(dim[0] / resize_factor[0]), int(dim[1]/ resize_factor[1]))
    tile_size_mask = (tile_size_dim_0[0] / resize_factor[0], tile_size_dim_0[1] / resize_factor[1])
    
    print dim, tile_size_dim_0, heatmap_shape
    print resize_factor, mask_image_size
    
    heatmap_img = Image.new('L', mask_image_size, 0)
    draw = ImageDraw.Draw(heatmap_img)
    
    for f_id, (f_loc, f_draw) in enumerate(zip(loc_list, frame_heatmaps)):
        patch_dim_0_loc = f_loc
        
        patch_loc = patch_dim_0_loc[0] / resize_factor[0], patch_dim_0_loc[1] / resize_factor[1]
        patch_end_loc = patch_loc[0] + tile_size_mask[0], patch_loc[1] + tile_size_mask[1]
        
        patch_loc = (int)(patch_loc[0]), (int)(patch_loc[1])
        patch_end_loc = (int)(patch_end_loc[0]), (int)(patch_end_loc[1])
        
        crop_area = (patch_loc[0], patch_loc[1], patch_end_loc[0], patch_end_loc[1]) #( left, top, right, bottom )
        
        #heatmap
        # draw square
        #draw.rectangle(xy=[(crop_area[0],crop_area[1]),(crop_area[2],crop_area[3])], outline='black')
            
        rect_loc = (crop_area[0],crop_area[1])
        rect_size = (crop_area[2]-crop_area[0], crop_area[3]-crop_area[1]) #right-left, bottom-top
        f_img = cv2.resize(f_draw, rect_size)
        #print np.max(f_img), np.min(f_img)
        f_mask = Image.fromarray(np.uint8(f_img*255.0), mode='L')
        
        draw.bitmap(xy=rect_loc,bitmap=f_mask, fill=255) #fill=(5,82,5,255)) #'green'

    # Save mask to png file
    if output_dir:
        filename = os.path.basename(input_svs)[:-4]
        output_file = output_dir + '/' + filename + '.mask.png'
        meta_dict = {'svs-full-size': str(dim), 'resize-factor': str(resize_factor) }
        reserved = set(['interlace', 'gamma', 'dpi', 'transparency', 'aspect'])
        # undocumented class
        meta = PngImagePlugin.PngInfo()

        # copy metadata into new object
        for k,v in meta_dict.iteritems():
            if k in reserved: continue
            meta.add_text(k, v, 0)

        heatmap_img.save(fp=output_file,format='png',compress_level=4,pnginfo=meta)  #compression (0,9). Default 6. 9=highest & slowest
    
    return heatmap_img

"""
heatmap_img = build_heatmap_mask(input_svs, loc_list, tile_size_level_0, heat_patches_dict['final'][:], output_dir='/var/shared/hist_project/Data/logs')
heatmap_img
"""

"\nheatmap_img = build_heatmap_mask(input_svs, loc_list, tile_size_level_0, heat_patches_dict['final'][:], output_dir='/var/shared/hist_project/Data/logs')\nheatmap_img\n"

In [12]:
def build_final_mask_image(mask_image, heatmap_img):
    target_size = heatmap_img.size
    print target_size

    mask_image_resized = mask_image.resize(target_size)
    B, G, R = np.array(mask_image_resized), np.array(heatmap_img), np.zeros(target_size).T

    print R.shape, G.shape, B.shape
    rgb = np.uint8(np.stack([R, G, B], axis=2))
    rgbimg = PIL.Image.fromarray(rgb, mode='RGB')
    return rgbimg

"""
build_final_mask_image(mask_image, heatmap_img)
"""

'\nbuild_final_mask_image(mask_image, heatmap_img)\n'

In [23]:
def build_mask_from_svs_file(model, input_svs):
    batch_size = 100
    n_patches = 3
    tile_size=(224,224) 
    target_mpp=0.6
    resize = negative = False
    
    # Build tissue mask
    loc_list, target_level, tile_size_level_target, tile_size_level_0, mask_image, thumb_img, full_thumb_img = build_tissue_mask(input_svs, target_mpp=target_mpp, tile_size=tile_size)
    
    # Build patch batch gen
    
    # 1 Thread
    ####patch_gen = slide_generator(input_svs, loc_list, target_level, tile_size_level_0, tile_size_level_target, tile_size)
    ####batch_gen = batch_generator(patch_gen, n_patches, batch_size)
    # N Threads
    batch_gen = multi_thread_batch_gen(input_svs, loc_list, n_patches, batch_size, tile_size_level_0, tile_size, target_level, tile_size_level_target, max_workers=10)
    
    # Run Model on patches
    final_frames_layer_outs = get_final_frames_layer_outs(model, loc_list, batch_gen, batch_size, n_patches)
    
    # Create heatmap dict
    heatmap_dict = create_heatmap_dict(final_frames_layer_outs, resize, negative, n_patches)
    
    # Create heatmap mask
    heatmap_img = build_heatmap_mask(input_svs, loc_list, tile_size_level_0, heatmap_dict['final'][:])

    # Merge tissue mask and heatmap mask
    final_mask_img = build_final_mask_image(mask_image, heatmap_img)
    
    return final_mask_img

"""
input_svs = '/var/shared/zelda-tcga/cdbde7ab-3de0-40c9-a82c-0b40fba36a38/TCGA-S3-AA12-01Z-00-DX2.4F0A4F18-41C7-4497-A7B8-5DCE610E08AD.svs'
final_mask_img = build_mask_from_svs_file(model, input_svs)
final_mask_img
"""

"\ninput_svs = '/var/shared/zelda-tcga/cdbde7ab-3de0-40c9-a82c-0b40fba36a38/TCGA-S3-AA12-01Z-00-DX2.4F0A4F18-41C7-4497-A7B8-5DCE610E08AD.svs'\nfinal_mask_img = build_mask_from_svs_file(model, input_svs)\nfinal_mask_img\n"

In [14]:
# Find list of files to slice

import pandas as pd
import distutils.dir_util

target_mpp = 0.6
group_name = 'TCGA'

def find_file_recursive(folder, file_ext='*.svs'):
    import fnmatch
    import os

    matches = []
    for root, dirnames, filenames in os.walk(folder, followlinks=True):
        for filename in fnmatch.filter(filenames, file_ext):
            matches.append(os.path.join(root, filename))
    return matches

def findnth(haystack, needle, n):
    """Find Nth occurance of string in another string"""
    parts= haystack.split(needle, n+1)
    if len(parts)<=n+1:
        return -1
    return len(haystack)-len(parts[-1])-len(needle)

def prepare_data(level1_folder, group):
    import os
    data = {}
    svs_files = find_file_recursive(level1_folder, '*.svs')
    for full_filename in svs_files:
        filename = os.path.basename(full_filename)
        if group not in filename:
            print 'file without case, skipping : %s' % filename
            continue
        if group=='TCGA':
            case_barcode = filename[:findnth(filename,'-',2)]
        elif group=='GTEX':
            case_barcode = filename[:findnth(filename,'-',1)]
        else:
            raise Error('Unknown group name %s' % group)
        patient_slides = data.get(case_barcode,[])
        patient_slides.append(full_filename)
        data[case_barcode] = patient_slides
    return data

def filenames_no_ext(fullname_list):
    return [filename_no_ext(f) for f in fullname_list]
def filename_no_ext(f):
    return os.path.basename(f)[:-4]

# Flatten function
flatten = lambda l: [item for sublist in l for item in sublist]

data = prepare_data('/var/shared/zelda-tcga', group_name)

print 'patients: %d, slides: %d' % (len(data), len(flatten(data.values())))

files_to_slice = flatten([data[k] for k in set(data.keys())])
print 'files to slice : %d' % len(files_to_slice)

patients: 9565, slides: 26150
files to slice : 26150


In [15]:
import os

level2_folder = '/var/shared/hist_project/Data/Level2'
out_folder = '/var/shared/hist_project/Data/0.6mpp_masks'
distutils.dir_util.mkpath(out_folder) # Create folders if missing

print out_folder

files_tiff_exist = [os.path.basename(f)[:-9] for f in os.listdir(out_folder)]
files_tiff_exist = set(files_tiff_exist)
print 'patches exist (tif) : %d' % len(files_tiff_exist)

files_to_slice_new = [f for f in files_to_slice if filename_no_ext(f) not in files_tiff_exist]
print 'delta to slice: %d' % len(files_to_slice_new)

/var/shared/hist_project/Data/0.6mpp_masks
patches exist (tif) : 328
delta to slice: 25819


In [16]:
import pandas as pd

hit_map = {}

to_patch_csv = pd.read_csv('/home/maor/to_patch_df.csv')

to_patch_csv = to_patch_csv.loc[to_patch_csv['PrimarySite']=='Breast']
to_patch_csv = to_patch_csv.loc[to_patch_csv['Sample_type']=='Tumor']

to_patch_set = set(to_patch_csv['Slide_name'])
files_to_slice_filtered = []
for full_svs_file in files_to_slice_new:
    slide_name = os.path.basename(full_svs_file)[:-4]
    if slide_name in to_patch_set:
        files_to_slice_filtered.append(full_svs_file)
        if slide_name not in hit_map:
            hit_map[slide_name] = []
        hit_map[slide_name].append(full_svs_file)
        
dup_map = {k:v for k,v in hit_map.iteritems() if len(v)> 1}
print len(to_patch_set), len(files_to_slice_filtered)
#print dup_map # downloaded duplicate slides...
print 'Duplicate svs files: %d' % len(dup_map)

2697 2393
Duplicate svs files: 1


In [None]:
def mask_all_svs(model, svs_files, output_dir):
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    for svs_file in svs_files:
        mask_filename = os.path.basename(svs_file)[:-4] + '.mask.png'
        
        mask_img =  build_mask_from_svs_file(model, svs_file)
        
        output_file = output_dir + '/' + mask_filename
        mask_img.save(fp=output_file,format='png',compress_level=4)
        
mask_all_svs(model, files_to_slice_filtered, out_folder)

(536, 536) (108493, 88844) (224, 224)
(16, 16)
