In [21]:
# This program processes the vertebrae ground truth masks into normalised sets of 3 input points and 1 target point which can be used to train the point predictor

import os
import numpy as np
from skimage.measure import regionprops
import matplotlib.pyplot as plt

VERTEBRA_DICT = {'C2':1,'C3':2,'C4':3,'C5':4,'C6':5,'C7':6,'T1':7,'T12':8,'L1':9,'L2':10,'L3':11,'L4':12,'L5':13,'S1':14}
data_path = os.path.join('..','data','NHANES2','Vertebrae')

mode = 'test'

if mode == 'train':
    with open(os.path.join(data_path,'..','data_split','train.txt'),'r') as f:
        id_list = [line[0:6] for line in f.readlines()]
elif mode == 'val':
    with open(os.path.join(data_path,'..','data_split','val.txt'),'r') as f:
        id_list = [line[0:6] for line in f.readlines()]
elif mode == 'test':
    with open(os.path.join(data_path,'..','data_split','test.txt'),'r') as f:
        id_list = [line[0:6] for line in f.readlines()]

gt_path = os.path.join(data_path,'gts')
gt_filenames = [filename for filename in os.listdir(gt_path) if filename.split('_')[0] in id_list]

In [22]:
centroid_dict = {}

for filename in gt_filenames:
    
    id = filename.split('_')[0]
    level = filename.split('_')[1].split('.')[0]

    if id not in centroid_dict.keys():
        centroid_dict[id] = {}

    mask = np.load(os.path.join(gt_path,filename)).astype(np.uint8)
    h,w = mask.shape[-2:]

    c_y,c_x = regionprops(mask)[0]['centroid']
    #c_x,c_y = c_x/w,c_y/h

    centroid_dict[id][VERTEBRA_DICT[level]] = (c_x,c_y)

In [3]:
input_points = []
output_points = []

for id,dict in centroid_dict.items():
    
    sorted_points = [dict[level] for level in sorted(dict.keys())]
    for i in range(len(sorted_points)-3):
        points = sorted_points[i:i+4]

        input_points.append(points[:-1])
        input_points.append(points[:0:-1])

        output_points.append(points[-1])
        output_points.append(points[0])

In [13]:
np_input_points = np.array(input_points)
np_output_points = np.array(output_points)

os.makedirs(os.path.join(data_path,'..','point_predictor_data'),exist_ok=True)

np.save(os.path.join(data_path,'..','point_predictor_data',mode+'_input.npy'),np_input_points)
np.save(os.path.join(data_path,'..','point_predictor_data',mode+'_output.npy'),np_output_points)

NameError: name 'input_points' is not defined

In [23]:
import pickle

os.makedirs(os.path.join(data_path,'points'),exist_ok=True)

pickle_dict = {}

for id,v_dict in centroid_dict.items():
    n=0
    for vid,centroid in v_dict.items():
        level = next(key for key, value in VERTEBRA_DICT.items() if value == vid)
        v_key = id+'_'+level
        pickle_dict[v_key] = (int(centroid[0]),int(centroid[1]))
        bg_key = id+'_BG'+str(n)
        n+=1
        pickle_dict[bg_key] = (0,0)
    
with open(os.path.join(data_path,'points',mode+'.pkl'),'wb') as f:
    pickle.dump(pickle_dict,f)

In [9]:
with open(os.path.join(data_path,'points',mode+'.pkl'),'rb') as f:
    loaded = pickle.load(f)

print(loaded)

{'C00166_C2': (773, 1035), 'C00166_BG0': (0, 0), 'C00166_C3': (673, 1147), 'C00166_BG1': (0, 0), 'C00166_C4': (600, 1202), 'C00166_BG2': (0, 0), 'C00166_C5': (515, 1249), 'C00166_BG3': (0, 0), 'C00166_C6': (448, 1311), 'C00166_BG4': (0, 0), 'C00166_C7': (368, 1376), 'C00166_BG5': (0, 0), 'C00174_C2': (1197, 529), 'C00174_BG0': (0, 0), 'C00174_C3': (1093, 636), 'C00174_BG1': (0, 0), 'C00174_C4': (1020, 719), 'C00174_BG2': (0, 0), 'C00174_C5': (948, 810), 'C00174_BG3': (0, 0), 'C00174_C6': (879, 888), 'C00174_BG4': (0, 0), 'C00179_C2': (1017, 918), 'C00179_BG0': (0, 0), 'C00179_C3': (912, 1021), 'C00179_BG1': (0, 0), 'C00179_C4': (833, 1089), 'C00179_BG2': (0, 0), 'C00179_C5': (743, 1142), 'C00179_BG3': (0, 0), 'C00179_C6': (668, 1197), 'C00179_BG4': (0, 0), 'C00197_C2': (996, 943), 'C00197_BG0': (0, 0), 'C00197_C3': (895, 1062), 'C00197_BG1': (0, 0), 'C00197_C4': (828, 1143), 'C00197_BG2': (0, 0), 'C00197_C5': (758, 1222), 'C00197_BG3': (0, 0), 'C00197_C6': (704, 1305), 'C00197_BG4': (0

In [10]:
print(os.getcwd())

/gpfs3/well/papiez/users/saa032/projects/Medical-SAM-Adapter
