# Dataset de torch

In [1]:
from torch_geometric.data import Dataset
import os
import torch
from torch_geometric.data import Batch, Data
import random
from torch_geometric.transforms import Compose
from transformers import AutoTokenizer
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch_geometric.data import Batch as GeoBatch
from torch.nn.utils.rnn import pad_sequence
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform


class MaxMinNormalization(BaseTransform):
    def __init__(self, max_values=None, min_values=None):
        """
        Initialize the normalization transform with optional max and min values.
        If not provided, they should be computed from the dataset.
        """
        self.max_values = max_values if max_values is not None else torch.tensor([76.03170776367188, 77.9359130859375, 88.72427368164062], dtype=torch.float)
        self.min_values = min_values if min_values is not None else torch.tensor([-73.90082550048828, -112.23554992675781, -79.38320922851562], dtype=torch.float)

    def __call__(self, data: Data) -> Data:
        """
        Apply min-max normalization to the node features.
        """
        data.x = (data.x - self.min_values) / (self.max_values - self.min_values)
        return data
    
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
import random


class MyLazyDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyLazyDataset, self).__init__(root, transform, pre_transform)
        self.transform = Compose([MaxMinNormalization()])

    @property
    def processed_file_names(self):
        # De manera similar, lista los archivos en el directorio 'processed'
        return os.listdir(os.path.join(self.root, 'processed'))
    
    @property
    def raw_file_names(self):
        # Lista los archivos en el directorio 'raw'
        return os.listdir(os.path.join(self.root, 'raw'))

    def len(self):
        return len(self.processed_file_names)
    
    def get(self, idx):
        
        subject = self.processed_file_names[idx]# Seleccionar un sujeto
        graphs = torch.load(os.path.join(self.processed_dir, subject))
        if self.transform:
            graphs = self.transform(graphs)

        return graphs


# Uso de tu clase de conjunto de datos
dataset = MyLazyDataset(root=r'C:\Users\pablo\GitHub\tfm_prg\tractoinferno_graphs\testset')

ModuleNotFoundError: No module named 'transformers'

In [9]:
first_element = dataset[0] # Esto carga el primer ejemplo de tu conjunto de datos

Cargando sujeto...
Sujeto cargado
Aplicando transformaciones...
Transformaciones aplicadas


In [8]:
from torch.utils.data import DataLoader
from torch_geometric.data import Batch as GeoBatch
from torch.nn.utils.rnn import pad_sequence
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform

