This notebook will will give details on how to evaluate nuclear segmentation models using REET2.0 on Google Colab. First, you must add the toolbox to your Google drive.

Mount your drive, ensure you are using a GPU. (Go to Runtime/Change Runtime Type/ and select a hardware accelerator)

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Move to the parent directory. This will depend on where you saved your toolbox.

In [None]:
cd /content/drive/MyDrive/Final_Toolbox


We will first demonstrate how to evaluate semantic segmentation models. We must first load the model. You can select one of the models that we trained for experiments from below.

In [None]:
import torch
import torch.nn as nn

#U-Net RA
# model = torch.load('segmentation/output/U-Net_NA.pth')

#U-Net NA
model = torch.load('segmentation/output/U-Net_RA.pth')

#U-Net Pix
# model = torch.load('segmentation/output/U_Net_Pix.pth')


Define the list of transforms you wish to evalutate

In [None]:
transform_list = ["Pixel","Rotate", "Zoom Out", "HED Stain", "Mean", "Zoom In", "Crop", "Stain", "Blur"]
#transform_list = ["Pixel"]

The script below is the one we used to evaluate our U-Net.

reetoolbox.evaluator.Semantic_Segmentation_Evaluator takes a model, a list of directories of test data, a transform, and the output shape of the model. This initalises the evaluator for a given transform. The transform parameters can be changed in reetoolbox/constants.py.


I have included a small sample of images in Consep_patches_540x540 to use for evaluation. The toolbox expects an np array for each image that includes the image, inst map and type map, in that order.

We then use eval.predict to perform the evaluation. Adersarial should be set to true. To display the models performance on each transform, set display=True. The display can be buggy in Colab, if so, you can change the evalutator code to use plt.savefig instead of plt.show. Scale_perturbation will scale the pixel intensities to use the full pixel range, this is neeeded to see the Pixel perturbations.

predict returns two dictionaries, one for the evaluation on the vanilla images and one for the evaluation on the transformed images. All metrics are calculated over the entire dataset by calculating TP, FP and FN for the entire dataset.

Note: If you display the images, you may wish to restrict the number of images you are evalutating, as displaying will slow it down dramatically. I like to add a breakpoint to the display section of the predict function when I use display.

In [None]:
import reetoolbox.evaluator

for transform in transform_list:
    print(transform)
    eval = reetoolbox.evaluator.Semantic_Segmentation_Evaluator(model, ["Consep_patches_540x540/valid/540x540_164x164"], transform, evaluation_shape=(256,256))
    orig_results_dict, adv_results_dict = eval.predict(adversarial=True, display=False, scale_perturbation=True)

    #SEMANTIC SEGMENTATION METRICS
    orig_ss_iou = orig_results_dict["pixel_wise_nucleus_IoU"]
    adv_ss_iou = adv_results_dict["pixel_wise_nucleus_IoU"]

    orig_ss_dice =  orig_results_dict["pixel_wise_nucleus_Dice"]
    adv_ss_dice = adv_results_dict["pixel_wise_nucleus_Dice"]

    orig_ss_acc =  orig_results_dict["all_type_pixel_accuracy"]
    adv_ss_acc = adv_results_dict["all_type_pixel_accuracy"]

    orig_ss_type_IoU = orig_results_dict["pixel_wise_type_IoU"]
    adv_ss_type_IoU = adv_results_dict["pixel_wise_type_IoU"]

    orig_ss_type_dice = orig_results_dict["pixel_wise_type_Dice"]
    adv_ss_type_dice = adv_results_dict["pixel_wise_type_Dice"]



    print("SEMANTIC SEGMENTATION METRICS")
    print(f"original pixel_wise nucleus pixel IoU: {orig_ss_iou}")
    print(f"adversarial pixel_wise nucleus IoU: {adv_ss_iou}")
    print("")
    print(f"original pixel_wise nucleus pixel dice {orig_ss_dice}")
    print(f"adversarial pixel_wise dice {adv_ss_dice}")
    print("")
    print(f"original pixel_wise classification Iou for each type: {orig_ss_type_IoU}")
    print(f"adversarial Iou for each type: {adv_ss_type_IoU}")
    print("")
    print(f"original pixel_wise classification dice for each type: {orig_ss_type_dice}")
    print(f"adversarial pixel_wise dice for each type: {adv_ss_type_dice}")

    #THE MEAN TYPE_DICE IS THE ONE WE USED FOR OUR EXPERIMENTS
    mean_dice = (adv_ss_type_dice[2] + adv_ss_type_dice[3] + adv_ss_type_dice[4])/3
    print(f"mean type dice: {mean_dice}")
    print("")
    print("")
    print("")

To evaluate HoVer-Net, we first need to load a model. Note that the way you load your weights may differ if you did not use parallelisation during your model training.




In [None]:
import torch
import torch.nn as nn

#HoVer-Net NA
# weights = torch.load("nuc_inst_segmentation/output/HoVer-Net_NA.tar")["desc"]

#HoVer-Net RA
weights = torch.load("nuc_inst_segmentation/output/HoVer-Net_RA.tar")["desc"]

