In [207]:
#automatically reload stuff
%load_ext autoreload
%autoreload 2
import Utils
import matplotlib.pyplot as plt
from SpatialPreprocessing import *
import json
import Formatting
from Constants import Const
import copy
import Metrics
import Models
import re
import Cluster
pd.set_option('display.max_columns', None)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [214]:
from warnings import simplefilter 
simplefilter(action='ignore', category=DeprecationWarning)

In [7]:
from Levenshtein import distance as levenshtein_distance
import collections

class SpellChecker():
    
    def __init__(self, keywords, aliases, max_edit_distance = .15,normalize_score = True):
        self.keywords = keywords #list
        self.aliases = aliases #dict 
        self.normalize_score = normalize_score
        self.positional_pairs = [('Lt','Rt'),('L_','R_'),('_L','_R')]
        self.max_edit_distance = max_edit_distance
            
    def word_distance(self,word1,word2):
        dist = 0
        #add an extra penalty if one is right and one is left
        for (x,y) in self.positional_pairs:
            if x in word1 and y in word2 or y in word1 and x in word2:
                dist += len(word1)
                break
        clean = lambda w: w.strip().lower().replace("_","")
        dist += levenshtein_distance(clean(word1),clean(word2))
        if self.normalize_score:
            dist = dist/max(len(str(word1)),len(str(word2)))
        return dist
        
    def best_spell_match(self,name,words):
        #compare a word with a list of words
        #get the closest word to source word based on edit distance
        best_match = None
        best_dist = np.inf
        for word in words:
            ld = self.word_distance(name,word)
            if ld < best_dist:
                best_dist = ld
                best_match = word
                if ld <= 0:
                    break
        return best_match, best_dist
    
    def spellcheck_df(self, df, cols=None,unique = True):
        df = df.copy()
        rename_dict = {}
        if cols is None:
            cols = list(df.columns)

        all_renames = {}
        for col in cols:
            matchwords = [i for i in self.keywords]
            col_words = list(np.unique(df[col].values.astype('str')))
            rename_dict = {}
            for cword in col_words:
                match,dist = self.best_spell_match(cword,matchwords)
                if dist < self.max_edit_distance:
                    if match != cword:
                        rename_dict[cword] = match
                        all_renames[cword] = match
                else:
                    aliasmatch, alias_dist = self.best_spell_match(cword, list(self.aliases.keys()))
                    if alias_dist < self.max_edit_distance:
                        target = self.aliases[aliasmatch]
                        rename_dict[cword] = target
                        all_renames[cword] = target
            df[col] = df[col].apply(lambda x: rename_dict.get(x,x))
        return df,all_renames

