In [11]:
from dataset_handlers import HCP_handler
from torch_geometric.data import Data, Dataset
from dipy.io.streamline import load_trk

In [12]:
train_ds_handler = HCP_handler(path='/app/dataset/HCP_105', scope='trainset')
valid_ds_handler = HCP_handler(path='/app/dataset/HCP_105', scope='validset')
testset_ds_handler = HCP_handler(path='/app/dataset/HCP_105', scope='testset')

train_tract_list = train_ds_handler.get_data()
valid_tract_list = valid_ds_handler.get_data()
testset_tract_list = testset_ds_handler.get_data()

print('train_tract_list:', len(train_tract_list))
print('valid_tract_list:', len(valid_tract_list))
print('testset_tract_list:', len(testset_tract_list))

print('train_tract_list:', train_tract_list)
print('valid_tract_list:', valid_tract_list)
print('testset_tract_list:', testset_tract_list)

{0: 'AF_left', 1: 'AF_right', 2: 'ATR_left', 3: 'ATR_right', 4: 'CA', 5: 'CC_1', 6: 'CC_2', 7: 'CC_3', 8: 'CC_4', 9: 'CC_5', 10: 'CC_6', 11: 'CC_7', 12: 'CC', 13: 'CG_left', 14: 'CG_right', 15: 'CST_left', 16: 'CST_right', 17: 'MLF_left', 18: 'MLF_right', 19: 'FPT_left', 20: 'FPT_right', 21: 'FX_left', 22: 'FX_right', 23: 'ICP_left', 24: 'ICP_right', 25: 'IFO_left', 26: 'IFO_right', 27: 'ILF_left', 28: 'ILF_right', 29: 'MCP', 30: 'OR_left', 31: 'OR_right', 32: 'POPT_left', 33: 'POPT_right', 34: 'SCP_left', 35: 'SCP_right', 36: 'SLF_I_left', 37: 'SLF_I_right', 38: 'SLF_II_left', 39: 'SLF_II_right', 40: 'SLF_III_left', 41: 'SLF_III_right', 42: 'STR_left', 43: 'STR_right', 44: 'UF_left', 45: 'UF_right', 46: 'T_PREF_left', 47: 'T_PREF_right', 48: 'T_PREM_left', 49: 'T_PREM_right', 50: 'T_PREC_left', 51: 'T_PREC_right', 52: 'T_POSTC_left', 53: 'T_POSTC_right', 54: 'T_PAR_left', 55: 'T_PAR_right', 56: 'T_OCC_left', 57: 'T_OCC_right', 58: 'ST_FO_left', 59: 'ST_FO_right', 60: 'ST_PREF_left', 6

In [16]:
print([*train_ds_handler.TRACT_LIST.keys()])

['AF_left', 'AF_right', 'ATR_left', 'ATR_right', 'CA', 'CC_1', 'CC_2', 'CC_3', 'CC_4', 'CC_5', 'CC_6', 'CC_7', 'CC', 'CG_left', 'CG_right', 'CST_left', 'CST_right', 'MLF_left', 'MLF_right', 'FPT_left', 'FPT_right', 'FX_left', 'FX_right', 'ICP_left', 'ICP_right', 'IFO_left', 'IFO_right', 'ILF_left', 'ILF_right', 'MCP', 'OR_left', 'OR_right', 'POPT_left', 'POPT_right', 'SCP_left', 'SCP_right', 'SLF_I_left', 'SLF_I_right', 'SLF_II_left', 'SLF_II_right', 'SLF_III_left', 'SLF_III_right', 'STR_left', 'STR_right', 'UF_left', 'UF_right', 'T_PREF_left', 'T_PREF_right', 'T_PREM_left', 'T_PREM_right', 'T_PREC_left', 'T_PREC_right', 'T_POSTC_left', 'T_POSTC_right', 'T_PAR_left', 'T_PAR_right', 'T_OCC_left', 'T_OCC_right', 'ST_FO_left', 'ST_FO_right', 'ST_PREF_left', 'ST_PREF_right', 'ST_PREM_left', 'ST_PREM_right', 'ST_PREC_left', 'ST_PREC_right', 'ST_POSTC_left', 'ST_POSTC_right', 'ST_PAR_left', 'ST_PAR_right', 'ST_OCC_left', 'ST_OCC_right']


# Recuento de fibras 

In [19]:
# Hacer un recuento de las fibras por tracto en el dataset
# crear un dataframe que recuente para cada sujeto el numero de fibras que tiene cada tracto
import pandas as pd
from tqdm import tqdm

def generate_streamline_count_dataframe(data, categories, output_file):
    df = pd.DataFrame(columns=['subject', *categories])

    row = {'subject': 'subject'}
    for category in categories:
        row[category] = 0

    progress_bar = tqdm(data)
    for subject_dict in progress_bar:
        row = {'subject': subject_dict['subject']}
        for tract in subject_dict['tracts']:
            tract_name = tract.stem.split('.')[0]
            num_fibers = len(load_trk(str(tract), 'same', bbox_valid_check=False).streamlines)
            row[tract_name] = num_fibers
        df.loc[len(df)] = row
        progress_bar.set_description(f'Processing {subject_dict["subject"]}')
    df.to_csv(output_file, index=False)

    return df

if False:
    categories = [*train_ds_handler.TRACT_LIST.keys()]
    df = generate_streamline_count_dataframe(train_tract_list, categories, '/app/code/trainset_streamline_count.csv')
    df = generate_streamline_count_dataframe(valid_tract_list, categories,'/app/code/validset_streamline_count.csv')
    df = generate_streamline_count_dataframe(testset_tract_list, categories, '/app/code/testset_streamline_count.csv')

# Subsampling y generacion de grafos

In [20]:
import pathlib2 as pathlib
from nibabel import load
from dipy.io.streamline import load_trk
import numpy as np
import torch
from torch_geometric.data import Data, Batch
import random
from tqdm import tqdm

class Graph_generator:
    def __init__(self, 
                 output_dir:str, 
                 ds_handler):
        # Comprobar si el directorio existe y si no, lanzar una excepción
        if not pathlib.Path(output_dir).exists():
            raise Exception(f"El directorio {output_dir} no existe.")
        else:
            self.output_dir = pathlib.Path(output_dir)

        self.ds_handler = ds_handler
        
    def generate_graphs_from_subject(self, subject_dict:dict) -> Data:
        """
        Genera un grafo por cada fibra de una imagen T1w.
        """

        subject_id = subject_dict["subject"]      # str (subj-01, subj-02, etc)
        tracts = subject_dict["tracts"]       # List[Path] (lista de rutas a las fibras del sujeto)
        split = subject_dict["subject_split"] # str (trainset, testset o validset)
        
        subject_graphs = []
        for tract in tracts:
            text_label = tract.stem.split(".")[0]# Obtener el label de la fibra
            label = self.ds_handler.get_label_from_tract(text_label)
            label = torch.tensor(label, dtype = torch.long)
            
            tractogram = load_trk(str(tract), 'same', bbox_valid_check=False)
            streamlines, affine = tractogram.streamlines, tractogram.affine
            
            if len(streamlines) > 8000:
                # Si hay más de 10000 fibras, seleccionar aleatoriamente 8000
                selected_index = random.sample(range(len(streamlines)), 8000)
                streamlines = [streamlines[i] for i in selected_index]

            for streamline in streamlines:
                nodes = torch.from_numpy(streamline).float()
                edges = torch.tensor([[i, i+1] for i in range(nodes.size(0)-1)] + [[i+1, i] for i in range(nodes.size(0)-1)], dtype=torch.long).T
                graph = Data(x = nodes, 
                             edge_index = edges,
                             y = label)
                # Almacenar grafo (Data) en estructura de datos
                subject_graphs.append(graph)
        
        # Convertir la lista de grafos en un solo objeto Data
        subject_graphs = Batch.from_data_list(subject_graphs)

        # Guardar los grafos en un archivo .pt
        self.output_dir.mkdir(parents=True, exist_ok=True)# Si no existe el directorio de la partición, crearlo
        output_path = str(self.output_dir.joinpath(split, f"{subject_id}.pt"))
        torch.save(subject_graphs, output_path)
        
        return subject_graphs    

    def generate_graphs_from_subjects(self, subjects_list:list) -> None:
        """
        Genera los grafos de una lista de sujetos.
        """
        for subject in tqdm(subjects_list):
            self.generate_graphs_from_subject(subject)

In [21]:
Graph_generator(output_dir='/app/dataset/HCP_105_processed', ds_handler=train_ds_handler).generate_graphs_from_subjects(train_tract_list)
Graph_generator(output_dir='/app/dataset/HCP_105_processed', ds_handler=valid_ds_handler).generate_graphs_from_subjects(valid_tract_list)
Graph_generator(output_dir='/app/dataset/HCP_105_processed', ds_handler=testset_ds_handler).generate_graphs_from_subjects(testset_tract_list)

100%|██████████| 63/63 [30:59<00:00, 29.51s/it]
100%|██████████| 21/21 [10:12<00:00, 29.16s/it]
100%|██████████| 21/21 [10:25<00:00, 29.78s/it]


# Generacion del dataset

In [None]:
import os.path as osp
from os import listdir

class MaxMinNormalization(BaseTransform):
    def __init__(self, 
                 max_values = [75, 82.5, 97.5], 
                 min_values = [-77, -120.5, -81.5]):
        """
        Inicializa la transformación de normalización con valores máximos y mínimos opcionales.
        Si no se proporcionan, deben calcularse a partir del conjunto de datos.
        
        [74.99879455566406, 82.36431884765625, 97.47947692871094], 
        [-76.92510986328125, -120.4773941040039, -81.27867126464844], 

        """
        self.max_values = torch.tensor(max_values, dtype=torch.float)
        self.min_values = torch.tensor(min_values, dtype=torch.float)

    def __call__(self, data: Data) -> Data:
        """
        Aplica normalización min-max a las características del nodo.
        """
        data.x = (data.x - self.min_values) / (self.max_values - self.min_values)
        return data


class FiberGraphDataset(Dataset):
    def __init__(self, 
                 root, 
                 transform = None, 
                 pre_transform = None):
        super(FiberGraphDataset, self).__init__(root, transform, pre_transform)
    
    @property
    def processed_dir(self):
        # Devolver la carpeta donde se guardan los datos procesados
        return osp.join(self.root)

    @property
    def processed_file_names(self):
        # Devolver los rutas absolutas de los archivos procesados
        return [osp.join(self.processed_dir, f) for f in listdir(self.processed_dir) if f.endswith('.pt')]
    
    def len(self):
        return len(self.processed_file_names)
    
    def get(self, idx):
        subject_path = self.processed_file_names[idx]# Seleccionar un sujeto
        data_batch = torch.load(subject_path)# Carga un DataBatch con todos los grafos del sujeto como Data

        selected_indices = self.sample_graphs_by_class(data_batch.to_data_list(), 
                                                       num_samples = 1500)# -> lista de índices seleccionados
        
        sampled_graphs = data_batch.index_select(selected_indices)# -> DataBatch con los grafos seleccionados
        
        if self.transform:
            sampled_graphs = self.transform(sampled_graphs)

        return sampled_graphs

    def sample_graphs_by_class(self, graphs, num_samples):
        # Organize graphs by class using a more streamlined approach
        class_dict = {i: [] for i in range(72)}
        
        [
            class_dict[graph.y.item()].append(idx) 
            for idx, graph in enumerate(graphs)
        ]

        sampled_graphs_idx = []
        for idx in class_dict.keys():
            sampled_graphs_idx.extend([
                idx for idx in random.sample(class_dict[idx], num_samples) 
                if len(class_dict[idx]) > num_samples
            ])

        return sampled_graphs_idx