#HoVer-Net Pix
# weights = torch.load("nuc_inst_segmentation/output/HoVer-Net_Pix.tar")["desc"]


#LOAD THE WEIGHTS INTO THE MODEL
from nuc_inst_segmentation.hovernet.net_desc import create_model

net_desc = create_model(input_ch=3, 
              nr_types=5, 
              freeze=False,
              mode="original").to("cuda")

from nuc_inst_segmentation.hovernet.utils import convert_pytorch_checkpoint

weights = convert_pytorch_checkpoint(weights)

load_feedback = net_desc.load_state_dict(weights, strict=False)


Now we can evaluate the HoVer-Net model. The HoVer_evaluator and predict work analagously to the Semantic_Segmentation_Evaluator.

In [None]:
import reetoolbox.evaluator

for transform in transform_list:
    print(transform)
    eval = reetoolbox.evaluator.HoVer_Evaluator(net_desc, ["Consep_patches_540x540/valid/540x540_164x164"], transform)
    orig_results_dict, adv_results_dict = eval.predict(adversarial=True, display=False, scale_perturbation=True)


    #SEMANTIC SEGMENTATION METRICS
    orig_ss_iou = orig_results_dict["pixel_wise_nucleus_IoU"]
    adv_ss_iou = adv_results_dict["pixel_wise_nucleus_IoU"]

    orig_ss_dice =  orig_results_dict["pixel_wise_nucleus_Dice"]
    adv_ss_dice = adv_results_dict["pixel_wise_nucleus_Dice"]

    orig_ss_acc =  orig_results_dict["all_type_pixel_accuracy"]
    adv_ss_acc = adv_results_dict["all_type_pixel_accuracy"]

    orig_ss_type_IoU = orig_results_dict["pixel_wise_type_IoU"]
    adv_ss_type_IoU = adv_results_dict["pixel_wise_type_IoU"]

    orig_ss_type_dice = orig_results_dict["pixel_wise_type_Dice"]
    adv_ss_type_dice = adv_results_dict["pixel_wise_type_Dice"]

    mean_dice = (adv_ss_type_dice[2] + adv_ss_type_dice[3] + adv_ss_type_dice[4])/3
    print(f"mean type dice: {mean_dice}")

    #NUCLEUS INSTANCE SEGMENTATION METRICS
    orig_inst_dice = orig_results_dict["instance_wise_Dice"]
    adv_inst_dice = adv_results_dict["instance_wise_Dice"]

    orig_inst_sq = orig_results_dict["instance_wise_segmentation_quality"]
    adv_inst_sq = adv_results_dict["instance_wise_segmentation_quality"]

    orig_inst_pq = orig_results_dict["instance_wise_panoptic_quality"]
    adv_inst_pq = adv_results_dict["instance_wise_panoptic_quality"]

    #NUCLEUS INSTANCE TYPE CLASSIFICATION METRICS
    orig_inst_type_accuracy = orig_results_dict["instance_wise_type_classification_accuracy"]
    adv_inst_type_accuracy = adv_results_dict["instance_wise_type_classification_accuracy"]


    print("SEMANTIC SEGMENTATION METRICS")
    print(f"original pixel_wise nucleus pixel IoU: {orig_ss_iou}")
    print(f"adversarial pixel_wise nucleus IoU: {adv_ss_iou}")
    print("")
    print(f"original pixel_wise nucleus pixel dice {orig_ss_dice}")
    print(f"adversarial pixel_wise dice {adv_ss_dice}")
    print("")
    print(f"original pixel_wise accuracy {orig_ss_acc}")
    print(f"adversarial pixel_wise accuracy {adv_ss_acc}")
    print("")
    print(f"original pixel_wise classification Iou for each type: {orig_ss_type_IoU}")
    print(f"adversarial Iou for each type: {adv_ss_type_IoU}")
    print("")
    print(f"original pixel_wise classification dice for each type: {orig_ss_type_dice}")
    print(f"adversarial pixel_wise dice for each type: {adv_ss_type_dice}")
    print("")
    print("")
    print("")
    print("NUCLEUS INSTANCE SEGMENTATION METRICS")
    print(f"original instance_wise dice: {orig_inst_dice}")
    print(f"adversarial instance_wise dice: {adv_inst_dice}")
    print("")
    print(f"original instance_wise segmentation quality: {orig_inst_sq}")
    print(f"adversarial instance_wise segmentation quality: {adv_inst_sq}")
    print("")
    print(f"original instance_wise panoptic quality: {orig_inst_pq}")
    print(f"adversarial instance_wise panoptic quality: {adv_inst_pq}")
    print("")
    print("")
    print("")
    print("NUCLEUS INSTANCE TYPE CLASSIFICATION")
    print(f"original instance_wise type classification accuracy: {orig_inst_type_accuracy}")
    print(f"adversarial instance_wise type classification accuracy {adv_inst_type_accuracy}")
    print("")
    print("")
    print("")
    print("")
    print("")
    print("")
    print("")