In [129]:
class RadDataset():
    
    #to keep it all consistent with the other dataset
    file_header_renames = {
        'mean': 'mean_dose',
        'Volume': 'volume',
        'minGy': 'min_dose',
        'maxGy': 'max_dose',
#         'Structure': 'ROI'
    }
    
    # alliases for all the organs in the data because people just make up the abreviations
    organ_rename_dict = {
        'cricoid': 'Cricoid_cartilage',
         'cricopharyngeus': 'Cricopharyngeal_Muscle',
         'esophagus_u': 'Esophagus',
         'oral_cavity': 'Extended_Oral_Cavity',
         'musc_geniogloss': 'Genioglossus_M',
         'hardpalate': 'Hard_Palate',
         'bone_hyoid': 'Hyoid_bone',
         'musc_constrict_i': 'IPC',
         'lips_lower': 'Lower_Lip',
         'lips_upper': 'Upper_Lip',
         'musc_constrict_m': 'MPC',
         'musc_mgh_complex': 'Mylogeniohyoid_M',
         'palate_soft': 'Soft_Palate',
         'musc_constrict_s': 'SPC',
         'spinalcord_cerv': 'Spinal_Cord',
         'larynx_sg': 'Supraglottic_Larynx',
         'cartlg_thyroid': 'Thyroid_cartilage',
         'lens_r': 'Rt_Anterior_Seg_Eyeball',
         'lens_l': 'Lt_Anterior_Seg_Eyeball',
         'brachial_plex_r': 'Rt_Brachial_Plexus',
         'brachial_plex_l': 'Lt_Brachial_Plexus',
         'brac_plx_l': 'Lt_Brachial_Plexus',
         'brac_plx_r': 'Rt_Brachial_Plexus',
         'brachialplex_l': 'Lt_Brachial_Plexus',
         'brachialplex_r': 'Rt_Brachial_Plexus',
         'pterygoid_lat_r': 'Rt_Lateral_Pterygoid_M',
         'pterygoid_lat_l': 'Lt_Lateral_Pterygoid_M',
         'musc_masseter_r': 'Rt_Masseter_M',
         'musc_masseter_l': 'Lt_Masseter_M',
         'bone_mastoid_r': 'Rt_Mastoid',
         'bone_mastoid_l': 'Lt_Mastoid',
         'pterygoid_med_r': 'Rt_Medial_Pterygoid_M',
         'pterygoid_med_l': 'Lt_Medial_Pterygoid_M',
         'parotid_r': 'Rt_Parotid_Gland',
         'parotid_l': 'Lt_Parotid_Gland',
         'l_parotid': 'Lt_Parotid_Gland',
         'lparotid': 'Lt_Parotid_Gland',
         'rparotid': 'Rt_Parotid_Gland',
         'r_parotid': 'Rt_Parotid_Gland',
         'eye_r': 'Rt_Posterior_Seg_Eyeball',
         'eye_l': 'Lt_Posterior_Seg_Eyeball',
         'musc_sclmast_r': 'Rt_Sternocleidomastoid_M',
         'musc_sclmast_l': 'Lt_Sternocleidomastoid_M',
         'glnd_submand_r': 'Rt_Submandibular_Gland',
         'glnd_submand_l': 'Lt_Submandibular_Gland',
         'inferior_pharyngeal_constrictor': 'IPC',
         'inferior_constrictor': 'IPC',
         'inferior_constrictor_muscle': 'IPC',
        'superior_pharyngeal_constrictor': 'IPC',
         'superior_constrictor_muscle': 'SPC',
         'superior_constrictor': 'SPC',
         'larynx_roi': 'Larynx',
         'glottis': 'Glottic_Area',
         'cricopharyngeus_muscle': 'Cricopharyngeal_Muscle',
         'cavity_oral': 'Extended_Oral_Cavity'
    }
    
    def __init__(self, path = None, organ_list = None,max_missing_ratio = .3):
        if path  is None:
            path = Const.data_dir + 'Cohort_SMART2_530pts_(486pts).xlsx'
        if 'xlsx' in path:
            dvh_df = pd.read_excel(path,index_col=0)
        else: 
            dvh_df = pd.read_csv(path,index_col=0)
        if organ_list is None:
            self.organ_list = Const.organ_list
        self.max_missing = int(len(self.organ_list)*max_missing_ratio)
        self.dropped_organ_names = set([])
        self.spellchecker =  SpellChecker(Const.organ_list, 
                          RadDataset.organ_rename_dict)
        
        self.dvh_df = self.clean_dvh_df(dvh_df)
        self.all_patient_ids = sorted(self.dvh_df.id.values)
#         self.get_dvh_info(dvh_df)
        
#     def rename_gtv(self,name):
#         to_sub = [
#             ('primary','p'),
#             ('nodes','n'),
#             ('nodal','n'),
#             ('nodes','n'),
#             ('_ln','n'),
#             ('',''),
#         ]
#         name = re.sub(r'(\S+)_gtv(\S*)','gtv_\1\2',str(name),flags=re.IGNORECASE)
#         for k,v in to_sub:
#             name = name.replace(k,v)
#         gtv_name= re.sub(r'.*gtv.*([np]).*',r'_GTV\1',name)
#         if gtv_name == 'gtv':
#             gtv_name = '_GTVp'
#         return gtv_name
        
