In [2]:
import os
import torch
from torch_geometric.data import Dataset, download_url
import os
import pandas as pd

- Here is some general information about the data in this folder.

1.  Bridge_eids_60520_87802.csv:<br>

- This file contains the bridging between the two sets of subject IDs.<br>
- The 'old' ID can be found in the column 'eid_60520' - and the 'new' ID in column 'eid_87802'.  
- The new ID corresponds to the folder names, the old ID to the subj_id saved in the data objects.  

***

2. basic_features.csv 

- This file contains some basic features of subjects. The columns stand for the following:
- 21003-2.0: age
- 31-0.0: sex (0: female, 1: male)
- 21001-2.0: BMI
- 21002-2.0: weight
- 50-2.0: standing height

# CREATE TRAIN VAL TEST SPLITS

In [10]:
from sklearn.model_selection import train_test_split
import os 
organ_mesh_ids = os.listdir('/data0/practical-wise2223/organ_mesh/organ_meshes')


In [15]:
len(organ_mesh_ids)

30382

In [24]:
sorted(set(organ_mesh_ids)) == sorted(set(registered_ids))
# Registered folder has 3 files missing
# 3 Files most likely included in our training set
# There is possiblity that 

False

In [21]:
registered_path = "/data0/practical-wise2223/organ_mesh/gendered_organ_registrations_ply/"
registered_ids = os.listdir(registered_path)

In [11]:
# There are organ mesh which do not have bmi features.


30382

In [31]:
# Get the basic features
basic_features = pd.read_csv('/data0/practical-wise2223/organ_mesh/basic_features.csv')
basic_features_nonnna = basic_features.dropna()

In [41]:
#  Non na ids 
valid_ids = set(basic_features_nonnna['eid'].values)
# Convert to str
valid_ids = [str(each) for each in valid_ids]

In [46]:
# Calculate the set intersection of the registered ids and the valid ids
valid_registered_ids = set(registered_ids).intersection(valid_ids)


In [48]:
len(valid_registered_ids)

29348

In [51]:
train_ratio = 0.75
validation_ratio = 0.15
test_ratio = 0.10

# train is now 75% of the entire data set
# the _junk suffix means that we drop that variable completely
x_train, x_test = train_test_split(list(valid_registered_ids), test_size=1 - train_ratio)

# test is now 10% of the initial data set
# validation is now 15% of the initial data set
x_val, x_test = train_test_split(x_test, test_size=test_ratio/(test_ratio + validation_ratio)) 

print('X train shape : ', len(x_train), 'X Val shape ', len(x_val), 'X Test shape', len(x_test))

X train shape :  22011 X Val shape  4402 X Test shape 2935


In [52]:
def write_list_file(split_list, out_path='.', mode='train'):
    path = os.path.join(out_path, f'NonNa_organs_split_{mode}.txt')
    with open(path, 'w') as f:
        f.write("\n".join(str(item) for item in split_list))


In [53]:
write_list_file(x_train, out_path='../data', mode='train')
write_list_file(x_val, '../data', mode='val')
write_list_file(x_test, '../data', mode='test')

In [58]:
import pandas as pd 


class OrganMeshDataset(Dataset):
    def __init__(self, root, basic_feats_path, bridge_path, mode='train', organ='liver', split_path = None,
                 num_samples = None, transform=None, pre_transform=None, pre_filter=None):
    
        super().__init__(root, transform, pre_transform, pre_filter)
        assert mode in ['train', 'val', 'test']

        self.root = root
        self.organ = organ

        
        split_path = os.path.join(split_path, f'organs_split_{mode}.txt')
        with open(split_path) as f:
            self.organ_mesh_ids = f.readlines()

        if num_samples is not None:
            self.organ_mesh_ids = os.listdir(root)[:num_samples]    

        self.basic_feats_path = basic_feats_path 
        self.bridge_path = bridge_path

        self.basic_features = pd.read_csv(basic_feats_path)
        new_names = {'21003-2.0':'age', '31-0.0':'sex', '21001-2.0':'bmi', '21002-2.0':'weight','50-2.0':'standing_weight'}
        self.basic_features = self.basic_features.rename(index=str, columns=new_names)
        self.bridge_organ_df = pd.read_csv(bridge_path)

    def len(self):
        return len(self.organ_mesh_ids)


    def get(self, idx):
        selected_patient = self.organ_mesh_ids[idx]
        #print('Selected Patient', selected_patient)
        data = torch.load(os.path.join(self.root, selected_patient,f'{self.organ}_mesh.pt'))
        old_id = data['eid']
        new_id = selected_patient
        patient_features = self.basic_features[self.basic_features['eid'] == int(selected_patient)]
        #print(patient_features['sex'])
        gender_patient = patient_features['sex'].item()
        print('Gender patient', gender_patient)
        #Label of the data is currently gender
        data.y = gender_patient
        return data
    


In [59]:
root = '/data0/practical-wise2223/organ_mesh/organ_meshes'
basic_feat_path = '/data0/practical-wise2223/organ_mesh/basic_features.csv'
bridge_path = '/data0/practical-wise2223/organ_mesh/Bridge_eids_60520_87802.csv'
split_path = '/data0/practical-wise2223/organ_mesh/data/'


In [60]:
train_dataset = OrganMeshDataset(root, basic_feat_path, bridge_path, mode='train', split_path=split_path )
val_dataset = OrganMeshDataset(root, basic_feat_path, bridge_path, mode='val', split_path=split_path )
test_dataset = OrganMeshDataset(root, basic_feat_path, bridge_path, mode='test', split_path=split_path )

In [61]:
train_dataset

OrganMeshDataset(22786)

In [62]:
val_dataset

OrganMeshDataset(4557)

In [63]:
test_dataset

OrganMeshDataset(3039)