# Imports and functions

In [None]:
%load_ext autoreload
%autoreload 2

# File management
import os
import glob
from tqdm import tqdm
from pathlib import Path

# Linear algebra and dataframes
import numpy as np  
import math
import pandas as pd

# Image processing
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
from openslide import OpenSlide, lowlevel 
from shapely.geometry import Polygon
import cv2 as cv
from PIL import Image, ImageDraw
Image.MAX_IMAGE_PIXELS = 1000000000000

# File extensions
import h5py
import json
import xml.etree.ElementTree as ET

# Tools
import staintools
from dgl.data.utils import save_graphs
from histocartography.visualization import OverlayGraphVisualization
from histocartography.preprocessing import (  
    GaussianTissueMask,                 # extract tissue mask   
    ColorMergedSuperpixelExtractor,     # superpixel extractor 1
    SLICSuperpixelExtractor,            # superpixel extractor 2
    AugmentedDeepFeatureExtractor,      # feature extractor
    RAGGraphBuilder,                    # graph builder
    )

from time import sleep
from skimage.segmentation import morphological_chan_vese, morphological_geodesic_active_contour, mark_boundaries
from skimage import measure
from scipy import ndimage
from skimage import morphology
from scipy.ndimage.measurements import label


# size of region : cols, rows 
def create_image_patches(img, num_rows, num_columns):
    """
    Partition an image into multiple patches of approximately equal size.
    The patch size is based on the desired number of rows and columns.
    Returns a list of image patches, in row-major order.
    """

    patch_list = []
    width, height = img.shape[1], img.shape[0]
    w, h = width // num_columns, height // num_rows # // is similar to the floor function
    #print(w,h)
    
    for y in range(0, h*num_rows, num_rows): 
        y_end = min(y + num_rows, height)
        for x in range(0, w*num_columns, num_columns):
            x_end = min(x + num_columns, width)
            patch = img[y:y_end, x:x_end]
            patch_list.append(patch)

    return patch_list

def normalize_image(image, ref, method, standartize_brightness):
    if standartize_brightness:
        image = staintools.LuminosityStandardizer.standardize(image)
        ref = staintools.LuminosityStandardizer.standardize(ref)
            
    if method == 'reinhard':
        normalizer = staintools.ReinhardColorNormalizer()
    if method == 'vahadane':
        normalizer = staintools.StainNormalizer(method='vahadane')
    if method == 'macenko':
        normalizer = staintools.StainNormalizer(method='macenko')
            
    normalizer.fit(ref)
    norm_img = normalizer.transform(image)
        
    return norm_img 

def save_superpixel_map(
        maps: np.ndarray,
        path: Path
):
    output_key = "default_key"

    with h5py.File(path, "w") as f:
        if not isinstance(maps, tuple):
            maps = tuple([maps])

        for i, output in enumerate(maps):
            f.create_dataset(
                f"{output_key}_{i}",
                data=output,
                compression="gzip",
                compression_opts=9,
            )


def count_patches(patches_paths: list):
    nb_pos = 0
    nb_neg = 0
    for patch in patches_paths:
        if 'bg_' in patch:
            nb_neg += 1
        elif 'patch_' in patch:
            nb_pos += 1
        else : 
            print(f"{patch.split('/')[-1].split('.')[0]} is corrupted")
            break
    nb_total = nb_neg + nb_pos
    return nb_total, nb_neg, nb_pos

def save_superpixel_map(
        maps: np.ndarray,
        path: Path
):
    output_key = "default_key"

    with h5py.File(path, "w") as f:
        if not isinstance(maps, tuple):
            maps = tuple([maps])

        for i, output in enumerate(maps):
            f.create_dataset(
                f"{output_key}_{i}",
                data=output,
                compression="gzip",
                compression_opts=9,
            )

def get_properties(array):
    values = np.unique(array) 
    properties = measure.regionprops(array+1)
    areas = np.array([prop.area for prop in properties])
    solidity = np.array([prop.solidity for prop in properties])
    eccentricity = np.array([prop.eccentricity for prop in properties])
    return values, areas, solidity, eccentricity

def swapping(arr):
    value, count = np.unique(arr, return_counts=True)
    if count[0] < count[1]:
        print('Segmentation : swapping 0s and 1s')
        arr = np.where((arr==0)|(arr==1), arr^1, arr)
    return arr