#     def get_organ_name(self,organ_name):
#         ldict = RadDataset.organ_rename_dict
#         oname = organ_name.lower()
#         if 'gtv' in oname:
#             return self.rename_gtv(oname)
#         if oname in ldict.keys():
#             return ldict.get(oname,oname)
#         return oname
        
    def clean_dvh_df(self, df, organ_rename_dict = None):
        df = df.rename(RadDataset.file_header_renames,axis=1)
        df = df[df.DicomType == "ORGAN"]
        
        #this maps words to words in the rename dict
        #somewhat weird because it can inverse the order but it re-fixes itself?
        #don't know how else to prevent bugs
        print([s for s in np.unique(df.Structure)if 'gtv' in s.lower()])
        spellchecked_df, _= self.spellchecker.spellcheck_df(df,['Structure'])
        df['ROI'] = spellchecked_df['Structure']
        df = df.drop(["DicomType"],axis=1)
        df = df.reset_index()
        df = df[df.ROI.isin(self.organ_list)] #only keep organs we car about
        df = df.drop_duplicates(subset=['id','ROI','volume'])
        df = df[df.mean_dose != 'error'] #idk what this is from
        df = self.filter_valid_patients(df.reset_index()) 
        df = self.add_missing_organs(df) 
        
        hist_cols = [c for c in df.columns if (re.match('[DV]\d+',c) is not None)]
        df[hist_cols] = df[hist_cols].astype('float16')
        return df.drop(['index'],axis=1)
            
    def add_patient_organs(self,pid,patient_df):
        pdf = patient_df.copy()
        rois = np.unique(patient_df.ROI.values)
        for organ in self.organ_list:
            if organ not in rois:
                entry = pd.Series([pid,organ,organ],index=['id','Structure','ROI'])
                pdf = pdf.append(entry,ignore_index=True)
        return pdf
    
    def add_missing_organs(self,df):
        dfs = []
        for pid,subdf in df.groupby('id'):
            subdf = self.add_patient_organs(pid,subdf).set_index("ROI")
            subdf = subdf.loc[self.organ_list]
            subdf = subdf.reset_index()
            dfs.append(subdf)
        return pd.concat(dfs)
            
    def filter_valid_patients(self,df):
        dfs = []
        flag = True
        for pid,subdf in df.groupby('id'):
            has_gtv = subdf.ROI.apply(lambda x: 'gtv' in x.lower())
            flag = flag & (has_gtv.sum() >= 0)
            
            rois = set(np.unique(subdf.ROI))
            n_rois = len(rois)
            n_missing = n_rois - len(self.organ_list)
            if n_missing != 0:
                print('patient ', pid, 'has', n_missing,'organs off')
            if n_rois != subdf.shape[0]:
                print('patient',pid,'has duplicate organs?')
            
            flag = flag & (n_missing < self.max_missing)
            if flag:
                dfs.append(subdf)
        return pd.concat(dfs)

    def filter_dvh_organs(self,dvh_df = None):
        if dvh_df is None:
            dvh_df = self.dvh_df
        valid_organs = set([o.lower() for o in self.organ_list])
        valid = dvh_df.ROI.apply(lambda x: x.lower() in valid_organs)
        dvh_df = dvh_df[valid]
        return dvh_df
    
rds = RadDataset()
rds.dvh_df.T

[]
patient  2.0 has -4 organs off
patient  6.0 has -4 organs off
patient  9.0 has -7 organs off
patient  10.0 has -4 organs off
patient  12.0 has -5 organs off
patient  15.0 has -4 organs off
patient  18.0 has -7 organs off
patient  19.0 has -4 organs off
patient  22.0 has -4 organs off
patient  24.0 has -4 organs off
patient  25.0 has -4 organs off
patient  28.0 has -6 organs off
patient  31.0 has -7 organs off
patient  32.0 has -4 organs off
patient  34.0 has -7 organs off
patient  35.0 has -45 organs off
patient  37.0 has -5 organs off
patient  38.0 has -4 organs off
patient  39.0 has -4 organs off
patient  40.0 has -5 organs off
patient  41.0 has -4 organs off
patient  42.0 has -4 organs off
patient  43.0 has -4 organs off
patient  44.0 has -4 organs off
patient  45.0 has -5 organs off
patient  46.0 has -6 organs off
patient  47.0 has -4 organs off
patient  48.0 has -4 organs off
patient  49.0 has -5 organs off
patient  50.0 has -5 organs off
patient  51.0 has -4 organs off
patient

