This notebook creates a .h5 file containing labeled patches centered around nuclei

In [1]:
from instanseg.utils.data_download import create_raw_datasets_dir, create_processed_datasets_dir, download_and_extract
from pathlib import Path
import numpy as np
import fastremap
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm
import torch

import matplotlib.pyplot as plt

from PIL import Image, ImageDraw

from instanseg.utils.utils import show_images, _move_channel_axis
# #aws s3 cp --no-sign-request s3://monkey-training/ ./ --recursive
monkey_dir = Path("../data") # Change this to the correct path to the root data 

files = sorted([
    f for f in os.listdir(monkey_dir / "annotations" / "xml")
    if f.startswith("A_") # Only take files from the first centre: centre A
    # if f.endswith(".xml") and not f.startswith(".") # Exclude non-xml files like the hidden .ipynb-checkpoints file
    if f.endswith("01.xml") and not f.startswith(".") # To test, only take the first image
]) # File path to xml folder of annotations

print(files)
label_ids = []
means_list = []
annotations_dict = {}


np.random.seed(0) # Add random seed for reproducability

for file in tqdm(files): # Loop through annotations

    split = np.random.choice(["train", "val"], p=[0.8, 0.2])

    # Get filepath of image corresponding to the annotation
    #print(file)
    img_pascpg_path = Path(monkey_dir) / ("images/pas-cpg/" + file.split(".")[0] + "_PAS_CPG.tif")
    img_pasdiagnostic_path = Path(monkey_dir) / ("images/pas-diagnostic/" + file.split(".")[0] + "_PAS_Diagnostic.tif")
   # img_pasoriginal_path = Path(monkey_dir) / ("images/pas-original/" + file.split(".")[0] + "_PAS_Original.tif")
    ihc_path = Path(monkey_dir) / ("images/ihc/" + file.split(".")[0] + "_IHC_CPG.tif")

    # Open the .tif files using TiffSlide
    from tiffslide import TiffSlide
    slidepascpg = TiffSlide(img_pascpg_path)
    slideihc = TiffSlide(ihc_path)

    # Get the actual filepath of the polygon annotation
    tree = ET.parse(monkey_dir/("annotations/xml/"+file))
    root = tree.getroot()  # Get the root of the XML

#     # if split == "val":
#     #     destination_img = "/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/validation_set/images/kidney-transplant-biopsy-wsi-pas/"
#     #     destination_mask = "/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/validation_set/images/tissue-mask/"
        
#     #     #move images to inference folder
#     #     import shutil
#     #     shutil.copy(monkey_dir / ("images/pas-cpg/" + file.split(".")[0] + "_PAS_CPG.tif"), destination_img)
#     #     shutil.copy(monkey_dir / ("images/tissue-masks/" + file.split(".")[0] + "_mask.tif"), destination_mask)
        
#     #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_inflammatory-cells.json"), 
#     #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

#     #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_lymphocytes.json"), 
#     #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

#     #     shutil.copy(monkey_dir / ("annotations/json/" + file.split(".")[0] + "_monocytes.json"), 
#     #     '/home/cdt/Documents/Projects/monkey-challenge-instanseg/evaluation/ground_truth')

    coords = []

    annotations_dict[file] = []

    # Iterate over each annotation and extract relevant information
    for annotation in root.findall('.//Annotation'):
        name = annotation.get('Name')
        part_of_group = annotation.get('PartOfGroup')
        _type = annotation.get('Type')
      
        if _type == "Polygon":
            coords_ROI = []
            for coordinate in annotation.findall('.//Coordinate'):
                x = float(coordinate.get('X'))
                y = float(coordinate.get('Y'))
                coords_ROI.append([x, y])

            coords_ROI = np.array(coords_ROI)

            x_min, y_min = coords_ROI.min(axis=0)
            x_max, y_max = coords_ROI.max(axis=0)
            bbox_width = int(x_max - x_min)
            bbox_height = int(y_max - y_min)

            # Read the bounding box from the slide
            rgb_data = slidepascpg.read_region(
                (int(x_min), int(y_min)),
                0,
                (bbox_width, bbox_height),
                as_array=True,
            )

            ihc_data = slideihc.read_region(
                (int(x_min), int(y_min)),
                0,
                (bbox_width, bbox_height),
                as_array=True,
            )


            mask = Image.new("L", (bbox_width, bbox_height), 0)
            polygon = coords_ROI - [x_min, y_min]  # Translate polygon to local bbox coordinates
            ImageDraw.Draw(mask).polygon(polygon.flatten().tolist(), outline=1, fill=1)
            # Convert the mask to a NumPy array
            binary_mask = np.array(mask)

            annotations_dict[file].append({ "split": split,
                                            "pas-cpg":rgb_data,
                                            "ihc":ihc_data,
                                            "polygon": coords_ROI, 
                                            "mask": binary_mask, 
                                            "bbox" : [x_min, y_min, x_max, y_max], 
                                            "dots" : []})

            #show_images(rgb_data)

    for annotation in root.findall('.//Annotation'):
        name = annotation.get('Name')
        part_of_group = annotation.get('PartOfGroup')
        _type = annotation.get('Type')
        
        if _type == "Dot":
            # Find the coordinates
            coordinates = annotation.find('.//Coordinate')
            x = int(float(coordinates.get('X')))
            y = int(float(coordinates.get('Y')))
            c = 0 if part_of_group == "lymphocytes" else 1

            for i,annotation in enumerate(annotations_dict[file]):
                if annotation["bbox"][0] < x < annotation["bbox"][2] and annotation["bbox"][1] < y < annotation["bbox"][3]:
                    annotations_dict[file][i]["dots"].append([y - annotation["bbox"][1] ,x - annotation["bbox"][0],c])