# Patch extraction

## Reading low level and the region using OpenSlide

In [None]:
SLIDE_ID = 0 # select the slide
PATCH_SIZE = 512 # select patch size
TASK = 'tangles' # select'plaques" or 'tangles'
SLIDE_DIM_LVL = 0 # select the magnification level

main_path = f'/localdrive10TB/users/pablo.mas/datasets/'

if TASK == 'plaques':
    slides_path = os.path.join(main_path, f'{TASK}/stratifiad/AT8_wsi/')
    labels_path = os.path.join(main_path, f'{TASK}/stratifiad/AT8_XML_annotations/')
elif TASK == 'tangles':
    slides_path = os.path.join(main_path, f'{TASK}/stratifiad/set1_slides/')
    labels_path = os.path.join(main_path, f'{TASK}/stratifiad/set1_annotations/')
elif TASK == 'tangles2':
    slides_path = os.path.join(main_path, f'{TASK}/stratifiad/Set 2 Tangles virtual slides/')
    labels_path = os.path.join(main_path, f'{TASK}/stratifiad/Set2 Tangles XML annotations/')
else: 
    print("ERROR - Please use a valid task ('plaques' or 'tangles')")

slides = sorted(glob.glob(os.path.join(slides_path,'*.ndpi')))
annotations = sorted(glob.glob(os.path.join(labels_path,'*.xml')))

for i, slide_path in enumerate(slides):
    name = slide_path.split('/')[-1].split('.ndpi')[0]
    print(f'{i} : {name}')

name = slides[SLIDE_ID].split('/')[-1].split('.ndpi')[0]

remove_patches = False # for now it is used for WSI 11
slide_name = slides[SLIDE_ID]
label_name = annotations[SLIDE_ID]

assert slide_name.split('/')[-1].split('.')[0] == label_name.split('/')[-1].split('.')[0], 'Error, slide and label do not match'

# Opening the slide image
slide = lowlevel.open(slide_name)
keys = lowlevel.get_property_names(slide)
val = lowlevel.get_property_value(slide,keys[-1])

# This are important values for nm -> pixel conversion
offsetX = int(lowlevel.get_property_value(slide, 'hamamatsu.XOffsetFromSlideCentre')) # THIS IS IN NANOMETERS!
offsetY = int(lowlevel.get_property_value(slide, 'hamamatsu.YOffsetFromSlideCentre')) # THIS IS IN NANOMETERS!

resX = float(lowlevel.get_property_value(slide, 'openslide.mpp-x')) # THIS IS IN MICRONS/PIXEL FOR LVL 0!
resY = float(lowlevel.get_property_value(slide, 'openslide.mpp-y')) # THIS IS IN MICRONS/PIXEL FOR LVL 0!

slide = OpenSlide(slide_name)

# Getting slide level dimentions
slide_levels = slide.level_dimensions

# Printing important information about the current slide
print("\n [INFO] The slide have ", len(slide_levels), " magnification levels:")
for i in range(len(slide_levels)):
    print(f"   Level {i} (mag x{40/2**i}) with dimensions (in pixels) : {slide_levels[i]}.")




## Extracting coordinates from XML and conversion to pixels

In [None]:
''' 
From nano/micro meters to pixel language :) 
Larger value is X axis (-->)
'''
dimsSlide = np.array(slide_levels[0])*[resX,resY] # this is in micrometers :)
centerPx_X, centerPx_Y = np.array(slide_levels[0])/2
_, factor = np.array(slide_levels[0])/np.array(slide_levels[SLIDE_DIM_LVL])

sizeX, sizeY = np.array(slide_levels[0])/factor

# Loading the slide annotations from XML file
tree = ET.parse(label_name)
root = tree.getroot()

# Preparing the annotation container
labels = []

# Getting the annotaiton coordinates from the XML file
for boxes in root:
    for obejcts in boxes:
        type_object = int(obejcts.attrib['Type'])
        for vertices in obejcts:
            temp_obj = []
            for vertex in vertices:
                y_mm = float(vertex.attrib['Y']) # this is in milimeters!
                x_mm = float(vertex.attrib['X']) # this is in milimeters!
                y_p_offset = (y_mm)*1000 - (abs(offsetY)/1000) # this is in micrometers!
                x_p_offset = (x_mm)*1000 - (abs(offsetX)/1000) # this is in micrometers!
                y_newCenter = y_p_offset + int(centerPx_Y)*resY # this is in micrometers!
                x_newCenter = x_p_offset + int(centerPx_X)*resX # this is in micrometers!
                y = (y_newCenter/resY)/factor # pixels 
                x = (x_newCenter/resX)/factor # pixels

                ''' Flip '''
                y = abs(sizeY - y)
                #x = sizeX - x

                temp_obj.append([round(x), round(y)])
            labels.append([type_object, np.array(temp_obj)])

