In [1]:
import cv2
import numpy as np
from cellmask_model import CellMaskModel
import os
import matplotlib.pyplot as plt
from data import import_images

  from .autonotebook import tqdm as notebook_tqdm


In [121]:
def get_encFeats_flat(model,image,type='cp'):
    pad_val = model.expand_div_256(image)[1]
    cp, mask, instance_mask, encFeats_cp, encFeats_mask = model.get_pred(image,0,encFeats=True)

    if type=='cp':
        encFeats = encFeats_cp
    elif type=='mask':
        encFeats = encFeats_mask

    encFeats_arr = []
    for enc in encFeats:
        arr = enc.detach()
        arr = np.expand_dims(arr,0)
        arr = np.expand_dims(arr,0)
        encFeats_arr.append(arr)

    stacked = model.stack_img(encFeats_arr,colrow=64)
    res = cv2.resize(stacked, dsize=(1280, 1280), interpolation=cv2.INTER_CUBIC)
    res = res[pad_val:-pad_val, pad_val:-pad_val]
    return res, instance_mask

def get_encFeats_flat_from_pred(model,instance_mask,encFeats,type='cp'):
    pad_val = model.expand_div_256(instance_mask)[1]

    encFeats_arr = []
    for enc in encFeats:
        arr = enc.detach()
        arr = np.expand_dims(arr,0)
        arr = np.expand_dims(arr,0)
        encFeats_arr.append(arr)

    stacked = model.stack_img(encFeats_arr,colrow=64)
    res = cv2.resize(stacked, dsize=(1280, 1280), interpolation=cv2.INTER_CUBIC)
    res = res[pad_val:-pad_val, pad_val:-pad_val]

    return res, instance_mask

def get_centers_of_ROIs(instance_mask):
    centers_instance_mask = []
    for i in range(1,max(instance_mask.flatten())+1):
        cell =  np.array(np.where(instance_mask == i,1,0))
        M = cv2.moments(np.float32(cell))
        center = (int(M["m10"] / M["m00"]),int(M["m01"] / M["m00"]))
        centers_instance_mask.append(center)
    return centers_instance_mask

def get_instance_encFeats_from_img(model,image,type='cp'):
    encFeats_flat, instance_mask = get_encFeats_flat(model,image,type=type)

    instance_encFeats = []
    for i in range(1,max(instance_mask.flatten())+1):
        instance_encFeats.append(np.delete(encFeats_flat,np.where(instance_mask.flatten()!=i)))

    cell_centers = get_centers_of_ROIs(instance_mask)
    return np.array(instance_encFeats), cell_centers

def get_instance_encFeats_from_encFeats(encFeats_flat,instance_mask,type='cp'):

    instance_encFeats = []
    for i in range(1,max(instance_mask.flatten())+1):
        first = instance_mask.flatten()
        second = encFeats_flat.flatten()
        mask = first == i
        rn = np.take(second, np.where(mask)[0])
        rn = np.extract(rn != 0, rn)
        instance_encFeats.append(rn)

    cell_centers = get_centers_of_ROIs(instance_mask)
    return np.array(instance_encFeats), cell_centers

def resize_arrays_to_fit_another(arrays_to_resize,array):
    arrays_shaped = []
    for i in arrays_to_resize:
        if i.shape < array.shape:
            pad_by = array.shape[0]-i.shape[0]
            new_arr = np.pad(i, (0, pad_by), 'constant')
            arrays_shaped.append(new_arr)
        elif i.shape > array.shape:
            arrays_shaped.append(i[:array.shape[0]])
        else:
            arrays_shaped.append(i)
    arrays_shaped = np.array(arrays_shaped)
    return arrays_shaped

def get_distance_between_cells(initial_cell_center,all_cell_centers):
    center_distances = []
    a = np.array(initial_cell_center)
    for cell_center in all_cell_centers:
        b = np.array(cell_center)
        center_distances.append(np.linalg.norm(a-b))
    return center_distances

def get_matching_cells(initial_cell,all_cells,cell_distances,radius=1000):
    all_cells_shaped = resize_arrays_to_fit_another(all_cells,initial_cell)
    cos_sim = np.dot(initial_cell, all_cells_shaped.T)/(np.linalg.norm(all_cells_shaped)*np.linalg.norm(initial_cell))

    #use the radius to limit the possible cells,
    #maybe input into the function an array of all the cells but just their distance to the others
    cell_distances = np.array(cell_distances)
    possible_matches = np.where(cell_distances < radius,cos_sim,0)

    match_index = np.argmax(possible_matches)
    #return idnex of max value in possible_matches

    return match_index


In [75]:
model = CellMaskModel()
model.import_model(os.getcwd() + '/saved_weights/cp_model', os.getcwd() + '/saved_weights/mask_model')

images_path_1059 = os.getcwd() + '/images_1059_0/'
images_1059 = import_images(images_path_1059,normalisation=True,num_imgs=5,format='.tiff')

instance_encFeats_0, cell_centers_0 = get_instance_encFeats_from_img(model,images_1059[0],type='cp')
instance_encFeats_1, cell_centers_1 = get_instance_encFeats_from_img(model,images_1059[1],type='cp')

