# nnUNet WSI inference pipeline

nnUNet by default does half overlap if the given patch is bigger than the model's patch size.
This means that on the borders there is no or 1x overlap (1 or 2 predictions),
while in the inside there are 4 predictions for each pixel. In this version of nnUNet WSI inference, we crop this border off the sampled patch, and only write the inner part.

This approach becomes more efficient if bigger patches are sampled, because this increases the inner/outer ratio. 
To prevent inference on large empty patches we add a check if we can remove rows and columns, while preserving the half overlap of nnUNet's sliding window approach. 

Next to doing inference we also calculate the mean disagreement of the 5 folds for every output pixel, as described in the MIDL paper. 

The loop is constructed in such a manner that multiple scripts can run in parralel without interfering. It does so by creating a '\<name\>_runtime.txt'  file once the loop starts inference on a new WSI + tissue mask match. If this txt file exists it means this match has already been processed or is currently being processed. This means that if a run failed, you'll need to delete its txt file to run it again. 

While this is the approach we used for our TIGER performance, there is still quite some room for improvement. We are currently working on an improved version which is more readable and easier constructed. 

In this file we will also auto-configure the config and files yamls, which may seem quite complex but can come in handy while debugging.

## Download example files
These are 2 images and 2 corresponding tissue masks

##### *aws s3 cp --no-sign-request s3://tiger-training/wsibulk/images/103S.tif /add/your/path/here/wsibulk/images/103S.tif*

##### *aws s3 cp --no-sign-request s3://tiger-training/wsibulk/tissue-masks/103S_tissue.tif /add/your/path/here/wsibulk/tissue-masks/103S_tissue.tif*

##### *aws s3 cp --no-sign-request s3://tiger-training/wsibulk/images/111S.tif /add/your/path/here/wsibulk/images/111S.tif*

##### *aws s3 cp --no-sign-request s3://tiger-training/wsibulk/tissue-masks/111S_tissue.tif /add/your/path/here/wsibulk/tissue-masks/111S_tissue.tif*

# Imports

In [2]:
from wholeslidedata.iterators import create_batch_iterator
from wholeslidedata.accessories.asap.imagewriter import WholeSlideMaskWriter
from tqdm.notebook import tqdm
from pathlib import Path
import numpy as np
from nnunet.training.model_restore import load_model_and_checkpoint_files
import os
import torch
from wholeslidedata.samplers.utils import fit_data
import yaml
from wholeslidedata.image.wholeslideimage import WholeSlideImage
import time



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet

nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.


# Functions and utils 

In [3]:
def norm_01(x_batch): # Use this for models trained on 0-1 scaled data
    x_batch = x_batch / 255
    x_batch = x_batch.transpose(3, 0, 1, 2)
    return x_batch

def z_norm(x_batch): # use this for default nnunet models, using z-score normalized data
    mean = x_batch.mean(axis=(-2,-1), keepdims=True)
    std = x_batch.std(axis=(-2,-1), keepdims=True)
    x_batch = ((x_batch - mean) / (std + 1e-8))
    x_batch = x_batch.transpose(3, 0, 1, 2)
    return x_batch


def ensemble_softmax_list(trainer, params, x_batch):
    softmax_list = []
    for p in params:
        trainer.load_checkpoint_ram(p, False)
        softmax_list.append(
            trainer.predict_preprocessed_data_return_seg_and_softmax(x_batch.astype(np.float32), verbose=False,
                                                                     do_mirroring=False, mirror_axes=[])[
                -1].transpose(1, 2, 3, 0).squeeze())
    return softmax_list


def array_to_formatted_tensor(array):
    array = np.expand_dims(array.transpose(2, 0, 1), 0)
    return torch.tensor(array)


def softmax_list_and_mean_to_uncertainty(softmax_list, softmax_mean):
    loss = torch.nn.CrossEntropyLoss(reduction='none')
    uncertainty_loss_per_pixel_list = []
    for softmax in softmax_list:
        log_softmax = np.log(softmax + 0.00000001)
        uncertainty_loss_per_pixel = loss(array_to_formatted_tensor(log_softmax),
                                          array_to_formatted_tensor(softmax_mean))
        uncertainty_loss_per_pixel_list.append(uncertainty_loss_per_pixel)
    uncertainty = torch.cat(uncertainty_loss_per_pixel_list).mean(dim=0)
    return uncertainty