In [None]:
# thm = slide.read_region((0, 0), SLIDE_DIM_LVL, slide_levels[SLIDE_DIM_LVL])

# print("\n[INFO] The whole-slide image with the gray matter annotation ( dim level", SLIDE_DIM_LVL, ")")
# plt.figure(figsize = (10,10))
# plt.imshow(thm)
# plt.plot(labels[0][1][:, 0], labels[0][1][:, 1])
# for obj, coordinates in labels:
#     plt.plot(coordinates[:, 0], coordinates[:, 1])
# plt.show()

## Generate patches & masks: ROI-guided.

In [None]:
# Creating mask for the WSI
mask_ROI = Image.new('L', (int(sizeX), int(sizeY)), 0)
mask_obj = Image.new('L', (int(sizeX), int(sizeY)), 0)

coords_list = []
coords_region_lst = []
for obj, coordinates in labels:
    if obj == 1: # this are the annotations by Lev.
        coordinates2list = coordinates.tolist()
        tuples = [tuple(x) for x in coordinates2list]
        ImageDraw.Draw(mask_ROI).polygon(tuples, outline=1, fill=1)

        polygon_patch = Polygon(coordinates)
        coords_region_lst.append(list(polygon_patch.bounds))
        # size = (int(coords_region[2]-coords_region[0])+1, int(coords_region[3]-coords_region[1])+1)
  
    if obj == 2:
        coordinates2list = coordinates.tolist()
        tuples = [tuple(x) for x in coordinates2list]
        ImageDraw.Draw(mask_obj).polygon(tuples, outline=1, fill=1)
        
        polygon_obj = Polygon(coordinates)
        coords_obj = list(polygon_obj.bounds)
        coords_list.append(coords_obj)

coords_region = []
size = []
for c in coords_region_lst:
    coords_region.append(c) 
    size.append((int(c[2]-c[0])+1, int(c[3]-c[1])+1))
print(coords_region)
print(size)

mask_ROI_WSI = mask_ROI.point(lambda i: i * 255)
mask_obj_WSI = mask_obj.point(lambda i: i * 255)

'''
# Save and plot just to check if it is working properly
mask_ROI_WSI.point(lambda i: i * 255).save('./name.png')
mask_obj_WSI.point(lambda i: i * 255).save('./name2.png')
'''
# plt.figure()
# plt.imshow(mask_ROI_WSI, cmap = "gray")
# #plt.xlim((coords_region[0], coords_region[2]))
# #plt.ylim((coords_region[3], coords_region[1]))
# plt.show()

# plt.figure()
# plt.imshow(mask_obj_WSI, cmap = "gray")
# #plt.xlim((coords_region[0], coords_region[2]))
# #plt.ylim((coords_region[3], coords_region[1]))
# plt.show()

In [None]:
'''
This code creates the region and its corresponding patches for the 4 corners of data augmentation.
'''

patchSize = [PATCH_SIZE, PATCH_SIZE] # [cols, rows]
save_path = os.path.join(main_path, f"{TASK}/{PATCH_SIZE}", slide_name.split('/')[-1].split('.ndpi')[0])
os.makedirs(save_path, exist_ok=True)
os.makedirs(os.path.join(save_path, "patches"), exist_ok=True)
os.makedirs(os.path.join(save_path, "masks"), exist_ok=True)

# mask_WSI: has all the masks for the annotations and is the same size as the WSI.
# slide: is the original WSI in the level selected. It is not loaded into RAM yet.
# labels: annotations from XML.
# coords_list: list of coordinates of annotated objects.