patient  423.0 has -6 organs off
patient  424.0 has -8 organs off
patient  425.0 has -4 organs off
patient  427.0 has -4 organs off
patient  428.0 has -5 organs off
patient  429.0 has -4 organs off
patient  430.0 has -5 organs off
patient  431.0 has -4 organs off
patient  433.0 has -4 organs off
patient  434.0 has -8 organs off
patient  447.0 has -2 organs off
patient  448.0 has -4 organs off
patient  450.0 has -4 organs off
patient  451.0 has -3 organs off
patient  453.0 has -4 organs off
patient  454.0 has -5 organs off
patient  463.0 has -4 organs off
patient  466.0 has -1 organs off
patient  467.0 has -4 organs off
patient  469.0 has -4 organs off
patient  470.0 has -5 organs off
patient  475.0 has -4 organs off
patient  486.0 has -4 organs off
patient  487.0 has -5 organs off
patient  488.0 has -4 organs off
patient  489.0 has -4 organs off
patient  494.0 has -4 organs off
patient  496.0 has -4 organs off
patient  497.0 has -4 organs off
patient  498.0 has -4 organs off
patient  5

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,36,37,38,39,40,41,42,43,44,45
ROI,Esophagus,Spinal_Cord,Lt_Brachial_Plexus,Rt_Brachial_Plexus,Cricopharyngeal_Muscle,Lt_thyroid_lobe,Rt_thyroid_lobe,Cricoid_cartilage,IPC,MPC,...,Extended_Oral_Cavity,Mandible,Hard_Palate,Lt_Posterior_Seg_Eyeball,Rt_Posterior_Seg_Eyeball,Lt_Anterior_Seg_Eyeball,Rt_Anterior_Seg_Eyeball,Lower_Lip,Upper_Lip,Glottic_Area
id,2,2,2,2,2,2,2,2,2,2,...,822,822,822,822,822,822,822,822,822,822
Structure,Esophagus_U,SpinalCord_Cerv,Brachial_Plex_L,Brachial_Plex_R,Cricopharyngeus,Lt_thyroid_lobe,Rt_thyroid_lobe,Cricoid,Musc_Constrict_I,Musc_Constrict_M,...,Oral_Cavity,Mandible,Hardpalate,Eye_L,Eye_R,Lens_L,Lens_R,Lips_Lower,Lips_Upper,Glottic_Area
volume,14279.5,12866.4,7139.74,5500.03,2698.52,,,8964.84,2900.39,1384.28,...,114498,67071.9,2670.29,8495.33,8804.32,339.508,354.767,1903.53,3667.83,1995.09
mean_dose,35.7206,23.8169,36.2478,42.9462,49.9044,,,48.6903,41.852,66.4098,...,35.4516,17.989,8.74754,0.78709,0.03987,0.691854,0.00123656,0.933136,1.33363,23.8146
min_dose,297,1768,341,1308,4070,,,2579,2834,5067,...,127,25,280,1,1,28,1,62,57,664
max_dose,5827,3008,6272,6034,5747,,,6842,5989,7380,...,6958,7348,2501,139,44,85,1,181,448,5211
V5,96.375,100,95.6875,100,100,,,100,100,100,...,90.75,37.0938,74.4375,0,0,0,0,0,0,100
V10,91.5625,100,80,100,100,,,100,100,100,...,79.75,35.5625,31.5,0,0,0,0,0,0,83.8125
V15,88.25,100,75.5,99,100,,,100,100,100,...,71.5,34.5625,12.7109,0,0,0,0,0,0,54.4062


In [134]:
def get_dvhcol_pos(x):
    m = re.match('[DV]'+'(\d+)',x)
    if m is not None:
        return int(m.group(1))
    else:
        return -1
    