def get_trim_indexes(y_batch):
    """
    Using the y_mask / tissue-background mask we can check if there are
    full empty rows and columns with a width of half the model patch size.
    We check this in half model patch size increments because otherwise
    we screw up the half overlap approach from nnunet (resulting in inconsistent
    overlap thoughout the WSI).
    We will still need 1 row or column that is empty to make sure the parts that
    do have tissue have 4x overlap
    """
    y = y_batch[0]
    r_is_empty = [not y[start:end].any() for start, end in zip(half_patch_size_start_idxs, half_patch_size_end_idxs)]
    c_is_empty = [not y[:, start:end].any() for start, end in zip(half_patch_size_start_idxs, half_patch_size_end_idxs)]

    empty_rs_top = 0
    for r in r_is_empty:
        if r == True:
            empty_rs_top += 1  # count empty rows
        else:
            trim_top_half_idx = empty_rs_top - 1  # should always include a single empty row, since we need the overlap
            trim_top_half_idx = np.clip(trim_top_half_idx, 0, None)  # cannot select regiouns outside sampled patch
            trim_top_idx = half_patch_size_start_idxs[trim_top_half_idx]
            break

    empty_rs_bottom = 0
    for r in r_is_empty[::-1]:
        if r == True:
            empty_rs_bottom += 1
        else:
            trim_bottom_half_idx = empty_rs_bottom - 1
            trim_bottom_half_idx = np.clip(trim_bottom_half_idx, 0, None)
            trim_bottom_idx = half_patch_size_end_idxs[::-1][trim_bottom_half_idx]  # reverse index
            break

    empty_cs_left = 0
    for c in c_is_empty:
        if c == True:
            empty_cs_left += 1
        else:
            trim_left_half_idx = empty_cs_left - 1
            trim_left_half_idx = np.clip(trim_left_half_idx, 0, None)
            trim_left_idx = half_patch_size_start_idxs[trim_left_half_idx]
            break

    empty_cs_right = 0
    for c in c_is_empty[::-1]:
        if c == True:
            empty_cs_right += 1
        else:
            trim_right_half_idx = empty_cs_right - 1
            trim_right_half_idx = np.clip(trim_right_half_idx, 0, None)
            trim_right_idx = half_patch_size_end_idxs[::-1][trim_right_half_idx]
            break

    # print(trim_top_half_idx, trim_bottom_half_idx, trim_left_half_idx, trim_right_half_idx)
    # print(trim_top_idx, trim_bottom_idx, trim_left_idx, trim_right_idx)
    return trim_top_idx, trim_bottom_idx, trim_left_idx, trim_right_idx


def find_matches(img_folder, match_folder, img_extension='', match_extension='', match_contains='', exact_match=False):
    img_files = [item for item in os.listdir(img_folder) if
                 os.path.isfile(os.path.join(img_folder, item)) and item.endswith(img_extension)]

    match_files = [item for item in os.listdir(match_folder) if os.path.isfile(os.path.join(match_folder, item))]

    # Match and optional filter
    if exact_match:
        assert match_extension != '', 'exact_match needs a match extension to verify if img stem + match_extension == match filename'
        img_match_paths = [(os.path.join(img_folder, img_file),  # add wsi folder in front
                            os.path.join(match_folder, match_file))  # add xml folder in front
                           for img_file in img_files  # loop over wsi folder files
                           for match_file in match_files  # loop over filtered annotaion files
                           if Path(match_file).name.startswith(
                Path(img_file).stem)  # only return if bg file starts with img file name (without suffix)
                           and (match_file == Path(img_file).stem + match_extension)  # exact match
                           and match_contains in match_file]  # only return if match contains match_contains

    else:
        img_match_paths = [(os.path.join(img_folder, img_file),  # add wsi folder in front
                            os.path.join(match_folder, match_file))  # add xml folder in front
                           for img_file in img_files  # loop over wsi folder files
                           for match_file in match_files  # loop over filtered annotaion files
                           if Path(match_file).name.startswith(
                Path(img_file).stem)  # only return if bg file starts with img file name (without suffix)
                           and match_contains in match_file]  # only return if match contains match_contains

    # Checks and prints
    if len(img_match_paths) > 0:
        matched_img_paths, _ = zip(*img_match_paths)
        matched_img_files = [Path(img_path).name for img_path in matched_img_paths]
        unmatched = [img_file for img_file in img_files if img_file not in matched_img_files]
        matches = [img_file for img_file in matched_img_files if
                   img_file in img_files]  # this captures multiple matches per img_file
        if len(unmatched) > 0:
            print(f'{len(unmatched)} files were not machted:\n', unmatched)
        else:
            print('All image files were matched')
        if len(set(matches)) < len(matches):
            print('Some matched image files have multiple matches')
            raise Exception('Ambiguous')
        if len(matches) == len(set(matches)):
            print('Each matched image file has a single match')
    else:
        print('No matches')
    print('\n')
    return img_match_paths

