In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import  glob
from scipy.spatial import ConvexHull
from Constants import Const
import joblib
from pointcloud_utils import *
from scipy.spatial.distance import cdist, pdist, squareform
import matplotlib as mpl
import torchio as tio

import plotly.graph_objects as go
from pointcloud_utils import *
%matplotlib notebook

from ReaderWriter import BetterDicomReader, dicom_reader_from_ids
# from multiprocessing import cpu_count
# import typing
import os
# from queue import *
# from tqdm import tqdm
# from threading import Thread
import pickle
import simplejson
import open3d as o3d
import pydicom

In [4]:
def get_all_dicom_ids(root = '../data/DICOMs/R01/'):
    files = glob.glob(root+'*/')
    ids = []
    for f  in files:
        pid = f.replace(root,'').replace('/','')
        if pid.isnumeric():
            ids.append(int(pid))
        else:
            print('bad pid',pid)
    return ids
get_all_dicom_ids()

[1054079696,
 1072572079,
 1079757401,
 1087308891,
 1099927508,
 1103015749,
 1107681010,
 1108642427,
 1138096463,
 1143391354,
 1164546699,
 1167168001,
 1173728658,
 1178044145,
 1194116893,
 1235569621,
 1265845118,
 1269213210,
 1276736352,
 2774318802,
 2776638846,
 2803120404,
 2804890849,
 2815583275,
 2817169157,
 2843895295,
 2865890865,
 2867468906,
 2875588687,
 2883823678,
 2889102751,
 2894996073,
 2908060983,
 2929571068,
 2932807221,
 2939740989,
 2983776095,
 2989874876,
 3028333367,
 3035721150,
 3045110595,
 3045556976,
 3045918834,
 3049758970,
 3071491330,
 3077525807,
 3081060090,
 3099145083,
 3100266114,
 3108161676,
 3125978990,
 3151449158,
 3192749693,
 3192926363,
 3194171650,
 3196236080,
 3205005928,
 3225956079,
 3235750820,
 3239454379,
 3256350141,
 3283538700,
 3316448321,
 3321571409,
 3327830751,
 3346286228,
 3355839884,
 3366888610,
 3370207907,
 3383739521,
 3442548822,
 3514335131,
 3563849182,
 3694073920,
 3816960111,
 3832820671,
 4017119917,

In [5]:
def get_associations(dr):
    all_rois = dr.return_rois(print_rois=False)
    names = ['gtv','ctv','ptv']
    assoc = {k:v for k,v in Const.organ_associations.items()}
    for roi in all_rois:
        counts = 0
        for name in names:
            
            if name in roi:
                if counts < 1:
                    if 'node' in roi or 'nodal' in roi or 'gtvn' in roi or 'ln' in roi:
                        assoc[roi] = 'gtvn'
                    else:
                        assoc[roi] = name
                    counts +=1
                else:
                    print(roi)
    return assoc


def get_all_patients(dr=None,dicom_dir=None, rois = None, associations = None):
    #process dicom files for all patients
    #doesn't extract info yet
    if rois is None:
        rois = Const.organ_list[:] 
    gtvs = ['gtv','gtvn','ctv','ptv']
        
    if dr is None:
        dr = BetterDicomReader(dicom_dir)
    if associations is None:
        associations = get_associations(dr)
    
    dr.set_contour_names_and_associations(Contour_Names=rois,associations=associations)
    
    all_uids = dr.get_all_uids()
    plist = []
    roi_map = {i+1: roi for i,roi in enumerate(rois)}
    for i,uid in enumerate(all_uids):
        #dicom reader reads everthing but only processes the current UID
        dr.set_by_uid(uid)
        dr.get_images_and_mask()
        patient = dr.get_current_patient()
        #image slices as a stack
        patient['ArrayDicom'] = dr.ArrayDicom[:]
        #same shape as arraydicom, replaces values with the index corresponding to each roi
        patient['mask'] = dr.mask[:].astype(np.int32)
        #save the index -> roi dictionary for later
        patient['roi_mask_map'] = roi_map
        plist.append(patient)
        print(uid,'done',i,'of',len(all_uids))
    return plist, roi_map

#temp subset of organs to test with since it takes a while to run it all

#Get a list of dictionaries with processed patient info from the dicom
# dr = dicom_reader_from_ids([2677877484,2411034155])
# plist, roi_map = get_all_patients(dr)
# plist

In [6]:
def pointcloud_distance_batched(pc1,pc2,batch_size=200,metric='euclidean'):
    #get two pointcloud arrays, checks distance between each pair of points
    #returns the smallest distance (inter-organ distance)
    if pc1.shape[0] < 1 or pc2.shape[0] < 1:
        return False
    
    pc1_batches = np.array_split(pc1,batch_size,axis=0)
    min_dist = 100000
    for batch in pc1_batches:
        batch_dists = cdist(batch,pc2,metric)
        batch_min = batch_dists.min()
        if batch_min < min_dist:
            min_dist = batch_min 

    return min_dist

def points_to_spatial(points, location, gridsize):
    scale = np.array(gridsize)
    corner = np.array(location)
    newpoints = points*scale + corner
    return  newpoints

def get_roi_pointclouds(pdict, roi_map=None, scale = True, concavify=True,ascale=.5,**kwargs):
    #get locations of contours.  I may need to tweak this part if I want to get actual dose values?
    clouds = {}
    #get the corresponding image value of each point. will add stuff for dose later idk
    pixel_values = {}
    mask = pdict['mask']
    #get array locations
    scale =np.array( [pdict['Pixel_Spacing_X'],pdict['Pixel_Spacing_Y'],pdict["Slice_Thickness"]])
    if roi_map is None:
        roi_map = pdict['roi_mask_map']
    all_points = []
    for index, roi in roi_map.items():
        #should be x, y, z if I tranpose it (default it's slicexheightxwidth)
        contour = np.argwhere(mask.T == index)
        #skip bad values with < 4 points because the algorithms break
        if contour.shape[0] > 3 and concavify:
            #this should get the points on the surface of the object (concave hull) if parameters are good
            surface_indices = concave_hull_3d(contour,return_vertices = True,alpha_scales=[ascale],**kwargs)
            contour = contour[surface_indices]
        contour = points_to_spatial(contour,[0,0,0],scale)
        clouds[roi] = contour
        
    return clouds

def pc_dist_worker_v2(args):
    #wrapper to help parallelize pointcloud distance calclations
    if args[0] == args[1]:
        return 0
    clouds = args[2]
    c1  = clouds[args[0]]
    c2 = clouds[args[1]]    
    if len(c1) < 1 or len(c2) < 1:
        return 0
    return pointcloud_distance_batched(c1,c2)

def get_interorgan_distance_v2(roi_clouds, 
                             parallel=True,
                             **kwargs):
    #calc interorgan distances on a dict of {roi: pointcloud, roi2: pointcloud2} as an array
    #also returns the order of rois in the dict [r1, r2, r3,...] =>[ [d(r1,r1),d(r1,r2),d(r1,r3)], [d(r2,r1),d(r2,r2),d(r2,r3)],...]

    valid_rois = list(roi_clouds.keys())
    dists = {}
    for roi in valid_rois:
        if parallel and parallel != 0:
            dlist = joblib.Parallel(n_jobs=parallel)(joblib.delayed(pc_dist_worker_v2)((roi,r2,roi_clouds)) for r2 in valid_rois)
        else:
            dlist = [pc_dist_worker_v2((roi,r2,roi_clouds)) for r2 in valid_rois]     
        dists[roi] = dlist
    return dists

def get_distances(pids=None,dicom_root='../data/DICOMs/R01/',target_file='../data/r01_distances',overwrite=False):
    try:
        with open(target_file+'.json','r') as f:
            distances = simplejson.load(f)
    except:
        distances = {}
    if pids is None:
        pids = get_all_dicom_ids()
    finished = set([int(p) for p in distances.keys()])
    print(len(finished))
    for pid in pids:
        #it crashes when I try to process these idk why
#         if int(pid) in [2677877484,2340030369,1138096463,2875588687,5968002865,6060411302,6321133829,7143413136,8999014515,9479197119 ,1596568859]:
#             conitinue
        if not overwrite and int(pid) in finished:
            continue
        dr = dicom_reader_from_ids([pid],root=dicom_root)
        clouds = {}
        for roilist in [Const.organ_list[:],['gtv','gtvn'],['ctv','ptv']]:
            plist, _ = get_all_patients(dr,rois= roilist)
            pentry = plist[0]
            tempclouds = get_roi_pointclouds(pentry)
            for k,v in tempclouds.items():
                clouds[k] = v
        dists = get_interorgan_distance_v2(clouds)
        distances[pid] = dists
        try:
            with open(target_file+'.json','w') as f:
                simplejson.dump(distances,f,default=np_converter)
        except Exception as e:
            print('issue saving at pid',pid)
            print(e)
    return pd.DataFrame(distances)

get_distances(overwrite=True)

SyntaxError: invalid syntax (798899511.py, line 74)

In [None]:
len(set({1167168001, 4625578503, 4074385931, 2508553230, 1269213210, 2815583275, 4862920237, 2865890865, 2932807221, 4840069687, 4713240632, 2883823678, 3694073920, 3316448321, 2279280705, 3151449158, 4481544782, 3321571409, 3442548822, 1079757401, 1087308891, 1522246751, 3205005928, 2894996073, 2867468906, 3816960111, 3045918834, 1143391354, 1108642427, 3192749693, 3383739521, 3100266114, 1164546699, 3108161676, 4862586001, 3346286228, 4663235737, 3192926363, 3366888610, 3370207907, 4443664553, 3239454379, 4017119917, 1072572079, 3962148532, 1670301878, 2106963643, 2989874876, 1054079696, 2774318802, 3327830751, 2770898143, 2804890849, 4554821349, 7291736302, 3225956079, 3045556976, 1178044145, 1107681010, 3081060090, 3049758970, 2929571068, 3071491330, 3194171650, 2817169157, 5802597125, 5778599178, 3283538700, 2803120404, 1194116893, 4100593439, 5618650917, 3077525807, 3196236080, 2908060983, 3028333367, 5089502522, 2939740989, 3217148733, 3045110595, 1103015749, 4509480776, 5110518099, 5038138708, 2983776095, 1276736352, 5709953389, 3125978990, 3099145083, 2776638846, 1265845118, 4646292867, 5826154886, 5751912329, 3355839884, 1173728658, 5367270807, 3514335131, 5211387293, 2889102751, 3235750820, 3256350141, 3035721150, 3832820671, 3023016905, 1099927508, 1235569621, 3563849182, 4488837609, 1848656363, 1293745646, 4151569401, 2843895295}))

In [None]:
import numpy as np
np.array_split(np.zeros((10,10)),3)

In [None]:
test = (np.random.random((10,4))*10).astype(int)
test

In [None]:
np.argwhere(test == test.min())