## Generate class-slice indexing table for experiments


### Overview

This is for experiment setting up for simulating few-shot image segmentation scenarios

Input: pre-processed images and their ground-truth labels

Output: a `json` file for class-slice indexing

In [1]:
%reset
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import glob
import SimpleITK as sitk
import sys
import json
sys.path.insert(0, '../../dataloaders/')
import niftiio as nio

In [2]:
IMG_BNAME="/home/user01/medical-data/norm/image_*.nii.gz"
SEG_BNAME="/home/user01/medical-data/norm/label_*.nii.gz"

In [3]:
imgs = glob.glob(IMG_BNAME)
segs = glob.glob(SEG_BNAME)
imgs = [ fid for fid in sorted(imgs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0])  ) ]
segs = [ fid for fid in sorted(segs, key = lambda x: int(x.split("_")[-1].split(".nii.gz")[0])  ) ]


In [4]:
imgs

['/home/user01/medical-data/norm/image_1.nii.gz',
 '/home/user01/medical-data/norm/image_2.nii.gz',
 '/home/user01/medical-data/norm/image_3.nii.gz',
 '/home/user01/medical-data/norm/image_5.nii.gz',
 '/home/user01/medical-data/norm/image_8.nii.gz',
 '/home/user01/medical-data/norm/image_10.nii.gz',
 '/home/user01/medical-data/norm/image_13.nii.gz',
 '/home/user01/medical-data/norm/image_15.nii.gz',
 '/home/user01/medical-data/norm/image_19.nii.gz',
 '/home/user01/medical-data/norm/image_20.nii.gz',
 '/home/user01/medical-data/norm/image_21.nii.gz',
 '/home/user01/medical-data/norm/image_22.nii.gz',
 '/home/user01/medical-data/norm/image_31.nii.gz',
 '/home/user01/medical-data/norm/image_32.nii.gz',
 '/home/user01/medical-data/norm/image_33.nii.gz',
 '/home/user01/medical-data/norm/image_34.nii.gz',
 '/home/user01/medical-data/norm/image_36.nii.gz',
 '/home/user01/medical-data/norm/image_37.nii.gz',
 '/home/user01/medical-data/norm/image_38.nii.gz',
 '/home/user01/medical-data/norm/ima

In [5]:
segs

['/home/user01/medical-data/norm/label_1.nii.gz',
 '/home/user01/medical-data/norm/label_2.nii.gz',
 '/home/user01/medical-data/norm/label_3.nii.gz',
 '/home/user01/medical-data/norm/label_5.nii.gz',
 '/home/user01/medical-data/norm/label_8.nii.gz',
 '/home/user01/medical-data/norm/label_10.nii.gz',
 '/home/user01/medical-data/norm/label_13.nii.gz',
 '/home/user01/medical-data/norm/label_15.nii.gz',
 '/home/user01/medical-data/norm/label_19.nii.gz',
 '/home/user01/medical-data/norm/label_20.nii.gz',
 '/home/user01/medical-data/norm/label_21.nii.gz',
 '/home/user01/medical-data/norm/label_22.nii.gz',
 '/home/user01/medical-data/norm/label_31.nii.gz',
 '/home/user01/medical-data/norm/label_32.nii.gz',
 '/home/user01/medical-data/norm/label_33.nii.gz',
 '/home/user01/medical-data/norm/label_34.nii.gz',
 '/home/user01/medical-data/norm/label_36.nii.gz',
 '/home/user01/medical-data/norm/label_37.nii.gz',
 '/home/user01/medical-data/norm/label_38.nii.gz',
 '/home/user01/medical-data/norm/lab

In [7]:
classmap = {}
LABEL_NAME = ["BG", "LIVER", "RK", "LK", "SPLEEN"]     


MIN_TP = 1 # minimum number of positive label pixels to be recorded. Use >100 when training with manual annotations for more stable training

fid = f'/home/user01/medical-data/norm/classmap_{MIN_TP}.json' # name of the output file. 
for _lb in LABEL_NAME:
    classmap[_lb] = {}
    for _sid in segs:
        pid = _sid.split("_")[-1].split(".nii.gz")[0]
        classmap[_lb][pid] = []

for seg in segs:
    pid = seg.split("_")[-1].split(".nii.gz")[0]
    lb_vol = nio.read_nii_bysitk(seg)
    n_slice = lb_vol.shape[0]
    for slc in range(n_slice):
        for cls in range(len(LABEL_NAME)):
            if cls in lb_vol[slc, ...]:
                if np.sum( lb_vol[slc, ...]) >= MIN_TP:
                    classmap[LABEL_NAME[cls]][str(pid)].append(slc)
    print(f'pid {str(pid)} finished!')
    
with open(fid, 'w') as fopen:
    json.dump(classmap, fopen)
    fopen.close()  
    

pid 1 finished!
pid 2 finished!
pid 3 finished!
pid 5 finished!
pid 8 finished!
pid 10 finished!
pid 13 finished!
pid 15 finished!
pid 19 finished!
pid 20 finished!
pid 21 finished!
pid 22 finished!
pid 31 finished!
pid 32 finished!
pid 33 finished!
pid 34 finished!
pid 36 finished!
pid 37 finished!
pid 38 finished!
pid 39 finished!


In [8]:
with open(fid, 'w') as fopen:
    json.dump(classmap, fopen)
    fopen.close()