['A_P000001.xml']


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.44it/s]


In [2]:
leukocytes_dots = 0
detected_leukocytes = 0

import os
from pytorch_utils import get_masked_patches, _to_ndim, _to_tensor_float32
from inference_class import _rescale_to_pixel_size
# from instanseg.instanseg import _to_tensor_float32, _rescale_to_pixel_size
import torchstain
from instanseg import InstanSeg

# pixel_size_precision = 0.01
# def _rescale_to_pixel_size(image: torch.Tensor, 
#                            requested_pixel_size: float, 
#                            model_pixel_size: float,
#                            mode: str = "bilinear") -> torch.Tensor:
    
#     original_dim = image.dim()

#     image = _to_ndim(image, 4)

#     scale_factor = requested_pixel_size / model_pixel_size

#     if not np.allclose(scale_factor,1, pixel_size_precision): #if you change this value, you MUST modify the whole_slide_image function.
#         image = interpolate(image, scale_factor=scale_factor, mode=mode)

#     return _to_ndim(image, original_dim)

# The filepath below is not used anywhere...
os.environ["INSTANSEG_BIOIMAGEIO_PATH"] = '/home/cdt/Documents/Projects/InstanSeg/instanseg_thibaut/instanseg/bioimageio_models/'
# Path where we want the .h5 files to go
os.environ['INSTANSEG_DATASET_PATH'] = "h5_datasets_test"
path = Path(os.environ['INSTANSEG_DATASET_PATH']).resolve()
print("Resolved path:", path)
print("Exists?", path.exists())
print("Is directory?", path.is_dir())

instanseg_script = torch.jit.load("../instanseg.pt")
brightfield_nuclei = InstanSeg(instanseg_script, verbosity = 0)

patch_size = 128
destination_pixel_size = 0.5 # 2420... # WE MIGHT WANT TO EXPERIMENT WITH SETTING THIS TO A LOWER VALUE --> TO INCREASE RESOLUTION
rescale_output = False if destination_pixel_size == 0.5 else True

image_types  = ["cpg"]#, "ihc"]

