# Compare SAM inference to pathologists

In [1]:
import pandas as pd
import numpy as np
import os
import os.path as osp
from pathlib import Path
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'


from histomics.utils.annotations.stats.annotation_stats import DatabaseAnnotationPuller
from histomics.data.constants import HIPE_OBSERVED_CLASSES
from histomics.metrics.segmentation_metrics import EnsembleDiceHoVerNet, PanopticQuality
from histomics.metrics.segmentation_nuclick_metrics import (
    InstanceBasedMaskMetricByClass,
    IoUByInstance,
)
from histomics.data.io.torch_dataset import TrainDataset

from histomics.data.transforms.label_transforms import (
    NuclickRandomCentroidLabelTransform,
)
from histomics.data.datasets.from_database.nuclick_validation_dataset import (
    NuclickValidationFromDataBase, NUCLICK_TILES_GT_QUERY
)
from histomics.engine.inference.nuclick.inferer_nuclick import NuclickInferer
from histomics.models.centroids_segmentation.nuclick import NuClick
from histomics.data.collate_fn import collate_fn_zip
from histomics.utils.annotations.stats.annotation_utils import plot_annotation_on_ax, plot_annotations_on_tile, plot_matching_annotators, plot_tile_with_multiple_annotators, list_to_single_annotations