first_cell = instance_encFeats_0[0]
first_cell_center = cell_centers_0[0]
print(first_cell_center)
distance_between_cells_from_first = get_distance_between_cells(first_cell_center,cell_centers_1)
possible_matches = get_matching_cells(first_cell,instance_encFeats_1,distance_between_cells_from_first,radius=400)
print(possible_matches)

25
(1, 1, 256, 256)


  return np.array(instance_encFeats), cell_centers


25
(1, 1, 256, 256)
(518, 14)
9


In [21]:
#Get all the model predictions on all the images
model = CellMaskModel()
model.import_model(os.getcwd() + '/saved_weights/cp_model', os.getcwd() + '/saved_weights/mask_model')

images_path_1059 = os.getcwd() + '/images_1059_0/'
images_1059 = import_images(images_path_1059,normalisation=True,num_imgs=5,format='.tiff')

cps, masks, instance_masks, encFeats_cps, encFeats_masks = model.eval(images_1059,0,encFeats=True)
#Get all the matches of cells
#Maybe average out the encFeats_cp and encFeats_mask closeness matches and only then get the index

25
(1, 1, 256, 256)
25
(1, 1, 256, 256)
25
(1, 1, 256, 256)
25
(1, 1, 256, 256)
25
(1, 1, 256, 256)


In [123]:
for i in range(1,np.max(instance_masks[0]+1)):
    print(i)
    mask_of_one_cell = np.where(instance_masks[0] == i,1,0)
    instance_cells_to_match_with = instance_masks[1]
    encFeats_cp_flat_0, instance_mask_0 = get_encFeats_flat_from_pred(model,instance_masks[0],encFeats_cps[0])
    encFeats_cp_flat_1, instance_mask_1 = get_encFeats_flat_from_pred(model,instance_masks[1],encFeats_cps[1])
    instance_encFeats_0, cell_centers_0 = get_instance_encFeats_from_encFeats(encFeats_cp_flat_0,instance_masks[0],type='cp')
    instance_encFeats_1, cell_centers_1 = get_instance_encFeats_from_encFeats(encFeats_cp_flat_1,instance_masks[1],type='cp')
    

1


  return np.array(instance_encFeats), cell_centers


2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87


In [42]:
instance_encFeats_0, cell_centers_0 = get_instance_encFeats_from_encFeats(encFeats_cp_flat_0,instance_masks[0],type='cp')

  return np.array(instance_encFeats), cell_centers


In [47]:
import time

start = time.time()
time.sleep(1)
time.time() - start 

1.0062534809112549

In [67]:
res, instance_mask = get_encFeats_flat_from_pred(model,instance_masks[0],encFeats_cps[0])

In [119]:
encFeats_cp_flat_0, cell_centers = get_encFeats_flat_from_pred(model,instance_masks[0],encFeats_cps[0])

instance_encFeats = []
start = time.time()
for i in range(1,max(instance_masks[0].flatten())+1):
    #indexes_of_not_cell = np.squeeze(np.array(np.where(instance_masks[0].flatten()!=i)))
    #print(indexes_of_not_cell.shape)

    #print('ggg',encFeats_cp_flat_0[0].shape,indexes_of_not_cell.shape)
    #rn = np.delete(encFeats_cp_flat_0,indexes_of_not_cell)
    
    #WORKING
    first = instance_masks[0].flatten()
    second = encFeats_cp_flat_0.flatten()
    mask = first == i
    rn = np.take(second, np.where(mask)[0])
    rn = np.extract(rn != 0, rn)

    #print(encFeats_cp_flat_0[0].shape)
    #rn = np.where(instance_masks[0].flatten()!=i,encFeats_cp_flat_0[0].flatten(),0)
    #print(rn)

    instance_encFeats.append(rn)

    
print(time.time()-start)
#cell_centers = get_centers_of_ROIs(instance_masks[0])

1.79819917678833


In [35]:
print(type(encFeats_cps[0][0]))

<class 'torch.Tensor'>


In [33]:
arr = encFeats_cps[0]
arr.shape

AttributeError: 'list' object has no attribute 'detach'

In [None]:
plt.imshow()

In [70]:
possible_matches = get_matching_cells(first_cell,instance_encFeats_1,distance_between_cells_from_first,radius=200)
print(possible_matches)

[ 1.16496146e-01  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  1.15337325e-02
  0.00000000e+00  0.00000000e+00  0.00000000e+00 -5.21657967e-05
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00 -2.19487512e-04  0.00000000e+00  0.00000000e+00
  1.41716609e-02  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000

In [48]:
first_cell = instance_encFeats_0[0]

cos_sim_index = get_distance_between_cells(first_cell,instance_encFeats_1)
cell_centers = get_centers_of_ROIs(instance_mask[0])

In [53]:
cos_sim_index.shape

(197,)

In [28]:
cos_sim = instance_encFeats_1[np.argmax(np.dot(instance_encFeats_0[0], instance_encFeats_1.T)/(np.linalg.norm(instance_encFeats_1)*np.linalg.norm(instance_encFeats_0[0])))]

ValueError: shapes (517,) and (197,) not aligned: 517 (dim 0) != 197 (dim 0)