for image_type in image_types:

  if image_type == "cpg":
    image_key  = "pas-cpg"
  else:
    image_key = "ihc"

  device = "cpu"

  np.random.seed(0)
  import h5py
  with h5py.File(Path(os.environ['INSTANSEG_DATASET_PATH']) / f"monkey_{image_type}_oneslide_metadata.h5", "w") as f: # Setting the name for the .h5 file

      f.attrs['class_names'] = str({"0": "lymphocytes", "1": "monocytes", "2" : "other"})  # Convert to string since HDF5 attributes must be simple types
      f.attrs['pixel_size'] = destination_pixel_size

      for split in ['train', 'val']:
          f.create_dataset(f"{split}/data", shape=(0, 4, patch_size, patch_size),
          dtype=np.uint8, maxshape=(None, 4, patch_size, patch_size),
          chunks=(1, 4, patch_size, patch_size),)
          f.create_dataset(f"{split}/labels", shape=(0, 1), dtype=np.uint8, maxshape=(None, 1))
          f.create_dataset(f"{split}/metadata", shape=(0,), dtype=h5py.string_dtype(), maxshape=(None,), chunks=(1,))

      for file in tqdm(annotations_dict.keys()):
          split = annotations_dict[file][0]["split"]

          for annotation in annotations_dict[file]:

              array = _to_tensor_float32(annotation["pas-cpg"])

              labels , input_tensor = brightfield_nuclei.eval_medium_image(array,
              pixel_size = 0.2420, rescale_output = rescale_output, seed_threshold = 0.05, tile_size= 1024)

              dots = torch.tensor(annotation["dots"]).to(device)
              dots[:,:2] = dots[:,:2] * 0.2420 / destination_pixel_size

              mask = _rescale_to_pixel_size(_to_tensor_float32(annotation["mask"]), 0.2420, destination_pixel_size).to(device)
              
              labels = labels.to(device) * torch.tensor(mask).bool()
              canvas = torch.zeros_like(labels)
              dots = torch.tensor(dots, dtype=torch.long)
              canvas[:,:,dots[:,0],dots[:,1]] = dots[:,2].float() + 1
              monocytes = labels * torch.isin(labels,labels * (canvas == 2).float()).float()
              lymphocytes = labels * torch.isin(labels,labels * (canvas == 1).float()).float()
              other_cells = (labels * ~torch.isin(labels,labels * (canvas > 0).float())).float()

              img_tensor = _rescale_to_pixel_size(_to_tensor_float32(annotation[image_key]), 0.2420, destination_pixel_size).byte().to(device)

              assert img_tensor.shape[-2:] == labels.shape[-2:]
              detected_leukocytes += len(torch.unique(monocytes + lymphocytes)) - 1
              leukocytes_dots += len(dots)


              if len(torch.unique(monocytes)) > 1:
                crops,masks,coords = get_masked_patches(monocytes,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_monocytes =(torch.cat((crops,masks),dim= 1))
                y_monocytes = torch.zeros(len(x_monocytes),dtype = torch.long) + 1
                metadata_monocytes = [
                        {
                            "file": file,
                            "annotation_idx": annotation.get("id", None),
                            "class": "monocyte",
                            "coord_row": int(coord[0].item()),
                            "coord_col": int(coord[1].item())
                        }
                        for coord in coords
                    ]
              else:
                x_monocytes = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_monocytes = torch.zeros(0,dtype = torch.long) + 1
                metadata_monocytes = []


              if len(torch.unique(lymphocytes)) > 1:
                crops,masks,coords = get_masked_patches(lymphocytes,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_lymphocytes =(torch.cat((crops,masks),dim= 1))
                y_lymphocytes = torch.zeros(len(x_lymphocytes),dtype = torch.long) + 0
                metadata_lymphocytes = [
                        {
                            "file": file,
                            "annotation_idx": annotation.get("id", None),
                            "class": "lymphocyte",
                            "coord_row": int(coord[0].item()),
                            "coord_col": int(coord[1].item())
                        }
                        for coord in coords
                    ]
              else:
                x_lymphocytes = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_lymphocytes = torch.zeros(0,dtype = torch.long) + 0
                metadata_lymphocytes = []

              if len(torch.unique(other_cells)) > 1:
                crops,masks,coords = get_masked_patches(other_cells,img_tensor, patch_size=patch_size)
                crops = (crops).to(torch.uint8)
                masks = (masks).to(torch.uint8)
                x_other =(torch.cat((crops,masks),dim= 1))
                y_other = torch.zeros(len(x_other),dtype = torch.long) + 2
                metadata_other = [
                        {
                            "file": file,
                            "annotation_idx": annotation.get("id", None),
                            "class": "lymphocyte",
                            "coord_row": int(coord[0].item()),
                            "coord_col": int(coord[1].item())
                        }
                        for coord in coords
                    ]  
              else:
                x_other = torch.zeros(0,4,patch_size,patch_size).to(device)
                y_other = torch.zeros(0,dtype = torch.long) + 2
                metadata_other = []

              x = torch.cat((x_monocytes,x_lymphocytes,x_other),dim = 0)
              y = torch.cat((y_monocytes,y_lymphocytes,y_other),dim = 0).numpy()[:,None]
              metadata_combined = metadata_monocytes + metadata_lymphocytes + metadata_other

              if len(x) != len(y):
                    pdb.set_trace()

              data_ds = f[f"{split}/data"]
              labels_ds = f[f"{split}/labels"]
              metadata_ds = f[f"{split}/metadata"]

              data_ds.resize((data_ds.shape[0] + x.shape[0],) + x.shape[1:])
              data_ds[-x.shape[0]:, ...] = (x).cpu().numpy().astype(np.uint8)
              labels_ds.resize((labels_ds.shape[0] + y.shape[0],) + y.shape[1:])
              labels_ds[-y.shape[0]:, ...] = y.astype(np.uint8)

              # Write metadata
              metadata_ds.resize((metadata_ds.shape[0] + len(metadata_combined),))
              metadata_ds[-len(metadata_combined):] = [str(m) for m in metadata_combined]
          

  undetected_percent = ( leukocytes_dots - detected_leukocytes) / leukocytes_dots
  print(f"Detected {detected_leukocytes} out of {leukocytes_dots} dots. { 100 - undetected_percent * 100:.2f}% detected")

Resolved path: /vol/csedu-nobackup/course/IMC037_aimi/group12/notebooks/h5_datasets_test
Exists? True
Is directory? True


  return forward_call(*args, **kwargs)
  labels = labels.to(device) * torch.tensor(mask).bool()
  dots = torch.tensor(dots, dtype=torch.long)
100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:44<00:00, 44.79s/it]


Detected 1127 out of 1239 dots. 90.96% detected