k = 0
for coords in tqdm(coords_list): 
    coords[0] = int(math.floor(coords[0]))
    coords[1] = int(math.floor(coords[1]))
    coords[2] = int(math.ceil(coords[2]))
    coords[3] = int(math.ceil(coords[3]))

    size_obj = [coords[2]-coords[0]+1, coords[3]-coords[1]+1]
    new_region_size = np.array(patchSize) - np.array(size_obj)        

    new_coords = (coords[0]-new_region_size[0]/2, coords[1]-new_region_size[1]/2, coords[2]+new_region_size[0]/2, coords[3]+new_region_size[1]/2)  
    new_region_obj = slide.read_region((int(new_coords[0]*factor), int(new_coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))      
    # new_region_obj = slide.read_region((int(new_coords[0]*factor), int(new_coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))
    new_mask_obj = mask_obj_WSI.crop((int(new_coords[0]),int(new_coords[1]),int(new_coords[0]+patchSize[0]),int(new_coords[1]+patchSize[1])))
    
    new_region_obj.save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_patch_{k:04}.png'))
    new_mask_obj.save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_patch_{k:04}.png'))
    
    '''
    Create 4 additional samples per object/patch ... the annotated object will be in each corner
    '''
    new_region_corner1 = slide.read_region((int(coords[0]*factor), int(coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))
    new_mask_corner1 = mask_obj_WSI.crop((coords[0],coords[1],coords[0]+patchSize[0],coords[1]+patchSize[1]))

    new_region_corner1.save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_patch_{k:04}_c1.png'))
    new_mask_corner1.save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_patch_{k:04}_c1.png'))

    corner2_coords = (coords[2] - patchSize[0], coords[1])
    new_region_corner2 = slide.read_region((int(corner2_coords[0]*factor), int(corner2_coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))
    new_mask_corner2 = mask_obj_WSI.crop((corner2_coords[0],corner2_coords[1],corner2_coords[0]+patchSize[0],corner2_coords[1]+patchSize[1]))

    new_region_corner2.save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_patch_{k:04}_c2.png'))
    new_mask_corner2.save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_patch_{k:04}_c2.png'))

    corner3_coords = (coords[2] - patchSize[0], coords[3] - patchSize[1])
    new_region_corner3 = slide.read_region((int(corner3_coords[0]*factor), int(corner3_coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))
    new_mask_corner3 = mask_obj_WSI.crop((corner3_coords[0],corner3_coords[1],corner3_coords[0]+patchSize[0],corner3_coords[1]+patchSize[1]))

    new_region_corner3.save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_patch_{k:04}_c3.png'))
    new_mask_corner3.save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_patch_{k:04}_c3.png'))

    corner4_coords = (coords[0], coords[3] - patchSize[1])
    new_region_corner4 = slide.read_region((int(corner4_coords[0]*factor), int(corner4_coords[1]*factor)), SLIDE_DIM_LVL, (patchSize[0], patchSize[1]))
    new_mask_corner4 = mask_obj_WSI.crop((corner4_coords[0], corner4_coords[1], corner4_coords[0]+patchSize[0],corner4_coords[1]+patchSize[1]))

    new_region_corner4.save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_patch_{k:04}_c4.png'))
    new_mask_corner4.save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_patch_{k:04}_c4.png'))

    k += 1

## Background extraction

In [None]:
# bg_mask = np.zeros((PATCH_SIZE, PATCH_SIZE))
# background_patch = []
# for coords_bbox, size_bbox in zip(coords_region, size):
#     mask_ROI_WSI = mask_ROI.crop((coords_bbox[0],coords_bbox[1],coords_bbox[0]+size_bbox[0],coords_bbox[1]+size_bbox[1]))
#     mask_obj_WSI = mask_obj.crop((coords_bbox[0],coords_bbox[1],coords_bbox[0]+size_bbox[0],coords_bbox[1]+size_bbox[1]))
#     region = slide.read_region((int(coords_bbox[0]*factor), int(coords_bbox[1]*factor)), SLIDE_DIM_LVL, (size_bbox[0], size_bbox[1]))

#     mask_ROI_WSI = (np.array(mask_ROI_WSI)>0).astype("uint8")
#     mask_obj_WSI = (np.array(mask_obj_WSI)>0).astype("uint8")

#     patch_ROI = create_image_patches(mask_ROI_WSI, PATCH_SIZE, PATCH_SIZE)
#     patch_obj = create_image_patches(mask_obj_WSI, PATCH_SIZE, PATCH_SIZE)
#     patch_region = create_image_patches(np.array(region), PATCH_SIZE, PATCH_SIZE)

#     save_obj = []
#     for i in tqdm(range(len(patch_ROI))):
#         pixelSumROI = np.sum(patch_ROI[i])
#         if pixelSumROI == PATCH_SIZE*PATCH_SIZE:
#             pixelSumObj = np.sum(patch_obj[i])
#             if pixelSumObj == 0:
#                 if patch_region[i][:,:,0:3].std() > 30.0: # get only background with few pixels of glass.
#                     background_patch.append(patch_region[i])
#                     Image.fromarray(patch_region[i]).save(os.path.join(save_path,f'patches/wsi{SLIDE_ID}_bg_{i:04}.png'))
#                     Image.fromarray(bg_mask).convert('L').save(os.path.join(save_path,f'masks/wsi{SLIDE_ID}_bg_{i:04}.png'))

        

In [None]:

wsi_names = [x.split('/')[-1].split('.ndpi')[0] for x in slides]

total_pos = 0
total_neg = 0
total_patches = 0

for name in wsi_names[:SLIDE_ID+1]:
    patches_paths = sorted(glob.glob(os.path.join(main_path, f"{TASK}/{PATCH_SIZE}/{name}/patches/*")))
    nb_total, nb_neg, nb_pos = count_patches(patches_paths)
    print(name)
    print(f'Number of positive patches (plaques) : {nb_pos}')
    print(f'Number of negative patches (backgrounds) : {nb_neg}')
    print('\n')
    total_pos += nb_pos
    total_neg += nb_neg
    total_patches += nb_total
    
print(f'Total number of positive patches (plaques) : {total_pos}')
print(f'Total number of negative patches (backgrounds) : {total_neg}')
print(f'Total number of images : {total_patches}')

# Normalization

In [None]:
NORMALIZATION_TYPE = 'macenko'
reference_path = os.path.join(main_path, f"references/norm_reference_512x512.png")
wsi_paths = sorted(glob.glob(os.path.join(main_path, f'{TASK}/{PATCH_SIZE}/*')))

wsi_path = wsi_paths[SLIDE_ID]
wsi_name = wsi_path.split('/')[-1]

In [None]:

print(f'Preprocessing record {wsi_name}')

print(f'    Creating folders')     
if not os.path.isdir(os.path.join(wsi_path, f'images_{NORMALIZATION_TYPE}')):
    os.mkdir(os.path.join(wsi_path, f'images_{NORMALIZATION_TYPE}'))
    
if not os.path.isdir(os.path.join(wsi_path, f'annotation_masks')):
    os.mkdir(os.path.join(wsi_path, f'annotation_masks'))
    
if not os.path.isdir(os.path.join(wsi_path, 'tissue_masks')):
    os.mkdir(os.path.join(wsi_path, 'tissue_masks'))
    
if not os.path.isdir(os.path.join(wsi_path, 'pickles')):
    os.mkdir(os.path.join(wsi_path, 'pickles'))
    

raw_patches = glob.glob(os.path.join(wsi_path, 'patches/*.png'))
raw_masks = glob.glob(os.path.join(wsi_path, 'masks/*.png'))
raw_patches.sort()
raw_masks.sort()

print('    Normalizing patches')
target = np.array(Image.open(reference_path))
for patch in tqdm(raw_patches):
    raw_array = np.array(Image.open(patch))[:, :, :3]
    norm_array = normalize_image(raw_array, target, NORMALIZATION_TYPE, True)
    norm_image = Image.fromarray(norm_array)
    norm_image.save(os.path.join(wsi_path, f'images_{NORMALIZATION_TYPE}', patch.split('/')[-1]))

print('    Generating annotation masks')  
images_macenko = glob.glob(os.path.join(wsi_path, f'images_{NORMALIZATION_TYPE}', '*.png'))
images_macenko.sort()

for image in tqdm(images_macenko):
    patch = np.array(Image.open(image))[:, :, :3]
    mask = np.array(Image.open(os.path.join(wsi_path, 'masks', image.split('/')[-1])))
    new_mask = mask/255
    img = Image.fromarray(new_mask).convert('L')
    img.save(os.path.join(wsi_path, 'annotation_masks', image.split('/')[-1]))

print('    Generating tissue masks')  
tissue_detector = GaussianTissueMask(downsampling_factor=2, sigma=5)
for image in tqdm(images_macenko):
    patch = np.array(Image.open(image))[:, :, :3]
    tissue_mask = tissue_detector.process(patch)
    img = Image.fromarray(tissue_mask)
    img.save(os.path.join(wsi_path, 'tissue_masks', image.split('/')[-1]))
    

In [None]:
images = sorted(glob.glob(os.path.join(wsi_path, f'images_{NORMALIZATION_TYPE}', '*.png')))

df_image_level_annotations = pd.DataFrame()
df_image_level_annotations['name'] = [x.split('/')[-1].split('.')[0] for x in images]
df_image_level_annotations['benign'] = np.where(df_image_level_annotations['name'].str.contains('bg_'), 1, 0)
df_image_level_annotations['plaque'] = np.where(df_image_level_annotations['name'].str.contains('patch_'), 1, 0)
df_image_level_annotations = df_image_level_annotations.set_index('name')
df_image_level_annotations.to_pickle(os.path.join(wsi_path, 'pickles', 'image_level_annotations.pickle'))

df_images = pd.DataFrame()
df_images['name'] = [x.split('/')[-1].split('.')[0] for x in images]
df_images['image_path'] = images
df_images = df_images.set_index('name')
df_images.to_pickle(os.path.join(wsi_path, 'pickles', 'images.pickle'))

df_annotation_mask = df_images.copy()
df_annotation_mask.rename(columns={'image_path': 'annotation_mask_path'}, inplace=True)
df_annotation_mask = df_annotation_mask['annotation_mask_path']
df_annotation_mask = df_annotation_mask.str.replace(f'images_{NORMALIZATION_TYPE}', 'annotation_masks')
df_annotation_mask.to_pickle(os.path.join(wsi_path, 'pickles', 'annotation_masks.pickle'))

if not os.path.isdir(os.path.join(wsi_path, 'partition')):
    os.mkdir(os.path.join(wsi_path, 'partition'))

if not os.path.isdir(os.path.join(wsi_path, 'partition', 'Test')):
    os.mkdir(os.path.join(wsi_path, 'partition', 'Test'))

if not os.path.isdir(os.path.join(wsi_path, 'partition', 'Train')):
    os.mkdir(os.path.join(wsi_path, 'partition', 'Train'))  
    
df_partition = pd.DataFrame()
df_partition['image_id'] = df_image_level_annotations.index 

from sklearn.model_selection import train_test_split, KFold
X_train, X_test = train_test_split(df_partition, test_size=0.1, random_state=42)
X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)

X_train.to_csv(os.path.join(wsi_path, 'partition', 'Test', 'Train.csv'), index=False)
X_test.to_csv(os.path.join(wsi_path, 'partition', 'Test', 'Test.csv'), index=False)

kf = KFold(n_splits=4)

for i, (train_index, val_index) in enumerate(kf.split(X_train)):
    
    if not os.path.isdir(os.path.join(wsi_path, 'partition', 'Train', f'Val{i+1}')):
        os.mkdir(os.path.join(wsi_path, 'partition', 'Train', f'Val{i+1}'))
    
    X_t, X_v = X_train.iloc[train_index], X_train.iloc[val_index]
    X_t = X_t.reset_index(drop=True)
    X_v = X_v.reset_index(drop=True)

    X_t.to_csv(os.path.join(wsi_path, 'partition', 'Train', f'Val{i+1}', 'Train.csv'), index=False)
    X_v.to_csv(os.path.join(wsi_path, 'partition', 'Train', f'Val{i+1}', 'Val.csv'), index=False)

# Graph preprocessing

## Superpixel extraction

In [None]:
image_paths = sorted(glob.glob(os.path.join(wsi_path, f"images_{NORMALIZATION_TYPE}/*")))
superpixel_path = os.path.join(wsi_path, 'preprocess', 'superpixels')
viz_path = os.path.join(wsi_path, 'preprocess', 'superpixel_viz')

if not os.path.isdir(os.path.join(wsi_path, 'preprocess')):
    os.mkdir(os.path.join(wsi_path, 'preprocess'))
if not os.path.isdir(viz_path):
    os.mkdir(viz_path)
if not os.path.isdir(superpixel_path):
    os.mkdir(superpixel_path)

In [None]:
for image_path in tqdm(image_paths):
    image_name = image_path.split('/')[-1].split('.')[0]
    image = np.array(Image.open(image_path))
    mask = np.array(Image.open(image_path.replace(f"images_{NORMALIZATION_TYPE}",  "masks")))

    # fig, ax = plt.subplots(5, 2, figsize=(10, 20))
    # ax = np.ravel(ax)

    # # Image and ground truth
    # ax[0].imshow(image)
    # ax[0].set_title(f'Normalized image ({NORMALIZATION_TYPE})')
    # ax[1].imshow(mask)
    # ax[1].set_title('Ground truth annotation (Lev)')

    #  Segmentation
    # s0 = morphological_chan_vese(image[:, :, 0], iterations=10, smoothing=1)
    # s1 = morphological_chan_vese(image[:, :, 1], iterations=10, smoothing=1)
    # s2 = morphological_chan_vese(image[:, :, 2], iterations=10, smoothing=1)
    # s = s0+s1+s2
    # arr = np.where(s==0, 1, 0)
    arr = morphological_chan_vese(image[:, :, 0], iterations=10, smoothing=1)
    # ax[2].imshow(arr)
    # ax[2].set_title('Segmentation')
    value, count = np.unique(arr, return_counts=True)
    if count[0] < count[1]:
        # print('Segmentation : swapping 0s and 1s')
        arr = np.where((arr==0)|(arr==1), arr^1, arr)
    
    # Padding
    arr = np.pad(arr, 10, 'constant', constant_values=0)
    # ax[3].imshow(arr)
    # ax[3].set_title('Padding')

    # Binary closing
    arr = ndimage.binary_closing(arr, iterations=2)
    # ax[4].imshow(arr)
    # ax[4].set_title('Binary closing')

    # Area cleaning 1 
    arr = morphology.remove_small_objects(arr, min_size=2000)
    # ax[5].imshow(arr)
    # ax[5].set_title('Area cleaning (small objects)')

    # Area cleaning 2
    arr = morphology.remove_small_holes(arr, area_threshold=600)
    # ax[6].imshow(arr)
    # ax[6].set_title('Area cleaning (small holes)')

    # Labelling
    superpixel = label(arr)[0]
    # ax[7].imshow(superpixel)
    # ax[7].set_title('Superpixel')

    values, areas, solidity, eccentricity = get_properties(superpixel)
    if np.argmax(areas) != 0:
        print("Labelling : reindexing background")
        superpixel = np.where(superpixel==0, 10000, superpixel)
        superpixel = np.where(superpixel==np.argmax(areas), 0, superpixel)
        superpixel = np.where(superpixel==10000, np.argmax(areas), superpixel)


    # Solidity cleaning
    values, areas, solidity, eccentricity = get_properties(superpixel)

    for elem in values[solidity < 0.33]:
        superpixel = np.where(superpixel==elem, np.argmax(areas), superpixel)
        
    for index, elem in enumerate(np.unique(superpixel)):
        superpixel = np.where(superpixel==elem, index, superpixel)
        
    superpixel = superpixel[10:-10, 10:-10] + 1
    save_superpixel_map(superpixel, os.path.join(superpixel_path, image_name + '.h5'))
    
    # ax[8].imshow(superpixel) 
    # ax[8].set_title('Solidity cleaning')

    # ax[9].imshow(mark_boundaries(image, superpixel, color=(0, 255, 0)))
    # plt.tight_layout()
    # plt.savefig(os.path.join(viz_path, f"{image_path.split('/')[-1].split('.')[0]}_1.png"), dpi=300)
    # plt.close()

    # fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    # ax.imshow(mark_boundaries(mark_boundaries(image, mask, color=(0, 0, 255)), superpixel, color=(0, 255, 0)))

    # green_patch = mpatches.Patch(color='green', label='MorphChanVese Superpixel')
    # blue_patch = mpatches.Patch(color='blue', label='Lev Annotation')
    # plt.legend(bbox_to_anchor=(0.70, 1.1), handles=[green_patch, blue_patch])
    # plt.savefig(os.path.join(viz_path, f"{image_path.split('/')[-1].split('.')[0]}_2.png"), dpi=300)
    # plt.close()


In [None]:
feature_extractor = AugmentedDeepFeatureExtractor(
    architecture='mobilenet_v2',
    num_workers=8,
    # rotations=[0, 90, 180, 270],
    # flips=['n', 'h'],
    patch_size=9,
    resize_size=224,
    batch_size=256
    )

graph_builder = RAGGraphBuilder(
    nr_annotation_classes=3,
    annotation_background_class=2,
    add_loc_feats=False
    )

In [None]:
spx_paths = sorted(glob.glob(os.path.join(superpixel_path, '*.h5')))
image_paths = sorted(image_paths)
mask_paths = sorted(glob.glob(os.path.join(wsi_path, 'annotation_masks/*')))
assert len(spx_paths) == len(image_paths) == len(mask_paths)

# Create preprocessing folders

if not os.path.isdir(os.path.join(wsi_path, f'preprocess/graphs')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/graphs'))

if not os.path.isdir(os.path.join(wsi_path, f'preprocess/graphs_viz')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/graphs_viz'))
    
from tqdm import trange

for i in tqdm(np.arange(470, 795)):
    image_name = spx_paths[i].split('/')[-1].split('.')[0]
    spx_path = spx_paths[i]
    image = np.array(Image.open(image_paths[i]))
    mask = np.array(Image.open(mask_paths[i]))
    with h5py.File(spx_path, 'r') as f:
        superpixels = f['default_key_0'][:, :]
        features = feature_extractor.process(image, superpixels)
        graph = graph_builder.process(superpixels, features, mask) # Compute graph
        save_graphs(filename=os.path.join(wsi_path, f'preprocess/graphs', image_name + '.bin'), g_list=[graph])
        
        visualizer = OverlayGraphVisualization()
        canvas = visualizer.process(image, graph, instance_map=superpixels) # Create graph visualization
        canvas.save(os.path.join(wsi_path, f'preprocess/graphs_viz', image_name + '.png'))



In [None]:
SPX_SIZE = 128
BLUR_SIZE = 0
PSIZE = 3

# Create preprocessing folders
if not os.path.isdir(os.path.join(wsi_path, 'preprocess')):
    os.mkdir(os.path.join(wsi_path, 'preprocess'))

if not os.path.isdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}'))
    
if not os.path.isdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/superpixels')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/superpixels'))

if not os.path.isdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs'))

if not os.path.isdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs_viz')):
    os.mkdir(os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs_viz'))
    
# Instantiate preprocessing operators 
spx_extractor = SLICSuperpixelExtractor(
    superpixel_size=SPX_SIZE,
    blur_kernel_size=BLUR_SIZE,
    compactness=20,
    downsampling_factor=1
    )

feature_extractor = AugmentedDeepFeatureExtractor(
    architecture='mobilenet_v2',
    num_workers=8,
    rotations=[0, 90, 180, 270],
    flips=['n', 'h'],
    patch_size=PSIZE,
    resize_size=224,
    batch_size=256
    )

graph_builder = RAGGraphBuilder(
    nr_annotation_classes=3,
    annotation_background_class=2,
    add_loc_feats=False
    )

# Preprocessing
for image_path in tqdm(images_macenko):
    image_name = os.path.basename(image_path).split('.')[0]
    image = np.array(Image.open(image_path))
    mask_path = image_path.replace('images_macenko', 'annotation_masks')
    mask = np.array(Image.open(mask_path))
    
    superpixels = spx_extractor.process(image) # Extract superpixels
    save_superpixel_map(superpixels, os.path.join(wsi_path, f'preprocess/{nb_preproc}/superpixels', image_name + '.h5'))
                        
    features = feature_extractor.process(image, superpixels) # Extract features from superpixels

    graph = graph_builder.process(superpixels, features, mask) # Compute graph
    save_graphs(filename=os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs', image_name + '.bin'), g_list=[graph])

    visualizer = OverlayGraphVisualization()
    canvas = visualizer.process(image, graph, instance_map=superpixels) # Create graph visualization
    canvas.save(os.path.join(wsi_path, f'preprocess/{nb_preproc}/graphs_viz', image_name + '.png'))

# Save the parameters into a JSON file  
feature_extractor.__dict__.pop('transforms', None)
feature_extractor.__dict__.pop('device', None)
feature_extractor.__dict__.pop('patch_feature_extractor', None)

params = {'spx_extractor': spx_extractor.__dict__,
          'feature_extractor':feature_extractor.__dict__,
          'graph_builder': graph_builder.__dict__}

with open(os.path.join(wsi_path, 'preprocess/{nb_preproc}/parameters.json'), 'w') as fp:
    json.dump(params, fp)

