## Encoders

In [None]:

    

# class GCNEncoder(nn.Module):
#     def __init__(self, in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks):
#         super(GCNEncoder, self).__init__()
#         self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
#         self.hidden_blocks = ModuleList(
#             [
#                 GraphConvBlock(hidden_dim, hidden_dim, dropout) 
#                 for _ in range(n_hidden_blocks - 1)])
#         self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)

#     def forward(self, data):
#         x, edge_index, batch = data.x, data.edge_index, data.batch
#         x = self.input_block(x, edge_index)
#         for layer in self.hidden_blocks:
#             x = layer(x, edge_index)
#         x = self.output_block(x, edge_index)
#         return global_mean_pool(x, batch) # (batch_size, out_channels)

In [3]:
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, BatchNorm, LayerNorm, global_mean_pool
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv,GraphConv, global_mean_pool, BatchNorm
from torch.nn import ModuleList
import torch.nn.functional as F

class GraphConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(GraphConvBlock, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
        self.bn = BatchNorm(out_channels)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, x, edge_index):
        # Apply graph convolution
        x = self.conv(x, edge_index)
        # Apply batch normalization
        x = self.bn(x)
        # Apply activation function
        x = self.activation(x)
        # Apply dropout if defined
        if self.dropout:
            x = self.dropout(x)
        return x

class GCNEncoder_v2(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, dropout=0.0, n_hidden_blocks=2):
        super(GCNEncoder_v2, self).__init__()
        self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
        self.hidden_blocks = self._make_hidden_layers(hidden_dim, dropout, n_hidden_blocks)
        self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)
        self.attention_block = GATConv(hidden_dim, hidden_dim, heads=4, concat=False)
        self.layer_norm = LayerNorm(out_channels)

    def _make_hidden_layers(self, hidden_dim, dropout, n_hidden_blocks):
        layers = []
        for _ in range(n_hidden_blocks - 1):
            layers.append(GraphConvBlock(hidden_dim, hidden_dim, dropout))
        return nn.ModuleList(layers)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_block(x, edge_index)
        for layer in self.hidden_blocks:
            x = layer(x, edge_index)
        x = self.attention_block(x, edge_index)
        x = self.output_block(x, edge_index)
        x = self.layer_norm(x)
        return global_mean_pool(x, batch)  # (batch_size, out_channels)

