## Manual labeling for 

* Lipids 6 ml 055
* Lipids 30 ml 067
* IgG 1 ml 012
* IgG 20 ml 033
* IgG 30 ml 02

In [None]:
from src.filepath_util import read_image, read_masks_for_image
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

%matplotlib inline

In [None]:
image_filepaths = [
  'dataset/1000/Lipids 6ml 1000x_055.tif',
  # 'dataset/1000/Lipids 30ml 1000x_067.tif',
  # 'dataset/1000/IgG 1 ml 1000x_012.tif',
  # 'dataset/1000/IgG 20 ml 1000x_033.tif',
  # 'dataset/1000/RBCs_IgG 30 ml 1000x_2.tif',
]

In [None]:


from consts import LABEL_UNLABELED, LABEL_WRONG, MASK_ID_COLUMN, Y_COLUMN


LABELS = {
    LABEL_UNLABELED: {"color": np.array([204, 0, 255])},
    LABEL_WRONG: {"color": np.array([203, 255, 0])},
    "red blood cell": {"color": np.array([255, 0, 0])},
    "spheroid cell": {"color": np.array([0, 255, 102])},
    "echinocyte": {"color": np.array([0, 101, 255])},
}

def create_color_by_mask_id(labels_df):
    color_by_mask_id = dict()
    for _, row in labels_df.iterrows():
        mask_id, label = row[MASK_ID_COLUMN], row[Y_COLUMN]
        if label in [LABEL_UNLABELED, LABEL_WRONG]:
            continue
        color_by_mask_id[mask_id] = LABELS[label]["color"]

    return color_by_mask_id


In [None]:

def read_manual_labels(image_filepath):
    labels_filepath = os.path.splitext(image_filepath)[0] + '_manual_labels.csv'
    labels = pd.read_csv(labels_filepath, header=0, index_col=0)
    return labels

In [None]:

from src.draw_util import MasksColorOptions, get_masks_img
from src.filepath_util import read_masks_features


image_filepath = image_filepaths[0]

image = read_image(image_filepath)
masks = read_masks_for_image(image_filepath)
labels_df = read_masks_features(image_filepath)

color_by_mask_id = create_color_by_mask_id(labels_df)
image = get_masks_img(masks, image, MasksColorOptions.BY_LABEL, color_by_mask_id)

plt.imshow(image)

In [None]:
manual_labels_df = read_manual_labels(image_filepath)
manual_labels_df.head(20)

In [None]:

# Create figure and axes
plt.figure(figsize=(20,20))

# Display the image
plt.imshow(image)

# Plot markers from the DataFrame
plt.scatter(manual_labels_df['X'], manual_labels_df['Y'], c=manual_labels_df['Counter'], cmap=plt.cm.bwr.reversed(), marker='o')

plt.show()

In [None]:

from src.filepath_util import get_rel_filepaths_from_subfolders
import torch
from segment_anything import SamPredictor, sam_model_registry


DEVICE = "cuda"
RESNET_BATCH_SIZE = 64

SAM_CHECKPOINTS_FOLDER = os.path.normpath("./model/sam/")
SAM_CHECKPOINT_FILEPATHS = get_rel_filepaths_from_subfolders(
    folder_path=SAM_CHECKPOINTS_FOLDER, extension="pth"
)

def sam_model_version(sam_checkpoint_filepath):
    if "sam_vit_b" in sam_checkpoint_filepath:
        return "vit_b"
    if "sam_vit_h" in sam_checkpoint_filepath:
        return "vit_h"
    if "sam_vit_l" in sam_checkpoint_filepath:
        return "vit_l"

    return None

if DEVICE == "cuda":
    torch.cuda.empty_cache()

sam_checkpoint_filepath = SAM_CHECKPOINT_FILEPATHS[0]
sam = sam_model_registry[sam_model_version(sam_checkpoint_filepath)](
    checkpoint=sam_checkpoint_filepath
)
sam.to(device=DEVICE)

predictor = SamPredictor(sam)
# predictor.set_image(read_image(image_filepath))
predictor.set_image(image)

In [None]:
x_list = manual_labels_df['X'].tolist()
y_list = manual_labels_df['Y'].tolist()
labels = manual_labels_df['Counter'].tolist()

masks = []

for i, (x, y) in enumerate(zip(x_list, y_list)):
    input_points = [[x, y]]
    input_labels = [1]

    for other_i, (other_x, other_y) in enumerate(zip(x_list, y_list)):
        if other_i == i:
            continue
        # if abs(x - other_x) < 250 and abs(y - other_y) < 250:
        #     input_points.append([other_x, other_y])
        #     input_labels.append(0)

    mask, _, _ = predictor.predict(
        point_coords=np.array(input_points), point_labels=np.array(input_labels), multimask_output=True
    )

    masks.append(mask)


In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

k = 15
for i in range(3):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    # plt.imshow(image)
    show_mask(masks[k][i], plt.gca(), random_color=True)
    plt.scatter(manual_labels_df.iloc[k]['X'], manual_labels_df.iloc[k]['Y'], c=manual_labels_df.iloc[k]['Counter'], cmap=plt.cm.bwr.reversed(), marker='o')
    plt.axis('off')
    plt.show()