In [5]:
import cv2
import numpy as np
from cellmask_model import CellMaskModel
import os
import matplotlib.pyplot as plt
from data import import_images

In [47]:
def get_encFeats(model,image):
    pad_val = model.expand_div_256(image)[1]
    cp, mask, instance_mask, encFeats_cp, encFeats_mask = model.get_pred(image,0,encFeats=True)
    return encFeats_cp, encFeats_mask, instance_mask, pad_val

def get_encFeats_flat(model,image,type='cp'):
    encFeats_cp, encFeats_mask, instance_mask, pad_val = get_encFeats(model,image)

    if type=='cp':
        encFeats = encFeats_cp
    elif type=='mask':
        encFeats = encFeats_mask

    encFeats_arr = []
    for enc in encFeats:
        arr = enc.detach()
        arr = np.expand_dims(arr,0)
        arr = np.expand_dims(arr,0)
        encFeats_arr.append(arr)

    stacked = model.stack_img(encFeats_arr,colrow=64)
    res = cv2.resize(stacked, dsize=(1280, 1280), interpolation=cv2.INTER_CUBIC)
    res = res[pad_val:-pad_val, pad_val:-pad_val]
    return res, instance_mask

def get_instance_encFeats(model,image,type='cp'):
    encFeats_flat, instance_mask = get_encFeats_flat(model,image,type=type)

    instance_encFeats = []
    for i in range(1,max(instance_mask.flatten())+1):
        instance_encFeats.append(np.delete(encFeats_flat,np.where(instance_mask.flatten()!=i)))
    return np.array(instance_encFeats)

def resize_arrays_to_fit_another(arrays_to_resize,array):
    arrays_shaped = []
    for i in arrays_to_resize:
        if i.shape < array.shape:
            pad_by = array.shape[0]-i.shape[0]
            new_arr = np.pad(i, (0, pad_by), 'constant')
            arrays_shaped.append(new_arr)
        elif i.shape > array.shape:
            arrays_shaped.append(i[:array.shape[0]])
        else:
            arrays_shaped.append(i)
    arrays_shaped = np.array(arrays_shaped)
    return arrays_shaped


def get_centers_of_ROIs(instance_mask):
    centers_instance_mask = []
    for i in range(1,max(instance_mask.flatten())+1):
        cell =  np.array(np.where(instance_mask == i,1,0))
        M = cv2.moments(np.float32(cell))
        center = (int(M["m10"] / M["m00"]),int(M["m01"] / M["m00"]))
        centers_instance_mask.append(center)
    return centers_instance_mask

def get_distance_between_cells(initial_cell,all_cells,radius=1000):
    all_cells_shaped = resize_arrays_to_fit_another(all_cells,initial_cell)
    cos_sim = np.dot(initial_cell, all_cells_shaped.T)/(np.linalg.norm(all_cells_shaped)*np.linalg.norm(initial_cell))

    #use the radius to limit the possible cells,
    #maybe input into the function an array of all the cells but just their distance to the others

    return cos_sim

def get_center_distances(cell_centers,cell_num):
    center_distances = []
    a = np.array(cell_centers[cell_num])
    for cell_center in cell_centers:
        b = np.array(cell_center)
        center_distances.append(np.linalg.norm(a-b))
    return center_distances


In [33]:
model = CellMaskModel()
model.import_model(os.getcwd() + '/saved_weights/cp_model', os.getcwd() + '/saved_weights/mask_model')

images_path_1059 = os.getcwd() + '/images_1059_0/'
images_1059 = import_images(images_path_1059,normalisation=True,num_imgs=5,format='.tiff')

instance_encFeats_0 = get_instance_encFeats(model,images_1059[0],type='cp')
instance_encFeats_1 = get_instance_encFeats(model,images_1059[1],type='cp')

first_cell = instance_encFeats_0[0]
cos_sim_index = get_distance_between_cells(first_cell,instance_encFeats_1)

25
(1, 1, 256, 256)


  return np.array(instance_encFeats)


25
(1, 1, 256, 256)


In [36]:
cp, mask, instance_mask, encFeats_cp, encFeats_mask = model.get_pred(images_1059[0],channel=0,encFeats=True)

25
(1, 1, 256, 256)


In [48]:
first_cell = instance_encFeats_0[0]

cos_sim_index = get_distance_between_cells(first_cell,instance_encFeats_1)
cell_centers = get_centers_of_ROIs(instance_mask[0])

In [46]:
instance_mask_0 = instance_mask[0]

centers_instance_mask = []
for i in range(1,max(instance_mask_0.flatten())+1):
    cell = np.where(instance_mask_0 == i,1,0)
    M = cv2.moments(np.float32(cell))
    center = (int(M["m10"] / M["m00"]),int(M["m01"] / M["m00"]))
    centers_instance_mask.append(center)
#return centers_instance_mask


In [34]:
cos_sim_index

array([ 1.16496146e-01,  1.65883161e-04,  1.90051906e-02,  3.65387648e-03,
        4.30590953e-05,  3.22814323e-02,  2.58465745e-02,  1.15337325e-02,
        1.97315094e-04,  1.25500098e-01,  4.00840417e-02, -5.21657967e-05,
        2.56426167e-02,  1.90165807e-02,  2.21931897e-02,  1.97097901e-02,
        1.10066291e-02, -2.19487512e-04,  4.08802507e-03,  3.74424784e-03,
        1.41716609e-02,  6.16137274e-02,  6.68611079e-02,  1.05067826e-04,
        1.30939558e-02,  3.77351567e-02,  1.88660529e-02,  1.48693388e-02,
        4.72209789e-03,  2.71877218e-02, -2.25122808e-03,  1.21671811e-03,
       -3.73891066e-03,  1.60924103e-02,  2.23307237e-02,  1.21407537e-02,
        4.09740843e-02,  1.90758597e-04,  1.17141111e-02,  5.57612367e-02,
        2.04655733e-02,  3.69997183e-03,  2.43717896e-05,  2.48461642e-04,
        1.51670454e-02,  2.58691758e-02,  4.02631462e-02,  7.19938660e-03,
        5.70922093e-05,  3.86615433e-02,  3.27468365e-02,  3.61710414e-02,
       -7.09031709e-04,  

In [28]:
cos_sim = instance_encFeats_1[np.argmax(np.dot(instance_encFeats_0[0], instance_encFeats_1.T)/(np.linalg.norm(instance_encFeats_1)*np.linalg.norm(instance_encFeats_0[0])))]

ValueError: shapes (517,) and (197,) not aligned: 517 (dim 0) != 197 (dim 0)