def make_files_yml_and_return_matches_to_run(matches, files_yml_output_path, output_folder):
    """
    The whole pipeline is costructed in such a way that a <name>_runtime.txt file is created once a python script started to work on this WSI + mask match. This means that this file should not be processed anymore by another python script that is run in parralel (also doesnt need to be copied to the local machine (chis is the case on our computing cluster))
    """
    runtime_stems = [file[:-12] for file in os.listdir(output_folder) if file.endswith('_runtime.txt')]
    imgs, _ = zip(*matches)
    img_stems = [Path(file).stem for file in imgs]
    matches_to_run_idx = [i for i in range(len(img_stems)) if img_stems[i] not in runtime_stems]
    matches_to_run = [matches[i] for i in matches_to_run_idx]
    
    yaml_file = {"validation": []}
    for wsi, wsa in matches_to_run:
        # print(wsi, wsa)
        # 1/0
        yaml_file["validation"].append({"wsi": {"path": str(wsi)},
                                        "wsa": {"path": str(wsa)}})
    with open(files_yml_output_path, 'w') as f:
        yaml.dump(yaml_file, f)
        print('CREATED FILES YAML:', files_yml_output_path)

    return matches_to_run

def get_closest_spacing(spacing_value):
    possible_spacings = [0.25, 0.5, 1, 2, 4, 8, 16]
    closest = min(possible_spacings, key=lambda x:abs(x-spacing_value))
    return closest

# Load model
### CHANGE YOUR MODEL AND INPUT NORMALIZATION HERE

In [4]:
model_base_path = Path('example_model_base/nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512')
trainer_name = Path(model_base_path).name
print(model_base_path, '\n')
print(trainer_name, '\n')

folds = (0, 1, 2, 3, 4)
mixed_precision = None
checkpoint_name = "model_best"

trainer, params = load_model_and_checkpoint_files(str(model_base_path), folds, mixed_precision=mixed_precision,
                                                  checkpoint_name=checkpoint_name)
norm = norm_01 # z_norm # Select the right input normalization here!

example_model_base\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512 

nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512 

using the following model files:  ['example_model_base\\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512\\fold_0\\model_best.model', 'example_model_base\\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512\\fold_1\\model_best.model', 'example_model_base\\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512\\fold_2\\model_best.model', 'example_model_base\\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512\\fold_3\\model_best.model', 'example_model_base\\nnUNetTrainerV2_BN_pathology_DA_ignore0_hed005__nnUNet_RGB_scaleTo_0_1_bs8_ps512\\fold_4\\model_best.model']


# Configs and prep
### CHANGE YOUR PATHS AND CONFIGURATIONS HERE

In [5]:
# check output folder name well, this is where we will store the WSI inference files
output_folder = Path('example_output_folder')
os.makedirs(output_folder, exist_ok=True)

# We now make the config and file yamls dynamically
yml_output_folder = os.path.join(output_folder, 'example_files_and_config_ymls')
os.makedirs(yml_output_folder, exist_ok=True)
files_yml_output_filename = "wsi_borderless_inference_files.yml"
config_yml_output_filename = "wsi_borderless_inference_config.yml"
files_yml_output_path = os.path.join(yml_output_folder, files_yml_output_filename)
config_yml_output_path = os.path.join(yml_output_folder, config_yml_output_filename)

### Make files yaml

In [6]:
# Enter the folder with the images/wsi and the tissue/bg masks here. On top of this notebook you can download 2 example files, enter the images and tissue-mask folders below
wsibulk_folder = "/your/path/to/wsibulk"
TIGER_test_wsi = os.path.join(wsibulk_folder, 'images')
TIGER_test_bg = os.path.join(wsibulk_folder, 'tissue-masks')
image_anno_TIGER = find_matches(TIGER_test_wsi, TIGER_test_bg)

matches = image_anno_TIGER # this should be a list of tuples with paths from wsi, mask matches: [(wsi_path1, mask_path1), (wsi_path2, mask_path2)]
matches = make_files_yml_and_return_matches_to_run(matches, files_yml_output_path, output_folder) # this will remove files from the yaml that are currently beint processed or are processed already