def get_dvh_columns(dvh,key,steps=None):
    good_cols = []
    for c in dvh.columns:
        match = re.match(key+'(\d+)', c)
        if match is not None:
            use = True
            if steps is not None and match.group(1) is not None:
                value = int(match.group(1))
                if value not in steps:
                    use = False
            if use:
                good_cols.append(match.group(0))
    keys = sorted(good_cols, key = get_dvhcol_pos)
    pos = [get_dvhcol_pos(x) for x in good_cols]
    return keys, pos

def to_dvh_array(df, key = 'D',dose_points = None, organ_list = None):
    #uses the cvh dataframe and key to return a 3d array of dvh values
    #pos is the array of values from the keys (e.g. V5 V10 ... => [5,10,...])
    #so we can do clustering or something idk
    keys, pos = get_dvh_columns(df,key,steps=dose_points)
    array = []
    for pid, subdf in df.groupby('id'):
        if organ_list is None:
            organ_list = list(subdf.ROI.values)
            assert(len(organ_list) == len(set(organ_list)))
        subdf = subdf.set_index('ROI').loc[organ_list,keys]
        array.append(subdf.values)
    array = np.stack(array)
    print(np.isnan(array).sum())
    return array, pos

dvh_array, dvh_pos = to_dvh_array(rds.dvh_df)

64860


In [194]:
dvh_widths = []
curr = 0
for i in np.arange(len(dvh_pos)):
    width = dvh_pos[i] - curr
    dvh_widths.append(width)
    curr = width
dvh_widths = np.array(dvh_widths)
dvh_widths

array([ 2,  3,  7,  8, 12, 13, 17, 18, 22, 23, 27, 28, 32, 33, 37, 38, 42,
       43, 47, 48, 49, 49, 50])

In [210]:
import Metrics 
sim_measures = {
    'jaccard': Metrics.Jaccard2d(),
    'euclidean': Metrics.Euclidean2D(),
    'em_weighted': Metrics.Wasserstein2d(steps=dvh_widths),
    'em': Metrics.Wasserstein2d(),
}
sims = {}
for sname, sm in  sim_measures.items():
    sim_matrix = sm.get_similarity_matrix(dvh_array)
    sims[sname] = sim_matrix
sims

  denominator = x.dot(x) + y.dot(y) - x.dot(y)


0.0 1.0
0.0 1.0
0.0 1.0
0.0 1.0


In [228]:
from sklearn.metrics import silhouette_score
results = []
for sname,sim in sims.items():
    for n in [2,3,4,5,6]:
        for link in ['ward','centroid','weighted','average']:
            try:
                entry = {'similarity':sname,'n':n,'link':link}
                clusterer = Cluster.SimilarityClusterer(n,link=link)
                clust = clusterer.fit_predict(sim)
                ss = silhouette_score(1/(1+sim),clust,metric='precomputed')
                entry['silhouette'] = ss
                results.append(entry)
            except Exception as e:
                pass
#                 print(link)
#                 print(e)
results = pd.DataFrame(results)
results.sort_values('silhouette',ascending=False).T

Unnamed: 0,0,1,2,3,14,10,12,13,8,15,11,9,5,6,4,7
link,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward,ward
n,3,4,5,6,5,5,3,4,3,6,6,4,4,5,3,6
silhouette,0.285968,0.285968,0.285968,0.285968,0.0448566,0.0406977,0.0400638,0.0391948,0.0388422,0.0365142,0.0359748,0.0322165,0.0123122,0.0100676,0.0092519,0.00871551
similarity,jaccard,jaccard,jaccard,jaccard,em,em_weighted,em,em,em_weighted,em,em_weighted,em_weighted,euclidean,euclidean,euclidean,euclidean


In [115]:
# symptom_df = pd.read_csv(Const.data_dir + 'patient_symptom_formatted.csv',index_col=0)
# symptom_df

#this is code for gettings gtvs from the older file. new one is currently missing them so I might merge them
#currently I cant actually merge them at all so not yet.
def rename_gtv(name):
    to_sub = [
        ('primary','p'),
        ('nodes','n'),
        ('nodal','n'),
        ('nodes','n'),
        ('_ln','n'),
        ('',''),
    ]
    name = name.lower()
    name = re.sub(r'(\S+)_gtv(\S*)','gtv_\1\2',str(name),flags=re.IGNORECASE)
    for k,v in to_sub:
        name = name.replace(k,v)
    gtv_name= re.sub(r'.*gtv.*([np]).*',r'_GTV\1',name)
    if gtv_name == 'gtv':
        gtv_name = '_GTVp'
    return gtv_name