def collate_function(batch):
    """Funcion para el DataLoader"""
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    TRACT_LIST = {
        'AF_L': {'id': 0, 'tract': 'arcuate fasciculus', 'side' : 'left', 'type': 'association'},
        'AF_R': {'id': 1, 'tract': 'arcuate fasciculus','side' : 'right', 'type': 'association'},
        'CC_Fr_1': {'id': 2, 'tract': 'corpus callosum, frontal lobe', 'side' : 'most anterior part of the frontal lobe', 'type': 'commissural'},
        'CC_Fr_2': {'id': 3, 'tract': 'corpus callosum, frontal lobe', 'side' : 'most posterior part of the frontal lobe','type': 'commissural'},
        'CC_Oc': {'id': 4, 'tract': 'corpus callosum, occipital lobe', 'side' : 'central', 'type': 'commissural'},
        'CC_Pa': {'id': 5, 'tract': 'corpus callosum, parietal lobe', 'side' : 'central', 'type': 'commissural'},
        'CC_Pr_Po': {'id': 6, 'tract': 'corpus callosum, pre/post central gyri', 'side' : 'central', 'type': 'commissural'},
        'CG_L': {'id': 7, 'tract': 'cingulum', 'side' : 'left', 'type': 'association'},
        'CG_R': {'id': 8, 'tract': 'cingulum', 'side' : 'right', 'type': 'association'},
        'FAT_L': {'id': 9, 'tract': 'frontal aslant tract', 'side' : 'left', 'type': 'association'},
        'FAT_R': {'id': 10, 'tract': 'frontal aslant tract', 'side' : 'right', 'type': 'association'},
        'FPT_L': {'id': 11, 'tract': 'fronto-pontine tract', 'side' : 'left', 'type': 'association'},
        'FPT_R': {'id': 12, 'tract': 'fronto-pontine tract', 'side' : 'right', 'type': 'association'},
        'FX_L': {'id': 13, 'tract': 'fornix', 'side' : 'left', 'type': 'commissural'},
        'FX_R': {'id': 14, 'tract': 'fornix', 'side' : 'right', 'type': 'commissural'},
        'IFOF_L': {'id': 15, 'tract': 'inferior fronto-occipital fasciculus', 'side' : 'left', 'type': 'association'},
        'IFOF_R': {'id': 16, 'tract': 'inferior fronto-occipital fasciculus', 'side' : 'right', 'type': 'association'},
        'ILF_L': {'id': 17, 'tract': 'inferior longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
        'ILF_R': {'id': 18, 'tract': 'inferior longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
        'MCP': {'id': 19, 'tract': 'middle cerebellar peduncle', 'side' : 'central', 'type': 'commissural'},
        'MdLF_L': {'id': 20, 'tract': 'middle longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
        'MdLF_R': {'id': 21, 'tract': 'middle longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
        'OR_ML_L': {'id': 22, 'tract': 'optic radiation, Meyer loop', 'side' : 'left', 'type': 'projection'},
        'OR_ML_R': {'id': 23, 'tract': 'optic radiation, Meyer loop', 'side' : 'right', 'type': 'projection'},
        'POPT_L': {'id': 24, 'tract': 'pontine crossing tract', 'side' : 'left', 'type': 'commissural'},
        'POPT_R': {'id': 25, 'tract': 'pontine crossing tract', 'side' : 'right', 'type': 'commissural'},
        'PYT_L': {'id': 26, 'tract': 'pyramidal tract', 'side' : 'left', 'type': 'projection'},
        'PYT_R': {'id': 27, 'tract': 'pyramidal tract', 'side' : 'right', 'type': 'projection'},
        'SLF_L': {'id': 28, 'tract': 'superior longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
        'SLF_R': {'id': 29, 'tract': 'superior longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
        'UF_L': {'id': 30, 'tract': 'uncinate fasciculus', 'side' : 'left', 'type': 'association'},
        'UF_R': {'id': 31, 'tract': 'uncinate fasciculus', 'side' : 'right', 'type': 'association'}
    }

    LABELS = {value["id"]: key for key, value in TRACT_LIST.items()}# Diccionario id -> Etiqueta
    caption_templates = [
            "A {type} fiber",
            "A {type} fiber on the {side} side",
            "{type} fiber on the {side} side",
            "A {type} fiber of the {tract}",
            "{type} fiber of the {tract}",
            "A {type} fiber of the {tract} on the {side} side",
            "{type} fiber of the {tract} on the {side} side",
            "{side} side",
            "{tract} tract",
            "{type} fiber",
            "The {type} fiber located in the {tract} tract",
            "This is a {type} fiber found on the {side} hemisphere",
            "Detailed view of a {type} fiber within the {tract}",
            "Observation of the {type} fiber, prominently on the {side} side",
            "The {tract} tract's remarkable {type} fiber",
            "Characteristics of a {type} fiber in the {tract} region",
            "Notable {type} fiber on the {side} hemisphere of the {tract}",
            "Insight into the {type} fiber's structure on the {side} side",
            "Exploring the complexity of the {type} fiber in the {tract}",
            "The anatomy of a {type} fiber on the {side} hemisphere",
            "The {tract} tract featuring a {type} fiber",
            "A comprehensive look at the {type} fiber, {side} orientation",
            "A closer look at the {type} fiber's path in the {tract}",
            "Unveiling the {type} fiber's role in the {tract} tract",
            "Decoding the structure of the {type} fiber on the {side}",
            "Highlighting the {type} fiber's significance in the {tract}",
            "The {type} fiber: A journey through the {tract} on the {side}",
            "A deep dive into the {type} fiber's dynamics in the {tract}",
            "The {type} fiber's contribution to {tract} tract functionality",
            "Mapping the {type} fiber's trajectory in the {tract} on the {side} side",
            "Navigating the intricate pathways of the {type} fiber within the {tract}",
            "The interplay of {type} fibers across the {side} hemisphere",
            "Traversing the {tract} with a {type} fiber",
            "The pivotal role of the {type} fiber in connecting the {tract}",
            "Showcasing the unique texture of {type} fibers in the {tract}",
            "Zooming in on the {type} fiber's impact on the {side} hemisphere",
            "The {type} fiber in the {tract}",
            "The {type} fiber as a conduit in the {tract} on the {side} side",
            "The {type} fiber's architectural marvel within the {tract}",
            "A journey alongside the {type} fiber through the {tract}",
            "The harmonious structure of the {type} fiber in the {tract}",
            "Unraveling the secrets of the {type} fiber in the {tract} tract",
            "The {type} fiber: A key player in {tract} dynamics",
            "Envisioning the {type} fiber's pathway in the {tract}",
            "The strategic placement of the {type} fiber in the {tract}",
            "Illuminating the {type} fiber's route through the {tract}",
            "The {type} fiber: An essential bridge within the {tract}",
            "Deciphering the network of {type} fibers in the {tract}",
            "Exploring the synergy between {type} fibers and the {tract}",
            "The {type} fiber's vital link in the neural network of the {tract}"
        ]

    # Extraer los labels de todos los grafos en el lote
    labels = [graph.y.item() for graph in batch]  # Asumiendo que `y` es el tensor de labels
    
    # Recuperar y tokenizar todos los captions necesarios en una sola llamada
    captions = [random.choice(caption_templates).format(**TRACT_LIST[LABELS[label]]) for label in labels]
    tokenized_texts_batch = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")
    
    # Devolver el lote procesado. No es necesario devolver tokenized_texts_batch por separado
    return batch, tokenized_texts_batch # grafos, {'input_ids': padded_input_ids, 'attention_mask': padded_attention_masks}





for subject in dataset:
    for graph_batch, text_batch in DataLoader(subject, batch_size=512, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_function):
        print(text_batch)
        break
    break



{'input_ids': tensor([[  101,  1996,  6143,  ...,     0,     0,     0],
        [  101,  6583,  5737,  ...,     0,     0,     0],
        [  101, 11131,  1996,  ...,     0,     0,     0],
        ...,
        [  101,  1037,  4012,  ...,     0,     0,     0],
        [  101,  1996,  2523,  ...,     0,     0,     0],
        [  101,  6459,  1997,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [6]:
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.loader import DataLoader

class MaxMinNormalization(BaseTransform):
    def __init__(self, max_values=None, min_values=None):
        """
        Initialize the normalization transform with optional max and min values.
        If not provided, they should be computed from the dataset.
        """
        self.max_values = max_values if max_values is not None else torch.tensor([76.03170776367188, 77.9359130859375, 88.72427368164062], dtype=torch.float)
        self.min_values = min_values if min_values is not None else torch.tensor([-73.90082550048828, -112.23554992675781, -79.38320922851562], dtype=torch.float)

    def __call__(self, data: Data) -> Data:
        """
        Apply min-max normalization to the node features.
        """
        data.x = (data.x - self.min_values) / (self.max_values - self.min_values)
        return data


    

    

#Entrenar el modelo utilizando la gpu y tensorboard

# print(f"Using device: {device}")
# model.to(device)
# for epoch in range(2):
#     for i, (graph_data, text_data) in enumerate(dataloader):
#         graph_data = graph_data.to(device)
#         text_data = {key: val.to(device) for key, val in text_data.items()}
#         loss = model(graph_data, text_data)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         writer.add_scalar('Loss/train', loss, epoch * len(dataloader) + i)
#         print(f"\r Epoch {epoch}, Iteration {i}, Loss {loss}")
# writer.flush()
# writer.close()

# Crear un transform customizado para generar captions
# transform = Compose([MaxMinNormalization()])

# Crear un DataLoader que aplica el transform customizado
for subject in dataset:
    for batch in subject:
        print(batch)
        dataloader = DataLoader(batch, batch_size=2048, shuffle=True, collate_fn=)#, transform=transform
        for batch in dataloader:
            print(batch)
            
    break
            

Cargando sujeto...
Sujeto cargado
Aplicando transformaciones...
Transformaciones aplicadas
('x', tensor([[0.4940, 0.5919, 0.4754],
        [0.4942, 0.5919, 0.4755],
        [0.4943, 0.5919, 0.4755],
        ...,
        [0.4981, 0.5919, 0.4762],
        [0.4982, 0.5919, 0.4762],
        [0.4984, 0.5918, 0.4762]]))
['x', tensor([[0.4940, 0.5919, 0.4754],
        [0.4942, 0.5919, 0.4755],
        [0.4943, 0.5919, 0.4755],
        ...,
        [0.4981, 0.5919, 0.4762],
        [0.4982, 0.5919, 0.4762],
        [0.4984, 0.5918, 0.4762]])]
('edge_index', tensor([[       0,        1,        2,  ..., 24849858, 24849859, 24849860],
        [       1,        2,        3,  ..., 24849857, 24849858, 24849859]]))


TypeError: expected Tensor as element 1 in argument 0, but got str

In [None]:
# Formato del dataset

# grafos (lista de Data), captions (lista de textos)


# Codigo para generar un batch 




In [2]:
from torch_geometric.loader import DataListLoader, DataLoader

# def custom_collate_fn(batch):
#     graphs = [item['graph'] for item in batch]
#     input_ids = [item['text']['input_ids'].squeeze(0) for item in batch]
#     attention_masks = [item['text']['attention_mask'].squeeze(0) for item in batch]
#     padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
#     padded_attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)
#     batched_graphs = GeoBatch.from_data_list(graphs)
#     return batched_graphs, {'input_ids': padded_input_ids, 'attention_mask': padded_attention_masks}



    
    

Cargando sujeto...
Sujeto cargado
Aplicando normalización...
Normalizacion aplicada
Generando captions...
Captions generadas


AttributeError: 'str' object has no attribute 'stores'

In [None]:
import torch
from torch_geometric.nn import GCNConv, BatchNorm
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from tqdm.notebook import tqdm

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GraphClassifier, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 256)
        self.conv4 = GCNConv(256, 512)
        self.fc1 = torch.nn.Linear(512, 256)
        self.fc2 = torch.nn.Linear(256, 64)
        self.fc3 = torch.nn.Linear(64, 32)
        self.fc = torch.nn.Linear(32, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.conv3(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.conv4(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = global_mean_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, training=self.training)
        x = BatchNorm(x.size()[1])(x)
        return F.log_softmax(x, dim=1)



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = GraphClassifier(3, 32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for epoch in range(1):
    for subject in dataset:
        for batch in tqdm(DataLoader(subject, batch_size=128, shuffle=True)):
            data = batch.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = F.nll_loss(out, data.y)
            # print(loss.item())
            loss.backward()
            optimizer.step()

        print('Epoch: {:03d}, Loss: {:.5f}'.format(epoch, loss.item()))


In [None]:
# Iterar sobre el conjunto de datos
for data in dataset:
    for batch in DataListLoader(data, batch_size=128, shuffle=True):
        

In [None]:
# dataloader de torch_geometric
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in loader:
    print(batch.shape)
    print(batch.num_graphs)

In [None]:
class AddCaptionTransform(BaseTransform):
    def __init__(self, tokenize_data=True):
        """
        Initialize the transform with the tokenizer and the tract list.
        """
        self.tokenize_data = tokenize_data

        if self.tokenize_data:
            self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
        else:
            self.tokenizer = None


        # Your TRACT_LIST and caption_templates can be defined here
        self.TRACT_LIST = {
            'AF_L': {'id': 0, 'tract': 'arcuate fasciculus', 'side' : 'left', 'type': 'association'},
            'AF_R': {'id': 1, 'tract': 'arcuate fasciculus','side' : 'right', 'type': 'association'},
            'CC_Fr_1': {'id': 2, 'tract': 'corpus callosum, frontal lobe', 'side' : 'most anterior part of the frontal lobe', 'type': 'commissural'},
            'CC_Fr_2': {'id': 3, 'tract': 'corpus callosum, frontal lobe', 'side' : 'most posterior part of the frontal lobe','type': 'commissural'},
            'CC_Oc': {'id': 4, 'tract': 'corpus callosum, occipital lobe', 'side' : 'central', 'type': 'commissural'},
            'CC_Pa': {'id': 5, 'tract': 'corpus callosum, parietal lobe', 'side' : 'central', 'type': 'commissural'},
            'CC_Pr_Po': {'id': 6, 'tract': 'corpus callosum, pre/post central gyri', 'side' : 'central', 'type': 'commissural'},
            'CG_L': {'id': 7, 'tract': 'cingulum', 'side' : 'left', 'type': 'association'},
            'CG_R': {'id': 8, 'tract': 'cingulum', 'side' : 'right', 'type': 'association'},
            'FAT_L': {'id': 9, 'tract': 'frontal aslant tract', 'side' : 'left', 'type': 'association'},
            'FAT_R': {'id': 10, 'tract': 'frontal aslant tract', 'side' : 'right', 'type': 'association'},
            'FPT_L': {'id': 11, 'tract': 'fronto-pontine tract', 'side' : 'left', 'type': 'association'},
            'FPT_R': {'id': 12, 'tract': 'fronto-pontine tract', 'side' : 'right', 'type': 'association'},
            'FX_L': {'id': 13, 'tract': 'fornix', 'side' : 'left', 'type': 'commissural'},
            'FX_R': {'id': 14, 'tract': 'fornix', 'side' : 'right', 'type': 'commissural'},
            'IFOF_L': {'id': 15, 'tract': 'inferior fronto-occipital fasciculus', 'side' : 'left', 'type': 'association'},
            'IFOF_R': {'id': 16, 'tract': 'inferior fronto-occipital fasciculus', 'side' : 'right', 'type': 'association'},
            'ILF_L': {'id': 17, 'tract': 'inferior longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
            'ILF_R': {'id': 18, 'tract': 'inferior longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
            'MCP': {'id': 19, 'tract': 'middle cerebellar peduncle', 'side' : 'central', 'type': 'commissural'},
            'MdLF_L': {'id': 20, 'tract': 'middle longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
            'MdLF_R': {'id': 21, 'tract': 'middle longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
            'OR_ML_L': {'id': 22, 'tract': 'optic radiation, Meyer loop', 'side' : 'left', 'type': 'projection'},
            'OR_ML_R': {'id': 23, 'tract': 'optic radiation, Meyer loop', 'side' : 'right', 'type': 'projection'},
            'POPT_L': {'id': 24, 'tract': 'pontine crossing tract', 'side' : 'left', 'type': 'commissural'},
            'POPT_R': {'id': 25, 'tract': 'pontine crossing tract', 'side' : 'right', 'type': 'commissural'},
            'PYT_L': {'id': 26, 'tract': 'pyramidal tract', 'side' : 'left', 'type': 'projection'},
            'PYT_R': {'id': 27, 'tract': 'pyramidal tract', 'side' : 'right', 'type': 'projection'},
            'SLF_L': {'id': 28, 'tract': 'superior longitudinal fasciculus', 'side' : 'left', 'type': 'association'},
            'SLF_R': {'id': 29, 'tract': 'superior longitudinal fasciculus', 'side' : 'right', 'type': 'association'},
            'UF_L': {'id': 30, 'tract': 'uncinate fasciculus', 'side' : 'left', 'type': 'association'},
            'UF_R': {'id': 31, 'tract': 'uncinate fasciculus', 'side' : 'right', 'type': 'association'}
        }

        self.LABELS = {value["id"]: key for key, value in self.TRACT_LIST.items()}# Diccionario id -> Etiqueta

        self.caption_templates = [
            "A {type} fiber",
            "A {type} fiber on the {side} side",
            "{type} fiber on the {side} side",
            "A {type} fiber of the {tract}",
            "{type} fiber of the {tract}",
            "A {type} fiber of the {tract} on the {side} side",
            "{type} fiber of the {tract} on the {side} side",
            "{side} side",
            "{tract} tract",
            "{type} fiber",
            "The {type} fiber located in the {tract} tract",
            "This is a {type} fiber found on the {side} hemisphere",
            "Detailed view of a {type} fiber within the {tract}",
            "Observation of the {type} fiber, prominently on the {side} side",
            "The {tract} tract's remarkable {type} fiber",
            "Characteristics of a {type} fiber in the {tract} region",
            "Notable {type} fiber on the {side} hemisphere of the {tract}",
            "Insight into the {type} fiber's structure on the {side} side",
            "Exploring the complexity of the {type} fiber in the {tract}",
            "The anatomy of a {type} fiber on the {side} hemisphere",
            "The {tract} tract featuring a {type} fiber",
            "A comprehensive look at the {type} fiber, {side} orientation",
            "A closer look at the {type} fiber's path in the {tract}",
            "Unveiling the {type} fiber's role in the {tract} tract",
            "Decoding the structure of the {type} fiber on the {side}",
            "Highlighting the {type} fiber's significance in the {tract}",
            "The {type} fiber: A journey through the {tract} on the {side}",
            "A deep dive into the {type} fiber's dynamics in the {tract}",
            "The {type} fiber's contribution to {tract} tract functionality",
            "Mapping the {type} fiber's trajectory in the {tract} on the {side} side",
            "Navigating the intricate pathways of the {type} fiber within the {tract}",
            "The interplay of {type} fibers across the {side} hemisphere",
            "Traversing the {tract} with a {type} fiber",
            "The pivotal role of the {type} fiber in connecting the {tract}",
            "Showcasing the unique texture of {type} fibers in the {tract}",
            "Zooming in on the {type} fiber's impact on the {side} hemisphere",
            "The {type} fiber in the {tract}",
            "The {type} fiber as a conduit in the {tract} on the {side} side",
            "The {type} fiber's architectural marvel within the {tract}",
            "A journey alongside the {type} fiber through the {tract}",
            "The harmonious structure of the {type} fiber in the {tract}",
            "Unraveling the secrets of the {type} fiber in the {tract} tract",
            "The {type} fiber: A key player in {tract} dynamics",
            "Envisioning the {type} fiber's pathway in the {tract}",
            "The strategic placement of the {type} fiber in the {tract}",
            "Illuminating the {type} fiber's route through the {tract}",
            "The {type} fiber: An essential bridge within the {tract}",
            "Deciphering the network of {type} fibers in the {tract}",
            "Exploring the synergy between {type} fibers and the {tract}",
            "The {type} fiber's vital link in the neural network of the {tract}"
        ]


    def get_caption(self, data: Data)-> Data:
        # print(data.y) -> tensor([ 0,  0,  0,  ..., 29, 29, 29])

        captions = []

        for label in data.y:
            info = self.TRACT_LIST[self.LABELS[label.item()]]
            caption = random.choice(self.caption_templates).format(**info)
            captions.append(caption)

        if self.tokenize_data:
            return self.tokenizer(captions, return_tensors="pt", truncation=True, padding=True)
        else:
            return captions

    def __call__(self, data: Data) -> Data:
        """
        Add a caption to the data object.
        """
        data.caption = self.get_caption(data)
        return data