All image files were matched
Each matched image file has a single match


CREATED FILES YAML: example_output_folder\configs\wsi_borderless_inference_files.yml


### Make config yaml
This one is filled in already for the example model that comes with the repo. But you may need to adjust these settings for your own models. One thing you need to check though is whether you need to copy the data to a local drive (for example when its stored on a network drive). In this case you need to set copy_data to true and specify the copy_path.

One more note: decrease the sampler_patch_size to reduce cpu usage, which may cause the dataloader to hang (this was the case in our Grand Challenge submission for the TIGER challenge, where we reduced it all the way to 1280 (2.5*model_patch_size))

In [27]:
copy_data = True
if copy_data == True:
    copy_path = '/home/user/data'

In [29]:
# Added hints for construction of config file if you want to do it yourself
spacing = 0.5 # spacing for batch_shape spacing and annotation_parser output_spacing (leave its processing spacing on 4 or higher)

model_patch_size = 512 # input size of model (should be square)
half_model_patch_size=int(model_patch_size/2)

sampler_patch_size = 1280 #4096 # use this as batch shape (= 8 * model_patch_size) 
assert sampler_patch_size % (model_patch_size/2) == 0 # needed for correct half overlap

# due to half overlap there is half the model patch size without overlap on all 4 sides of the sampled patch
output_patch_size = sampler_patch_size - 2 * half_model_patch_size # use this as your annotation_parser shape
# Note that for this approach we need a CenterPointSampler, and center: True in the patch_sampler and patch_label_sampler

tissue_mask_spacing = get_closest_spacing(WholeSlideImage(matches[0][1], backend='asap').spacings[0]) #this takes the first match's mask file, and takes its most detailed spacing
tissue_mask_ratio = tissue_mask_spacing/spacing


template = str(Path('tiger_example_iterator_templates/WSI_inference_config_remote_data_template.yml')) if copy_data == True else str(Path('tiger_example_iterator_templates/WSI_inference_config_local_data_template.yml'))

with open(template) as f:
    config_yml_str = str(yaml.safe_load(f))

replace_dict = {
    "'auto_files_yml'" : files_yml_output_path,
    "'auto_sampler_patch_size'" : sampler_patch_size,
    "'auto_spacing'" : spacing,
    "'auto_tissue_mask_spacing'" : tissue_mask_spacing,
    "'auto_tissue_mask_ratio'" : tissue_mask_ratio,
    "'auto_output_patch_size'" : output_patch_size
}

print('\nAuto configuring CONFIG YAML. Replacing template placeholders:')
for k, v in replace_dict.items():
    print('\t', k, v)
    config_yml_str = config_yml_str.replace(k, str(v))

config_yml = yaml.safe_load(config_yml_str)

with open(config_yml_output_path, 'w') as f:
    yaml.dump(config_yml, f)
    print('CREATED CONFIG YAML:', config_yml_output_filename)


Auto configuring CONFIG YAML. Replacing template placeholders:
	 'auto_files_yml' example_output_folder\configs\wsi_borderless_inference_files.yml
	 'auto_sampler_patch_size' 1280
	 'auto_spacing' 0.5
	 'auto_tissue_mask_spacing' 0.5
	 'auto_tissue_mask_ratio' 1.0
	 'auto_output_patch_size' 768
CREATED CONFIG YAML: wsi_borderless_inference_config.yml


#### Some variable settings you can ignore

In [22]:
mode = "validation"
context = 'spawn' if os.name == "nt" else 'fork'
image_path = None
previous_file_key = None
files_exist_already = None  # not sure if needed
plot = False
# following is later used to check if we can remove big empty parts of the sampled patch before inference
sampler_patch_size_range = list(range(sampler_patch_size))
half_patch_size_start_idxs = sampler_patch_size_range[0::half_model_patch_size]
half_patch_size_end_idxs = [idx + half_model_patch_size for idx in half_patch_size_start_idxs]


# Run
The loop could be simplified quite a bit, we'll release an improved version in our v2 code release!

In [24]:
training_iterator = create_batch_iterator(mode=mode,
                                          context=context,
                                          user_config=config_yml_output_path,
                                          presets=('slidingwindow',),
                                          cpus=4,
                                          number_of_batches=-1,
                                          return_info=True)