def load_gtv_dvh(path = None):
    if path is None:
        path = Const.data_dir + 'Cohort_SMART_anonomyzed.xlsx'
    gtv_df = pd.read_excel(path)
    #only keep gtvs
    gtv_df = gtv_df[gtv_df.Structure.apply(lambda x: 'gtv' in x.lower())]
    #roi is the renamed version of structure
    gtv_df['ROI'] = gtv_df.Structure.apply(rename_gtv)
    gtv_df = gtv_df.drop(['DicomType'],axis=1)
    #properly renamed ones should be _GTVp or _GTVn. currently no _GTVnX
    gtv_df = gtv_df[gtv_df.ROI.apply(lambda x: '_gtv' in x.lower())]
    return gtv_df

gtv_df = load_gtv_dvh()
for pid,subdf in gtv_df.groupby('id'):
    print(pid)
    print(subdf)
    print()

6
    id Structure        Volume     mean minGy maxGy   V5  V10  V15  V20  ...  \
36   6       GTV  40762.023926  71.7218  6734  7423  100  100  100  100  ...   

      D70    D75    D80    D85    D90    D95    D97    D98   D99    ROI  
36  71.45  71.37  71.29  71.19  71.05  70.84  70.69  70.55  70.2  _GTVp  

[1 rows x 46 columns]

9
     id Structure        Volume     mean minGy maxGy       V5      V10  \
235   9       GTV  25781.455994  66.2363     1  6896  99.5761  99.5761   

         V15      V20  ...    D70    D75    D80    D85    D90    D95    D97  \
235  99.5761  99.5761  ...  66.85  65.93  64.78  63.69  62.94  62.08  61.86   

       D98    D99    ROI  
235  61.72  61.57  _GTVp  

[1 rows x 46 columns]

10
     id Structure        Volume     mean minGy maxGy   V5  V10  V15  V20  ...  \
405  10  GTVn_CEC   7358.093262  71.6713  7068  7362  100  100  100  100  ...   
406  10  GTVp_CEC  93455.200195  71.7998  7008  7343  100  100  100  100  ...   

       D70    D75    D80    D8

      id      Structure        Volume     mean minGy maxGy   V5  V10  V15  \
5522  95  GTV_NodePreop  11215.897751  66.5064  6531  6769  100  100  100   

      V20  ...    D70    D75    D80    D85    D90    D95    D97    D98   D99  \
5522  100  ...  66.29  66.21  66.13  66.06  65.96  65.87  65.81  65.77  65.7   

        ROI  
5522  _GTVp  

[1 rows x 46 columns]

97
      id      Structure        Volume     mean minGy maxGy   V5  V10  V15  \
5809  97            GTV   7003.784180  67.6239  6619  6989  100  100  100   
5810  97  GTV_PreopNode  38364.257812  63.7184  5964  6907  100  100  100   
5812  97       GTVp_CEC   7003.784180  67.6239  6619  6989  100  100  100   

      V20  ...    D70    D75    D80    D85    D90    D95    D97    D98    D99  \
5809  100  ...  67.35  67.28   67.2  67.11  66.97  66.73  66.61  66.52  66.42   
5810  100  ...  61.84  61.66  61.49  61.35  61.16  60.94  60.77  60.65  60.45   
5812  100  ...  67.35  67.28   67.2  67.11  66.97  66.73  66.61  66.52  66.42

        id         Structure         Volume     mean minGy maxGy       V5  \
10290  190  GTV_Preop__nodes    7106.781006  64.9083  6257  6846      100   
10291  190           GTV_exp  102094.573975  64.4678     1  6887  99.8951   

           V10      V15      V20  ...    D70    D75    D80    D85    D90  \
10290      100      100      100  ...  64.04  63.94  63.83  63.71  63.57   
10291  99.7861  99.7175  99.5965  ...     64  63.87  63.73  63.56  63.31   

         D95    D97    D98    D99    ROI  