class GraphConvBlock_v2(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(GraphConvBlock_v2, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
        self.bn = BatchNorm(out_channels)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None
        self.residual = (in_channels == out_channels)
        
    def forward(self, x, edge_index):
        identity = x
        x = self.conv(x, edge_index)
        x = self.bn(x)
        x = self.activation(x)
        if self.dropout:
            x = self.dropout(x)
        if self.residual:
            x += identity
        return x


class ProjectionHead_v2(nn.Module):
    def __init__(self, in_features, projection_dim):
        super(ProjectionHead_v2, self).__init__()
        self.projection = nn.Linear(in_features, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        x = self.projection(x)
        x = self.gelu(x)
        x = self.fc(x)
        x = self.layer_norm(x)
        return x
    
class ClassifierHead_v2(nn.Module):
    def __init__(self, in_features, num_classes):
        super(ClassifierHead_v2, self).__init__()
        self.fc = nn.Linear(in_features, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.fc(x)
        x = self.softmax(x)
        return x

class SiameseGraphNetwork(nn.Module):
    def __init__(self, encoder, projection_head, classifier, normalize=True):
        super(SiameseGraphNetwork, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head
        self.classifier = classifier
        self.normalize = normalize

    def forward(self, graph):
        x_1 = self.encoder(graph)
        x_1 = self.projection_head(x_1)
        
        if self.normalize:
            x1_norm = F.normalize(x_1, p=2, dim=1)

        c1 = self.classifier(x1_norm)
        return x1_norm, c1

# Dimensiones
in_channels = 3  # Número de características de entrada
hidden_dim = 128  # Dimensión de las capas ocultas
out_channels = 64  # Dimensión de la salida del encoder
projection_dim = 128  # Dimensión del embedding proyectado
num_classes = 10  # Número de clases para la clasificación

# Arquitectura
n_hidden_blocks = 3  # Número de bloques ocultos

# Dropout y Regularización
dropout = 0.2  # Tasa de dropout


# Configuración de la Pérdida
# triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
model = SiameseGraphNetwork(
    encoder = GCNEncoder_v2(in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks),
    projection_head = ProjectionHead_v2(out_channels, projection_dim),
    classifier = ClassifierHead_v2(projection_dim, num_classes)
)

print(model)
# Guardar el modelo
torch.save(model.state_dict(), '/app/pruebas/prueba_model_v2.pth')

SiameseGraphNetwork(
  (encoder): GCNEncoder_v2(
    (input_block): GraphConvBlock(
      (conv): GCNConv(3, 128)
      (bn): BatchNorm(128)
      (activation): GELU(approximate='none')
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (hidden_blocks): ModuleList(
      (0-1): 2 x GraphConvBlock(
        (conv): GCNConv(128, 128)
        (bn): BatchNorm(128)
        (activation): GELU(approximate='none')
        (dropout): Dropout(p=0.2, inplace=False)
      )
    )
    (output_block): GraphConvBlock(
      (conv): GCNConv(128, 64)
      (bn): BatchNorm(64)
      (activation): GELU(approximate='none')
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (attention_block): GATConv(128, 128, heads=4)
    (layer_norm): LayerNorm(64, affine=True, mode=graph)
  )
  (projection_head): ProjectionHead_v2(
    (projection): Linear(in_features=64, out_features=128, bias=True)
    (gelu): GELU(approximate='none')
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (layer

In [None]:
import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), torch.randn(5))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

In [17]:
import torch
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv, BatchNorm, LayerNorm, global_mean_pool
import torch.nn as nn
import torch.nn.functional as F

class GraphConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super(GraphConvBlock, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
        self.bn = BatchNorm(out_channels)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.bn(x)
        x = self.activation(x)
        if self.dropout:
            x = self.dropout(x)
        return x

class GCNEncoder_v2(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, dropout=0.0, n_hidden_blocks=2):
        super(GCNEncoder_v2, self).__init__()
        self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
        self.hidden_blocks = self._make_hidden_layers(hidden_dim, dropout, n_hidden_blocks)
        self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)
        self.attention_block = GATConv(hidden_dim, hidden_dim, heads=4, concat=False)
        self.layer_norm = LayerNorm(out_channels)

    def _make_hidden_layers(self, hidden_dim, dropout, n_hidden_blocks):
        layers = []
        for _ in range(n_hidden_blocks - 1):
            layers.append(GraphConvBlock(hidden_dim, hidden_dim, dropout))
        return nn.ModuleList(layers)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_block(x, edge_index)
        for layer in self.hidden_blocks:
            x = layer(x, edge_index)
        x = self.attention_block(x, edge_index)
        x = self.output_block(x, edge_index)
        x = self.layer_norm(x)
        return global_mean_pool(x, batch)  # (batch_size, out_channels)

class ProjectionHead_v2(nn.Module):
    def __init__(self, in_features, projection_dim):
        super(ProjectionHead_v2, self).__init__()
        self.projection = nn.Linear(in_features, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        x = self.projection(x)
        x = self.gelu(x)
        x = self.fc(x)
        x = self.layer_norm(x)
        return x

class ClassifierHead_v2(nn.Module):
    def __init__(self, in_features, num_classes):
        super(ClassifierHead_v2, self).__init__()
        self.fc = nn.Linear(in_features, num_classes)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.fc(x)
        x = self.softmax(x)
        return x

class SiameseGraphNetwork(nn.Module):
    def __init__(self, encoder, projection_head, classifier, normalize=True):
        super(SiameseGraphNetwork, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head
        self.classifier = classifier
        self.normalize = normalize

    def forward(self, x, edge_index, batch):
        graph = Data(x=x, edge_index=edge_index, batch=batch)
        x_1 = self.encoder(graph)
        x_1 = self.projection_head(x_1)
        
        if self.normalize:
            x1_norm = F.normalize(x_1, p=2, dim=1)
        else:
            x1_norm = x_1

        c1 = self.classifier(x1_norm)
        return x1_norm, c1

# Dimensiones del modelo
in_channels = 3  # Número de características de entrada
hidden_dim = 128  # Dimensión de las capas ocultas
out_channels = 64  # Dimensión de la salida del encoder
projection_dim = 128  # Dimensión del embedding proyectado
num_classes = 10  # Número de clases para la clasificación
dropout = 0.2  # Tasa de dropout
n_hidden_blocks = 3  # Número de bloques ocultos

# Cargar el modelo
model = SiameseGraphNetwork(
    encoder=GCNEncoder_v2(in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks),
    projection_head=ProjectionHead_v2(out_channels, projection_dim),
    classifier=ClassifierHead_v2(projection_dim, num_classes)
)

# Generar un lote de datos aleatorios
def generate_random_graph(num_nodes, num_node_features, num_edges):
    x = torch.randn((num_nodes, num_node_features), dtype=torch.float)
    edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.long)
    batch = torch.zeros(num_nodes, dtype=torch.long)  # Asumiendo un solo grafo en el lote
    return Data(x=x, edge_index=edge_index, batch=batch)

num_graphs = 5
num_nodes = 10
num_node_features = in_channels
num_edges = 20

graphs = [generate_random_graph(num_nodes, num_node_features, num_edges) for _ in range(num_graphs)]
loader = DataLoader(graphs, batch_size=1, shuffle=True)

# Obtener un batch del DataLoader
batch = next(iter(loader))

# Preparar los datos de entrada para el modelo
inputs = (batch.x, batch.edge_index, batch.batch)

# Exportar el modelo
exported_program = torch.export.export(model, args=inputs)

# Guardar el modelo exportado
torch.export.save(exported_program, 'exported_program.pt2')



Unsupported: call_function UserDefinedClassVariable() [] {'x': TensorVariable(), 'edge_index': TensorVariable(), 'batch': TensorVariable()}

from user code:
   File "/tmp/ipykernel_29230/3296968102.py", line 83, in forward
    graph = Data(x=x, edge_index=edge_index, batch=batch)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


In [13]:
import model_explorer
import torch
from torch_geometric.data import Data, Batch, DataLoader
# Generate some random Data graphs (torch geometric data object) and edges.
def generate_random_graph(num_nodes, num_node_features, num_edges):
    x = torch.randn((num_nodes, num_node_features), dtype=torch.float)
    edge_index = torch.randint(0, num_nodes, (2, num_edges), dtype=torch.long)
    batch = torch.zeros(num_nodes, dtype=torch.long)  # Asumiendo un solo grafo en el lote
    return Data(x=x, edge_index=edge_index, batch=batch)

# Crear un lote de grafos
num_graphs = 4
num_nodes = 10
num_node_features = 3
num_edges = 20
graphs = [generate_random_graph(num_nodes, num_node_features, num_edges) for _ in range(num_graphs)]
batch = Batch.from_data_list(graphs)

# Crear un DataLoader
loader = DataLoader(graphs, batch_size=1, shuffle=True)

batch = next(iter(loader))
inputs = (batch, )





In [14]:
# Load saved model.


# Prepare a PyTorch model and its inputs.
# model = torch.load('/app/pruebas/prueba_model_v2.pth').eval()
ep = torch.export.export(model, args=inputs)

# Visualize.
model_explorer.visualize_pytorch('siamese_graph_network', exported_program=ep, port=8888)

Unsupported: 'inline in skipfiles: Mapping.__contains__ | __contains__ /opt/conda/lib/python3.10/_collections_abc.py, skipped according skipfiles.SKIP_DIRS'

from user code:
   File "/tmp/ipykernel_29230/1115990740.py", line 106, in forward
    x_1 = self.encoder(graph)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_29230/1115990740.py", line 41, in forward
    x, edge_index, batch = data.x, data.edge_index, data.batch
  File "/opt/conda/lib/python3.10/site-packages/torch_geometric/data/data.py", line 950, in x
    return self['x'] if 'x' in self._store else None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


In [10]:
import model_explorer

model_explorer.visualize('/app/pruebas/prueba_model_v2.pth', port=8888)

Loading extensions...
! Failed to load extension module ".builtin_tflite_flatbuffer_adapter":
/opt/conda/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /opt/conda/lib/python3.10/site-packages/ai_edge_model_explorer_adapter/_pywrap_convert_wrapper.so)

! Failed to load extension module ".builtin_tflite_mlir_adapter":
/opt/conda/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /opt/conda/lib/python3.10/site-packages/ai_edge_model_explorer_adapter/_pywrap_convert_wrapper.so)

! Failed to load extension module ".builtin_tf_mlir_adapter":
/opt/conda/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /opt/conda/lib/python3.10/site-packages/ai_edge_model_explorer_adapter/_pywrap_convert_wrapper.so)

! Failed to load extension module ".builtin_tf_direct_adapter":
/opt/conda/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /opt/conda/lib/python3.10/site-packages/ai_edge_model_explorer_adapter

Loading "original-fs" failed
Error: Cannot find module 'original-fs'
Require stack:
- /root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js
[90m    at Module._resolveFilename (node:internal/modules/cjs/loader:1145:15)[39m
[90m    at Module._load (node:internal/modules/cjs/loader:986:27)[39m
[90m    at Module.require (node:internal/modules/cjs/loader:1233:19)[39m
[90m    at require (node:internal/modules/helpers:179:18)[39m
    at i (/root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js:3:98)
    at r.load (/root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js:2:1637)
    at h.load (/root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js:1:13958)
    at u (/root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js:3:9338)
    at Object.errorback (/root/.vscode-server/bin/b1c0a14de1414fcdaa400695b4db1c0799bc3124/out/server-cli.js:3:9457)
 

Stopping server...


In [None]:
class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, dropout, n_hidden_blocks):
        super(GCNEncoder, self).__init__()
        self.input_block = GraphConvBlock(in_channels, hidden_dim, dropout)
        self.hidden_blocks = ModuleList([GraphConvBlock(hidden_dim, hidden_dim, dropout) for _ in range(n_hidden_blocks - 1)])
        self.output_block = GraphConvBlock(hidden_dim, out_channels, dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_block(x, edge_index)
        for layer in self.hidden_blocks:
            x = layer(x, edge_index)
        x = self.output_block(x, edge_index)
        return global_mean_pool(x, batch) # (batch_size, out_channels)





In [None]:

#================================================PROJECTION HEAD=====================================================
class ProjectionHead(nn.Module):
    """
    Proyección de las embeddings de texto a un espacio de dimensión reducida.
    """
    def __init__(
        self,
        embedding_dim,# Salida del modelo de lenguaje (768)
        projection_dim, # Dimensión de la proyección (256)
        # dropout=0.1
    ):
        super(ProjectionHead, self).__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        # self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        # x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


#================================================CLASSIFIER HEADs=====================================================
class ClassifierHead(nn.Module):
    """
    Capa FC con activación softmax para que clasifique la clase.
    """
    def __init__(self, projection_dim, n_classes):
        super(ClassifierHead, self).__init__()
        self.fc = nn.Linear(projection_dim, n_classes)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        x = self.fc(x)
        return self.softmax(x)


#================================================MODEL======================================================
class SiameseGraphNetwork(nn.Module):
    def __init__(self, encoder, projection_head, classifier):
        super(SiameseGraphNetwork, self).__init__()
        self.encoder = encoder
        self.projection_head = projection_head
        self.classifier = classifier

    def forward(self, graph):
        x_1 = self.encoder(graph)
        x_1 = self.projection_head(x_1)
        x1_norm = F.normalize(x_1, p=2, dim=1)

        c1 = self.classifier(x1_norm)
        return x1_norm, c1
    

#==================================================================================================================

import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import MLP, GINConv, global_add_pool

class GINEncoder(torch.nn.Module):# GRaph Isomorphism Network
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = MLP([in_channels, hidden_channels, hidden_channels])
            self.convs.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels

        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
                       norm=None, dropout=0.5)

    def forward(self, data, batch_size):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        
        x = global_add_pool(x, batch, size=batch_size)
        return self.mlp(x)
    

#================================================MODEL======================================================
from torch_geometric.nn import GATConv, global_mean_pool

class GATEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GATEncoder, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=8, concat=True)
        self.conv2 = GATConv(hidden_channels * 8, out_channels, heads=8, concat=False)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        graph_embedding = global_mean_pool(x, batch)
        return graph_embedding


In [1]:
import numpy as np
import random
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import (MulticlassAccuracy, 
                                         MulticlassF1Score,
                                         MulticlassAUROC,
                                         MulticlassPrecision,
                                         MulticlassRecall)
from tqdm import tqdm
from dataset_handlers import (HCPHandler, 
                            TractoinfernoHandler,
                            FiberCupHandler)

from streamline_datasets import (MaxMinNormalization,
                                TestDataset, collate_test_ds)

from encoders import (SiameseGraphNetwork, GCNEncoder, 
                      GATEncoder, ProjectionHead, ClassifierHead)

# from custom_metrics import get_dice_metrics

# from torch.utils.tensorboard import SummaryWriter
# Comando para lanzar tensorboard en el navegador local a traves del puerto 8888 reenviado por ssh:
# tensorboard --logdir=runs/embedding_visualization --host 0.0.0.0 --port 8888


# import pandas as pd
import warnings
warnings.filterwarnings('ignore')




# Habilitar TensorFloat32 para una mejor performance en operaciones de multiplicación de matrices
torch.set_float32_matmul_precision('high')




class CFG:
    def __init__(self):
        self.seed = 42
      
        self.batch_size = 4096
        self.encoder = "GCN" # Las opciones son "GAT" o "GCN" o "HGPSL"
        self.dataset = "Tractoinferno" # "Tractoinferno o "FiberCup" o "HCP_105"

        if self.dataset == "HCP_105":
            self.ds_path = "/app/dataset/HCP_105"
            self.pretrained_model_path = "/app/pretrained_models/encoder_HCP_105.pt"
            self.n_classes = 72 # 72 tractos o 71 tractos sin CC

        elif self.dataset == "Tractoinferno":
            self.ds_path = "/app/dataset/Tractoinferno/tractoinferno_preprocessed_mni"
            self.pretrained_model_path = "/app/trained_models/checkpoint_Tractoinferno_GCN_512_0.pth"
            self.n_classes = 32
        
        elif self.dataset == "FiberCup":
            self.ds_path = "/app/dataset/Fibercup"
            self.pretrained_model_path = "/app/pretrained_models/encoder_fibercup.pt"
            self.n_classes = 7

        self.embedding_projection_dim = 512
        
    
cfg = CFG()




def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Establecer la semilla
seed_everything(cfg.seed)


# Cargar las rutas de los sujetos de entrenamiento, validación y test
if cfg.dataset == "HCP_105":
    handler = HCPHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()
    

elif cfg.dataset == "Tractoinferno":
    handler = TractoinfernoHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()

elif cfg.dataset == "FiberCup":
    handler = FiberCupHandler(path = cfg.ds_path, scope = "testset")
    test_data = handler.get_data()




# Crear el modelo, la función de pérdida y el optimizador
if cfg.encoder == "GAT":# Graph Attention Network
    encoder = GATEncoder(in_channels = 3, 
                         hidden_channels = 16, 
                         out_channels = 256)

elif cfg.encoder == "GCN":# Graph Convolutional Network
    encoder = GCNEncoder(in_channels = 3, 
                         hidden_dim = 128, 
                         out_channels = 128, 
                         dropout = 0.15, 
                         n_hidden_blocks = 2)
    





# Crear el modelo
model = SiameseGraphNetwork(
    encoder = encoder,
    projection_head = ProjectionHead(embedding_dim = 128, 
                                     projection_dim = cfg.embedding_projection_dim),
    classifier = ClassifierHead(projection_dim = cfg.embedding_projection_dim, 
                                n_classes = cfg.n_classes)
).cuda()

model

{0: 'AF_L', 1: 'AF_R', 2: 'CC_Fr_1', 3: 'CC_Fr_2', 4: 'CC_Oc', 5: 'CC_Pa', 6: 'CC_Pr_Po', 7: 'CG_L', 8: 'CG_R', 9: 'FAT_L', 10: 'FAT_R', 11: 'FPT_L', 12: 'FPT_R', 13: 'FX_L', 14: 'FX_R', 15: 'IFOF_L', 16: 'IFOF_R', 17: 'ILF_L', 18: 'ILF_R', 19: 'MCP', 20: 'MdLF_L', 21: 'MdLF_R', 22: 'OR_ML_L', 23: 'OR_ML_R', 24: 'POPT_L', 25: 'POPT_R', 26: 'PYT_L', 27: 'PYT_R', 28: 'SLF_L', 29: 'SLF_R', 30: 'UF_L', 31: 'UF_R'}


SiameseGraphNetwork(
  (encoder): GCNEncoder(
    (input_block): GraphConvBlock(
      (conv): GCNConv(3, 128)
      (bn): BatchNorm(128)
      (relu): LeakyReLU(negative_slope=0.01)
      (dropout): Dropout(p=0.15, inplace=False)
    )
    (hidden_blocks): ModuleList(
      (0): GraphConvBlock(
        (conv): GCNConv(128, 128)
        (bn): BatchNorm(128)
        (relu): LeakyReLU(negative_slope=0.01)
        (dropout): Dropout(p=0.15, inplace=False)
      )
    )
    (output_block): GraphConvBlock(
      (conv): GCNConv(128, 128)
      (bn): BatchNorm(128)
      (relu): LeakyReLU(negative_slope=0.01)
      (dropout): Dropout(p=0.15, inplace=False)
    )
  )
  (projection_head): ProjectionHead(
    (projection): Linear(in_features=128, out_features=512, bias=True)
    (gelu): GELU(approximate='none')
    (fc): Linear(in_features=512, out_features=512, bias=True)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (classifier): ClassifierHead(
    (fc): Linea