for x_batch, y_batch, info in tqdm(training_iterator):
    ### Get image data, check if new image, save previous image if there was one, if new image create new image file
    sample_reference = info['sample_references'][0]['reference']
    current_file_key = sample_reference.file_key
    if current_file_key != previous_file_key:  # if starting a new image
        if previous_file_key != None and files_exist_already != True:  # if there was a previous image, and the previous image did not exist already (can also be None)
            wsm_writer.save()  # save previous mask
            wsu_writer.save()  # save previous uncertainty
            # Save runtime
            end_time = time.time()
            run_time = end_time - start_time
            text_file = open(output_folder / (image_path.stem + '_runtime.txt'), "w")
            text_file.write(str(run_time))
            text_file.close()
        # Getting file settings and path, doing check if exists already
        with training_iterator.dataset.get_wsi_from_reference(sample_reference) as wsi:
            image_path = wsi.path
            shape = wsi.shapes[wsi.get_level_from_spacing(spacing)]
            real_spacing = wsi.get_real_spacing(spacing)
        wsm_path = output_folder / (image_path.stem + '_nnunet.tif')
        wsu_path = output_folder / (image_path.stem + '_uncertainty.tif')
        if os.path.isfile(wsm_path) and os.path.isfile(wsu_path):
            files_exist_already = True  # this means we can skip this whole loop for this file key, checked above '### Prep, predict and uncertainty'
            previous_file_key = current_file_key
            print(f'[SKIPPING] files for {image_path.stem} exist already')
            continue  # continue to next batch
        else:
            files_exist_already = False
        # Create new writer and file
        start_time = time.time()
        wsm_writer = WholeSlideMaskWriter()  # whole slide mask
        wsu_writer = WholeSlideMaskWriter()  # whole slide uncertainty
        # Create files
        wsm_writer.write(path=wsm_path, spacing=real_spacing, dimensions=shape,
                         tile_shape=(output_patch_size, output_patch_size))
        wsu_writer.write(path=wsu_path, spacing=real_spacing,
                         dimensions=shape, tile_shape=(output_patch_size, output_patch_size))

    ### If this file already exists, skip
    if files_exist_already:
        continue

    ### Trim check
    trim_top_idx, trim_bottom_idx, trim_left_idx, trim_right_idx = get_trim_indexes(y_batch)
    x_batch_maybe_trimmed = x_batch[:, trim_top_idx : trim_bottom_idx, trim_left_idx: trim_right_idx, :]

    ### Prep, predict and uncertainty
    prep = norm(x_batch_maybe_trimmed)

    softmax_list = ensemble_softmax_list(trainer, params, prep)
    softmax_mean = np.array(softmax_list).mean(0)
    pred_output_maybe_trimmed = softmax_mean.argmax(axis=-1)

    uncertainty = softmax_list_and_mean_to_uncertainty(softmax_list, softmax_mean)
    uncertainty_output_maybe_trimmed = np.array((uncertainty.clip(0, 4) / 4 * 255).int()) 

    ### Reconstruct possible trim
    pred_output = np.zeros((sampler_patch_size, sampler_patch_size))
    pred_output[trim_top_idx : trim_bottom_idx, trim_left_idx: trim_right_idx] = pred_output_maybe_trimmed

    uncertainty_output = np.zeros((sampler_patch_size, sampler_patch_size))
    uncertainty_output[trim_top_idx: trim_bottom_idx, trim_left_idx: trim_right_idx] = uncertainty_output_maybe_trimmed

    # Only write inner part
    pred_output_inner = fit_data(pred_output, [output_patch_size, output_patch_size])
    uncertainty_output_inner = fit_data(uncertainty_output, [output_patch_size, output_patch_size])
    y_batch_inner = fit_data(y_batch[0], [output_patch_size, output_patch_size]).astype('int64')
    
    ### Get patch point
    point = info['sample_references'][0]['point']
    c, r = point.x - output_patch_size/2, point.y - output_patch_size/2 # from middle point to upper left point of tile to write
    
    ### Write tile and set previous file key for next loop check
    wsm_writer.write_tile(tile=pred_output_inner, coordinates=(int(c), int(r)), mask=y_batch_inner)
    wsu_writer.write_tile(tile=uncertainty_output_inner, coordinates=(int(c), int(r)), mask=y_batch_inner)
    previous_file_key = current_file_key

wsm_writer.save()  # if done save last image
wsu_writer.save()  # if done save last image

# Save runtime
end_time = time.time()
run_time = end_time - start_time
text_file = open(output_folder / (image_path.stem + '_runtime.txt'), "w")
text_file.write(str(run_time))
text_file.close()

training_iterator.stop()
print("DONE")