from segment_anything.sam_histomics.compute_metrics import get_inferer, get_dataset_from_database, export_masks_to_dataframe_annotations, compute_metrics_on_outputs, convert_df_annotations_to_masks

  warn(


# Load pathologist data from the ground_truth_tiles DB

In [None]:
annot_puller = DatabaseAnnotationPuller(path_db="/home/owkin/project/database/ground_truth_tiles.db", remove_katharina_from_annotations=False)

df_annotations = annot_puller.get_annotations_full()

df_annotations["cell_type"] = df_annotations["cell_type"].apply(list_to_single_annotations)

df_annotations = df_annotations.loc[df_annotations["coordinates"].str.startswith("POLYGON")]

In [None]:
df_multiple_pathologists = annot_puller.get_multiple_annotated_tiles_df()

In [None]:
df_multiple_pathologists

In [None]:
df_annotations.loc[df_annotations["tile_id"] == 65331257]

In [None]:
df_multiple_pathologists = df_multiple_pathologists.loc[df_multiple_pathologists["nb_annotations"] > 1]

In [None]:
df_multiple_pathologists = df_multiple_pathologists.reset_index(drop=True)

In [None]:
df_crossed_annotations = annot_puller.retrieve_and_match_annotations(df_annotations, location_id=df_multiple_pathologists.loc[0, "location_id"], threshold_iou=0.5)

In [None]:
df_crossed_annotations

In [None]:
df_crossed_annotations

In [None]:
annotations_trialland = df_annotations.loc[df_annotations["cytomine_username"] == "trialland"]

In [None]:
annotations_good_tile = annotations_trialland.loc[annotations_trialland["tile_id"] == 97435628]

In [None]:
masks = convert_df_annotations_to_masks(annotations_good_tile)


### Compute metrics on the masks from T.Rialland

In [None]:
# Need to get dataset from Database

dataset = get_dataset_from_database()

metrics = {
    "ensemble_dice": EnsembleDiceHoVerNet(),
    # "IoU_by_class": InstanceBasedMaskMetricByClass(
    #     IoUByInstance, categories=HIPE_OBSERVED_CLASSES,
    # ),
    "Panoptic quality": PanopticQuality()
}


In [None]:
outputs_trialland = compute_masks_and_match_tile_ids(df_annotations, df_multiple_pathologists, "trialland")

In [None]:
outputs_trialland

In [None]:
trialland_metrics = compute_metrics_on_outputs(outputs_trialland, dataset, metrics, tile_by_tile=True)

In [None]:
outputs_trialland[67912873].shape

In [None]:
trialland_metrics.mean()

## Infer SAM on ground truth tiles

In [None]:
inferer = get_inferer()

dataset = get_dataset_from_database()

outputs_dataset = inferer.infer_on_dataset(dataset)

In [None]:
metrics = {
    "ensemble_dice": EnsembleDiceHoVerNet(),
    "IoU_by_class": InstanceBasedMaskMetricByClass(
        IoUByInstance, categories=[2],
    ),
    "Panoptic quality": PanopticQuality()
}

df_metrics = compute_metrics_on_outputs(outputs_dataset, dataset, metrics)

In [None]:
df_metrics

In [None]:
df_metrics.mean()

In [None]:
df_metrics.loc["3521541", "Panoptic quality"]

In [None]:
df_metrics.loc["3521541", "IoU_by_class"]

In [None]:
df_sam_annotations = export_masks_to_dataframe_annotations(outputs_dataset, dataset, model_name="saminference")

In [None]:
df_sam_annotations

In [None]:
df_sam_annotations.to_csv("/home/owkin/project/experiments/nuclick_sam_comp/sam_annotations.csv", index=False)

TODO:
- Remove tiles at the border
- 

In [None]:
df_annotations_with_sam = pd.concat([df_annotations, df_sam_annotations], axis=0, join="outer")

In [None]:
df_annotations_with_sam

In [None]:
list_location_ids = df_annotations_with_sam["location_id"].unique().tolist()

In [None]:
matching_sam_katharina = annot_puller.retrieve_and_match_annotations(df_annotations_with_sam, location_id=95145984, threshold_iou=0.0)

In [None]:
matching_sam_katharina.loc[matching_sam_katharina["username_matching"] == "kvonloga_saminference"]

In [None]:
for loc_id in list_location_ids:
    ax, all_labels = plot_tile_with_multiple_annotators(df_annotations_with_sam.loc[df_annotations_with_sam["location_id"] == loc_id])
    plt.legend(all_labels.values(), all_labels.keys(), bbox_to_anchor=(1.01, 1), loc="upper left")
    plt.show()

# Infer Nuclick as well on Katharinas annotations

In [None]:
# Instantiate Nuclick
root_db = "sqlite:////home/owkin/project/database/ground_truth_tiles.db"
device = "cuda"

metrics = {
    "ensemble_dice": EnsembleDiceHoVerNet(),
}

model = NuClick(nuclick_version="unet",
                weights_file="/home/owkin/project/experiments/nuclick/consep_pcns_train_unet_nuclick/checkpoint_epoch35.pth",
                device=device)

inferer = NuclickInferer(
    model, collate_fn_zip, batch_size=1, metrics=metrics, device=device,
    out_threshold=0.5, min_size=3
)

label_transform = NuclickRandomCentroidLabelTransform(0.1)

our_dataset = NuclickValidationFromDataBase(root_path=root_db, debug=False, sql_query=NUCLICK_TILES_GT_QUERY)

In [None]:
all_outputs_nuclick = {}

for i in range(our_dataset.dataframe.shape[0]):
    tile_id = our_dataset.dataframe.index[i]

    dataset_this_tile = TrainDataset(
        our_dataset,
        subset=[str(tile_id)],
        label_transform=label_transform,
    )

    outputs, _, _ = inferer.predict_and_score(
    dataset_this_tile, score=True, predict=True, batch_size=None
    )
    all_outputs_nuclick[tile_id] = (outputs[str(tile_id)]["masks"] > 0.5)

    

In [None]:
all_outputs_nuclick[3521541].max()

In [None]:
dataset_valid = TrainDataset(
    our_dataset,
    label_transform=label_transform,
)

df_annotations_nuclick = export_masks_to_dataframe_annotations(all_outputs_nuclick, dataset_valid, model_name="nuclick")

In [None]:
df_annotations_nuclick

In [None]:
df_annotations_nuclick.to_csv("/home/owkin/project/experiments/nuclick_sam_comp/nuclick_annotations.csv", index=False)

## Plot SAM and Nuclick annotations on ground truth tiles

In [None]:
sam_annotations = pd.read_csv("/home/owkin/project/experiments/nuclick_sam_comp/sam_annotations.csv")
nuclick_annotations =  pd.read_csv("/home/owkin/project/experiments/nuclick_sam_comp/nuclick_annotations.csv")

all_annotations = pd.concat([df_annotations, sam_annotations, nuclick_annotations], axis=0, join="outer")

In [None]:
df_annotations.loc[df_annotations["cytomine_username"] == "kvonloga"]

In [None]:
all_annotations

In [None]:
list_location_ids = sam_annotations["location_id"].unique().tolist()

In [None]:
list_location_ids

In [None]:
all_annotations["location_id"].value_counts()

In [None]:
matching_sam_katharina = annot_puller.retrieve_and_match_annotations(df_annotations=all_annotations, location_id=95145984, threshold_iou=0.0)

In [None]:
matching_sam_katharina["username_matching"].value_counts()

In [None]:
for loc_id in list_location_ids:
    ax, all_labels = plot_tile_with_multiple_annotators(all_annotations.loc[all_annotations["location_id"] == loc_id])
    plt.legend(all_labels.values(), all_labels.keys(), bbox_to_anchor=(1.01, 1), loc="upper left")
    plt.show()

In [None]:
all_annotations.loc[all_annotations["tile_id"] == 3521541]

In [None]:
annot_trialland = all_annotations.loc[all_annotations["tile_id"] == 95242592].copy()

In [None]:
ax = plot_annotations_on_tile(annot_trialland, draw_tile_box=True)

ax.set_title("Tile 95242592 - T.Rialland")