10290  63.44  63.36  63.23  63.04  _GTVn  
10291  61.69  57.06  51.35  39.02  _GTVp  

[2 rows x 46 columns]

193
        id       Structure       Volume     mean minGy maxGy   V5  V10  V15  \
10340  193  GTV_Preop_Node  5779.266357  61.5437  5529  6313  100  100  100   

       V20  ...    D70    D75   D80    D85    D90    D95    D97    D98    D99  \
10340  100  ...  61.27  61.19  61.1  60.99  60.87  60.71  60.56  60.44  60.22   

         ROI  
10340  _GTVn  

[1 rows x 46 columns]

194
  

[2 rows x 46 columns]

260
        id Structure        Volume     mean minGy maxGy   V5  V10  V15  V20  \
13569  260       GTV  34153.747559  71.4965  6976  7304  100  100  100  100   
13570  260  GTV_Node   4680.175781  71.8174  7048  7292  100  100  100  100   

       ...    D70    D75    D80    D85    D90    D95    D97    D98    D99  \
13569  ...  71.24  71.16  71.09     71  70.89  70.75  70.66   70.6  70.49   
13570  ...   71.6  71.54  71.48  71.41  71.29  71.15  71.07  70.99   70.9   

         ROI  
13569  _GTVp  
13570  _GTVn  

[2 rows x 46 columns]

262
        id   Structure       Volume      mean minGy maxGy V5 V10 V15 V20  ...  \
13657  262         GTV  3671.125031   0.28042    22    32  0   0   0   0  ...   
13658  262  GTV_Node_1  3806.140137  0.246174    17    44  0   0   0   0  ...   

        D70   D75   D80   D85   D90   D95   D97   D98   D99    ROI  
13657  0.29  0.29  0.28  0.27  0.26  0.25  0.24  0.24  0.24  _GTVp  
13658  0.23  0.22  0.22  0.21   0.2   0.2  0.19 

[2 rows x 46 columns]

367
        id Structure        Volume     mean minGy maxGy       V5      V10  \
17111  367       GTV   9485.244751  46.0389  1399  7414      100      100   
17112  367  GTV_Node  65361.022949  55.5791     1  7626  87.1863  86.1153   

           V15      V20  ...    D70    D75    D80    D85    D90    D95   D97  \
17111  99.5174  91.1522  ...  34.36  30.76  27.07  23.34  20.54  18.37  17.2   
17112  85.3041  84.5308  ...  60.93  55.08  42.87  16.64    0.1   0.02  0.01   

         D98   D99    ROI  
17111  16.66  15.7  _GTVp  
17112   0.01  0.01  _GTVn  

[2 rows x 46 columns]

368
        id Structure        Volume     mean minGy maxGy   V5  V10  V15  V20  \
17161  368  GTV_Node  11497.497559  61.5152  5519  7388  100  100  100  100   

       ...    D70    D75    D80    D85    D90    D95   D97    D98    D99  \
17161  ...  58.73  58.42  58.15  57.86  57.53  57.17  56.9  56.74  56.57   

         ROI  
17161  _GTVn  

[1 rows x 46 columns]

369
        id Structu

        id    Structure         Volume     mean minGy maxGy   V5  V10  V15  \
19848  467          GTV  111141.815186   71.006  3978  7481  100  100  100   
19849  467     GTV_Node   79002.685547  70.8651  4970  7481  100  100  100   
19850  467  GTV_Primary    9183.197021  70.7264  6567  7376  100  100  100   

       V20  ...    D70    D75    D80    D85    D90    D95    D97    D98  \
19848  100  ...     71  70.85  70.65  70.32  69.69   67.8  65.32   62.5   
19849  100  ...  70.92  70.74  70.48  69.94  69.01  66.36  64.11  62.62   
19850  100  ...  70.03  69.84  69.65  69.41  69.19   68.9  68.71  68.51   

         D99    ROI  
19848  57.89  _GTVp  
19849  59.89  _GTVn  
19850  67.84  _GTVp  

[3 rows x 46 columns]

469
        id  Structure        Volume     mean minGy maxGy   V5  V10  V15  V20  \
19908  469        GTV  13759.613037  61.0908  5853  6507  100  100  100  100   
19909  469   GTV_Node  10410.308838   62.225  5914  6399  100  100  100  100   
19910  469  GTV_Preop  24251.9

True

In [81]:
# def input_dvh_vols()
def get_dvh_columns(dvh,key,steps=None):
    good_cols = []
    for c in dvh.columns:
        match = re.match(key+'(\d+)', c)
        if match is not None:
            use = True
            if steps is not None and match.group(1) is not None:
                value = int(match.group(1))
                if value not in steps:
                    use = False
            else:
                use = True
            if use:
                good_cols.append(match.group(0))
    return sorted(good_cols, key = lambda x: int(re.match(key+'(\d+)',x).group(1)))

get_dvh_columns(rds.dvh_df,'D')

['D2',
 'D5',
 'D10',
 'D15',
 'D20',
 'D25',
 'D30',
 'D35',
 'D40',
 'D45',
 'D50',
 'D55',
 'D60',
 'D65',
 'D70',
 'D75',
 'D80',
 'D85',
 'D90',
 'D95',
 'D97',
 'D98',
 'D99']

In [15]:
dose_points = [2,10,20,30,40,50,60,70,80,90,99]
dose_dvh, dose_missing = rds.get_dvh_array('D',dose_points,patient_ids = symptom_df.index.values)
(dose_dvh.shape,(dose_missing == 0).sum())

[nan nan nan nan nan nan nan nan nan nan nan]


KeyError: 'ROI'

In [5]:
import torch.nn.functional as F

class DvhAutoEncoder(nn.Module):
    
    def __init__(self,
                input_size,
                embedding_dim = 20,
                pre_embedding_dim = 100,
                conv_channels = 2, 
                init_dropout = .2,
                embedding_dropout = .2,
                penult_dropout = .2
                ):
        
        super().__init__()
            
        self.input_size = input_size
        self.n_channels = input_size[-1]
        self.n_organs = input_size[1]
        
        self.init_dropout = init_dropout
        self.embedding_dim =  embedding_dim
        self.pre_embedding_dim = pre_embedding_dim
        self.conv_channels = conv_channels
        
        self.embedding_dropout = embedding_dropout
        self.penult_dropout = penult_dropout
        self.conv_droput = nn.Dropout2d(p=self.init_dropout)
        self.flatten = nn.Flatten().cuda()
        
        self.encoder = self.init_encoder()
        self.decoder = self.init_decoder()
        
    def init_encoder(self):
        conv_in = nn.Conv1d(self.n_channels,self.conv_channels,
                            padding='same')
        #I think this is what the size should be
        conv_dim = self.conv_channels * self.n_organs
        fc_in_1 = nn.Linear(conv_dim, self.pre_embedding_dim)
        embedding_layer = nn.Linear(self.pre_embedding_dim,self.embedding_dim )
        encoder = nn.Sequential(
            self.conv_droput,
            conv_in,
            self.flatten,
            nn.Relu(),
            fc_in_1,
            nn.Relu(),
            embedding_layer,
            nn.Dropout(p=self.embedding_dropout),
            nn.LazyBatchNorm1d(),
        )
        return encoder
    
    def init_decoder(self):
        #idk
        fc_out_1 = nn.LazyLiner(self.pre_embedding_dim)
        conv_dim = self.conv_channels * self.n_organs
        fc_out_2 = nn.Linear(self.pre_embedding_dim)
        conv_out = nn.ConvTranspose1d(self.pre_embedding_dim)
        
    def encode(self,x):
        #I think I need to figure out alignment here
        pass
    
    def decode(self,x):
        pass
    
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
        
    def forward(self,xin):
        #we keep in nans in the data so I can ignore them in the loss function
        x = torch.nan_to_num(xin)
        #channel should be the 2nd dim for conv networks, I use 3rd usually
        x = x.permute(0,2,1)
        x = self.encode(x)
        x = self.decode(x)
        return x
    
volume_dvh

NameError: name 'nn' is not defined