# imports

In [2]:
import os
import argparse
import ast
import pickle as pkl
from itertools import tee
import random
import sklearn
import numpy as np
import pandas as pd
import open3d as o3d
import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import train_test_split_edges, negative_sampling
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, Linear, LayerNorm
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_sparse import SparseTensor
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.utils import class_weight

INFO - 2023-11-10 09:57:41,153 - instantiator - Created a temporary directory at /tmp/tmp9ttna07b
INFO - 2023-11-10 09:57:41,155 - instantiator - Writing /tmp/tmp9ttna07b/_remote_module_non_scriptable.py


# Arguments

In [3]:
path="../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/"
organ="liver_mesh.ply"
output=False
save=False
    
    
# train=True
epochs=1
batchs=4
use_input_encoder=True
in_features=3
encoder_features=516
hidden_channels=[1024, 1024, 512, 512, 256, 256, 256]
num_conv_layers=4
num_classes=1
activation="ELU"
normalization=True
layer="gat"
dropout=0.005
lr=0.00007
weight_decay=0.002
optimizer="adam"

# Functions

In [65]:
def get_data(path, organ, wanted_label='disease'):
    registered_mesh = []
    labels_path = "../data/liver_diseases.csv"
    train_ids_path = "../data/NonNa_organs_split_train.txt"
    val_ids_path = "../data/NonNa_organs_split_val.txt"
    test_ids_path = "../data/NonNa_organs_split_test.txt"
    labels = pd.read_csv(labels_path, delimiter=",", dtype=str, index_col=0)    
    train_dirs = np.loadtxt(train_ids_path, delimiter=",", dtype=str)
    val_dirs = np.loadtxt(val_ids_path, delimiter=",", dtype=str)
    test_dirs = np.loadtxt(test_ids_path, delimiter=",", dtype=str)
    dirs = next(os.walk(path))[1]
    train_dataset = []
    val_dataset = []
    test_dataset = []
    errors = []
    #In Test
    dirs = dirs[:5]
    print(f'Number of samples used: {len(dirs)}', flush=True)

    body_fields = ["eid", "22407-2.0", "22408-2.0", "31-0.0"]
    full_ukbb_data = pd.read_csv("../../../../../../vol/aimspace/projects/ukbb/data/tabular/ukb668815_imaging.csv", usecols=body_fields)
    full_ukbb_data_new_names = {'22407-2.0':'VAT', '22408-2.0':'ASAT', '31-0.0':'sex'}
    full_ukbb_data = full_ukbb_data.rename(index=str, columns=full_ukbb_data_new_names)
    
    basic_features = pd.read_csv("../data/basic_features.csv")
    basic_features_new_names = {'21003-2.0':'age', '31-0.0':'sex', '21001-2.0':'bmi', '21002-2.0':'weight','50-2.0':'height'}
    basic_features = basic_features.rename(index=str, columns=basic_features_new_names)
    print(f'Number of samples used: {len(dirs)}, with label: {wanted_label}', flush=True)

    if(wanted_label == 'sex' or wanted_label == 'VAT' or wanted_label == 'ASAT'):
        features = full_ukbb_data
    else:
        features = basic_features
    
    for dir in dirs:
        registered_mesh = []
        try: 
            mesh = o3d.io.read_triangle_mesh(f'{path}{dir}/{organ}')
            vertices_data = np.asarray(mesh.vertices)
            triangles = np.asarray(mesh.triangles)
            vertices = torch.from_numpy(vertices_data).double()
            edges = []
            for triangle in triangles:
                edges.append([triangle[0], triangle[1]])
                edges.append([triangle[0], triangle[2]])
                edges.append([triangle[1], triangle[2]])
            edges_torch = [[],[]]
            edges = np.unique(np.array(edges), axis=0)
            for edge in edges:
                edges_torch[0].append(edge[0])
                edges_torch[1].append(edge[1])
            edges_torch = torch.from_numpy(np.asarray(edges_torch)).long()
            registered_mesh.append((vertices.type(torch.float32), edges_torch))

            # label = labels.loc[int(dir)].to_list()
            # label = [int(id) for id in label]
            # label_np = np.asarray(label)
            # label_tensor = torch.from_numpy(label_np)

            if(wanted_label == 'disease'):
                label = 1 if int(dir) in labels.index else 0
                label_tensor = label
                data = Data(x=registered_mesh[0][0], y=label_tensor, edge_index=registered_mesh[0][1], num_nodes= len(registered_mesh[0][0]))
                if(dir in train_dirs):
                    train_dataset.append(data)
                if(dir in val_dirs):
                    val_dataset.append(data)
                elif(dir in test_dirs):
                    test_dataset.append(data)
            else:
                cur_patient_feature = features[features['eid'] == int(dir)]
                if(len(cur_patient_feature[wanted_label]) == 1):
                    if(not pd.isnull(cur_patient_feature[wanted_label].item())):
                        cur_patient_feature_tensor = torch.tensor(cur_patient_feature[wanted_label].item())
                        data = Data(x=registered_mesh[0][0], y=cur_patient_feature_tensor, edge_index=registered_mesh[0][1], num_nodes= len(registered_mesh[0][0]))
                        if(dir in train_dirs):
                            train_dataset.append(data)
                        if(dir in val_dirs):
                            val_dataset.append(data)
                        elif(dir in test_dirs):
                            test_dataset.append(data)

        except:
            errors.append(dir)
            
    return train_dataset, val_dataset, test_dataset
    
############################################################################################
#Generating GNN layers
def get_gnn_layers(num_conv_layers: int, hidden_channels, num_inp_features:int, 
                 gnn_layer, activation=nn.ReLU, normalization=None, dropout = None):
    """Creates GNN layers"""
    layers = nn.ModuleList()
    for i in range(num_conv_layers):
        if i == 0:
            layers.append(gnn_layer(num_inp_features, hidden_channels[i]))
            layers.append(activation())
            if normalization is not None:
                layers.append(normalization(hidden_channels[i]))
        else:
            layers.append(gnn_layer(hidden_channels[i-1], hidden_channels[i]))
            layers.append(activation())
            if normalization is not None:
                layers.append(normalization(hidden_channels[i]))
    return nn.ModuleList(layers)
    
############################################################################################
#Making multilayer perceptron layers 
def get_mlp_layers(channels: list, activation, output_activation=nn.Identity):
    """Define basic multilayered perceptron network."""
    layers = []
    *intermediate_layer_definitions, final_layer_definition = pairwise(channels)
    for in_ch, out_ch in intermediate_layer_definitions:
        intermediate_layer = nn.Linear(in_ch, out_ch)
        layers += [intermediate_layer, activation()]
    layers += [nn.Linear(*final_layer_definition), output_activation()]
    #print('Output activation ',output_activation)
    return nn.Sequential(*layers)
    
############################################################################################
#Iterate over all pairs of consecutive items in a list
def pairwise(iterable):
    """Iterate over all pairs of consecutive items in a list.
    Notes
    -----
        [s0, s1, s2, s3, ...] -> (s0,s1), (s1,s2), (s2, s3), ...
    """
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)
    
############################################################################################

def reduce_loss(loss, reduction):
    # none: 0, elementwise_mean:1, sum: 2
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    elif reduction == "sum":
        return loss.sum()

############################################################################################

def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
    # if weight is specified, apply element-wise weight
    if weight is not None:
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == 'mean':
            # Avoid causing ZeroDivisionError when avg_factor is 0.0,
            # i.e., all labels of an image belong to ignore index.
            eps = torch.finfo(torch.float32).eps
            loss = loss.sum() / (avg_factor + eps)
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != 'none':
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss

############################################################################################

def focal_loss(pred, target, alpha = 1, gamma = 2, weight=None, reduction='mean', avg_factor=None):
    target = target.type_as(pred)
    pt = (1 - pred) * target + pred * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

############################################################################################

#Train function
def train(model, optimizer, dataloader, alpha, gamma, threshold = 0.5, loss_fn = nn.BCEWithLogitsLoss()):
    """Train network on training dataset."""
    model.train()
    cumulative_loss = 0.0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out.squeeze(1), data.y.float())
        # print(f"output: {out.squeeze(1)}, gt: {data.y.float()}, loss: {loss}", flush=True)
        loss.backward()
        cumulative_loss += loss.item()
        optimizer.step()
    return cumulative_loss / len(dataloader)
    # model.train()
    # cumulative_loss = 0.0
    # for data in dataloader:
    #     data = data.to(device)      
    #     weights = torch.where(data.y.float() == 1.0, torch.tensor(5), torch.tensor(1))  

    #     output = model(data) 
    #     output = output.squeeze(1).flatten()
    #     prediction = (output > threshold).float() 
    #     prediction = torch.tensor(prediction, requires_grad=True)
    #     intermediate_losses = loss_fn(prediction, data.y.float())
    #     loss = torch.mean(weights*intermediate_losses)
    #     # loss = focal_loss(output, data.y.float(), alpha = alpha, gamma = gamma, weight=None, reduction='mean', avg_factor=None)
    #     cumulative_loss += loss.item()

    #     optimizer.zero_grad()
    #     loss.backward()
    #     optimizer.step()
    # return cumulative_loss / len(dataloader)
    
############################################################################################

#Validation function
def calculate_val_loss(model, dataloader, alpha, gamma, threshold = 0.5, loss_fn = nn.BCEWithLogitsLoss()):
    model.eval()
    cumulative_loss = 0.0
    for data in dataloader:
        data = data.to(device)
        out = model(data)
        loss = loss_fn(out.squeeze(1), data.y.float())
        cumulative_loss += loss.item()
    return cumulative_loss / len(dataloader)
    # model.eval()
    # cumulative_loss = 0.0
    # for data in dataloader:
    #     data = data.to(device)        
    #     weights = torch.where(data.y.float() == 1.0, torch.tensor(5), torch.tensor(1))    

    #     output = model(data) 
    #     output = output.squeeze(1).flatten()
    #     prediction = (output > threshold).float() 
    #     prediction = torch.tensor(prediction, requires_grad=True)
    #     intermediate_losses = loss_fn(prediction, data.y.float())
    #     loss = torch.mean(weights*intermediate_losses)
    #     # loss = focal_loss(output, data.y.float(), alpha = alpha, gamma = gamma, weight=None, reduction='mean', avg_factor=None)
    #     cumulative_loss += loss.item()
    
    # return cumulative_loss / len(dataloader)
    
############################################################################################

#Test function
def test(model, dataloader, alpha, gamma, threshold = 0.5, loss_fn = nn.BCEWithLogitsLoss()):    
    model.eval()
    prediction_accuracies = []
    prediction_f1 = []
    for data in dataloader:
        data = data.to(device)
        predictions = model(data)
        # predicted_class_labels = torch.nn.Sigmoid()(predictions)
        predicted_class_labels = predictions.squeeze(1)
        predicted_class_labels = torch.round(predicted_class_labels)

        correct_assignments = (predicted_class_labels == data.y.float()).sum()
        num_assignemnts = predicted_class_labels.shape[0]
        # print(f"output: {predictions}, pred: {predicted_class_labels}, gt: {data.y.float()}, correct_assignments: {correct_assignments}, num_assignemnts: {num_assignemnts}, Acc: {float(correct_assignments / num_assignemnts)}", flush=True)
        prediction_accuracies.append(float(correct_assignments / num_assignemnts))
        f1_score = sklearn.metrics.precision_score(predicted_class_labels.int().detach().cpu(), data.y.int().detach().cpu(), average='weighted')
        prediction_f1.append((f1_score))
    print(prediction_f1)
    return sum(prediction_accuracies) / len(dataloader), sum(prediction_f1) / len(dataloader)
    # model.eval()
    # cumulative_loss = 0.0
    # correct = 0
    # total = 0
    # for data in dataloader:
    #     data = data.to(device)        
    #     weights = torch.where(data.y.float() == 1.0, torch.tensor(5), torch.tensor(1))  

    #     output = model(data)
    #     output = output.squeeze(1).flatten()
    #     prediction = (output > threshold).float() 
    #     prediction = torch.tensor(prediction, requires_grad=True)
    #     intermediate_losses = loss_fn(prediction, data.y.float())
    #     print(f"pred: {prediction}, gt: {data.y.float()}, in_losses: {intermediate_losses}, final losses: {torch.mean(weights*intermediate_losses)}")
    #     loss = torch.mean(weights*intermediate_losses)
    #     # loss = focal_loss(prediction, data.y.float(), alpha = alpha, gamma = gamma, weight=None, reduction='mean', avg_factor=None)
    #     cumulative_loss += loss.item()

    #     correct += (prediction == data.y).sum().item()
    #     total += data.y.size(0)
    
    # loss = cumulative_loss / len(dataloader)
    # accuracy = correct / total
    # precision = sklearn.metrics.precision_score(prediction.int(), data.y.int())
    # recall = sklearn.metrics.recall_score(prediction.int(), data.y.int())
    # f1_score = sklearn.metrics.f1_score(prediction.int(), data.y.int())
    # ap = sklearn.metrics.average_precision_score(prediction.int(), data.y.int())
    # auc = sklearn.metrics.auc(prediction.int(), data.y.int())
    # return loss, accuracy, precision, recall, f1_score, ap, auc
    
############################################################################################

#Optimizer
def build_optimizer(network, optimizer, learning_rate, weight_decay):
    if optimizer == "sgd":
        optimizer = torch.optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = torch.optim.Adam(network.parameters(),
                               lr=learning_rate, weight_decay=weight_decay)
    return optimizer
    
############################################################################################

# Model

In [24]:
class GNN(torch.nn.Module):
    def __init__(self, in_features, hidden_channels, activation, normalization, num_classes, num_conv_layers=4, layer='gcn',
                 use_input_encoder=True, encoder_features=128, apply_batch_norm=True,
                 apply_dropout_every=True, dropout = 0):
        super(GNN, self).__init__()
        torch.manual_seed(42)
        
        self.fc = torch.nn.ModuleList()
        self.layer_type = layer
        self.num_classes = num_classes
        self.use_input_encoder = use_input_encoder
        self.apply_batch_norm = apply_batch_norm
        self.dropout = dropout
        self.normalization_bool = normalization
        self.activation = activation
        self.apply_dropout_every = apply_dropout_every

        if self.normalization_bool:
            self.normalization = LayerNorm
        else:
            self.normalization = None

        if self.use_input_encoder :
            self.input_encoder = get_mlp_layers(
                channels=[in_features, encoder_features],
                activation=nn.ELU,
            )
            in_features = encoder_features

        if layer == 'gcn':
            self.layers = get_gnn_layers(num_conv_layers, hidden_channels, num_inp_features=in_features,
                                        gnn_layer=GCNConv,activation=activation,normalization=self.normalization )
        elif layer == 'sageconv':
            self.layers = get_gnn_layers(num_conv_layers, hidden_channels,in_features,
                                        gnn_layer=SAGEConv,activation=activation,normalization=self.normalization )
        elif layer == 'gat':
            self.layers = get_gnn_layers(num_conv_layers, hidden_channels,in_features,
                                        gnn_layer=GATConv,activation=activation,normalization=self.normalization )     

        for i in range((len(hidden_channels)-num_conv_layers)):
            self.fc.append(Linear(hidden_channels[i+num_conv_layers-1], hidden_channels[i+num_conv_layers]))
        
        self.pred_layer = Linear(hidden_channels[len(hidden_channels)-1], self.num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if self.use_input_encoder:
            x = self.input_encoder(x)

        if self.normalization is None:
            for i, layer in enumerate(self.layers):
                # Each GCN consists 2 modules GCN -> Activation 
                # GCN send edge index
                if i% 2 == 0:
                    x = layer(x, edge_index)
                else:
                    x = layer(x)

                if self.apply_dropout_every:
                    x = F.dropout(x, p=self.dropout, training=self.training)
        else:
            for i, layer in enumerate(self.layers):
                # Each GCN consists 3 modules GCN -> Activation ->  Normalization 
                # GCN send edge index
                if i% 3 == 0:
                    x = layer(x, edge_index)
                else:
                    x = layer(x)

                if self.apply_dropout_every:
                    x = F.dropout(x, p=self.dropout, training=self.training)      

        # 2. Readout layer
        x = global_max_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=self.dropout, training=self.training)

       
        for i in range(len(self.fc)):
           x = self.fc[i](x)
           x = torch.tanh(x)
           x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.pred_layer(x)  
        x = torch.nn.Sigmoid()(x)

        return x
    
############################################################################################

# Training

In [25]:
torch_geometric.seed_everything(42)

registeration_path = path

#Model Parameters
activation = torch.nn.modules.activation.ELU
model_params = dict(
        use_input_encoder = use_input_encoder,
        in_features= in_features, 
        encoder_features = encoder_features,
        hidden_channels= hidden_channels,
        num_classes= num_classes,
        activation=activation,
        normalization = normalization,
        layer = layer,
        num_conv_layers = num_conv_layers,
        dropout = dropout)

    

# model
model = GNN(**model_params)

# move to GPU (if available)
device = 'cpu'
model = model.to(device)

train_dataset, val_dataset, test_dataset = get_data(path, organ)

healthy_train_dataset = []
healthy_val_dataset = []
healthy_test_dataset = []
unhealthy_train_dataset = []
unhealthy_val_dataset = []
unhealthy_test_dataset = []

count = [0,0,0,0,0,0]
for i in train_dataset:
    count[i.y] = count[i.y] + 1
    if(i.y==0):
        healthy_train_dataset.append(i)
    else:
        unhealthy_train_dataset.append(i)
for i in val_dataset:
    count[i.y+2] = count[i.y+2] + 1
    if(i.y==0):
        healthy_val_dataset.append(i)
    else:
        unhealthy_val_dataset.append(i)
for i in test_dataset:
    count[i.y+4] = count[i.y+4] + 1
    if(i.y==0):
        healthy_test_dataset.append(i)
    else:
        unhealthy_test_dataset.append(i)
print(f"Original Distrubution: {count}", flush=True)

healthy_train_dataset = random.sample(healthy_train_dataset, count[1])
healthy_val_dataset = random.sample(healthy_val_dataset, count[3])
healthy_test_dataset = random.sample(healthy_test_dataset, count[5])

train_dataset = healthy_train_dataset + unhealthy_train_dataset
val_dataset = healthy_val_dataset + unhealthy_val_dataset
test_dataset = healthy_test_dataset + unhealthy_test_dataset

count = [0,0,0,0,0,0]
for i in train_dataset:
    count[i.y] = count[i.y] + 1
for i in val_dataset:
    count[i.y+2] = count[i.y+2] + 1
for i in test_dataset:
    count[i.y+4] = count[i.y+4] + 1
print(f"Used Distrubution: {count}", flush=True)


train_loader = DataLoader(dataset = train_dataset, batch_size=batchs, shuffle=True )
valid_loader = DataLoader(dataset = val_dataset, batch_size=batchs, shuffle=True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batchs, shuffle=True)

Number of samples used: 382, 62, 40
Original Distrubution: [376, 6, 60, 2, 39, 1]
Used Distrubution: [6, 6, 2, 2, 1, 1]


In [26]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [40]:
# inizialize the optimizer
optimizer = build_optimizer(model, optimizer, lr, weight_decay)
alpha, gamma, threshold = 0.1, 2, 0.3

# loss_fn = nn.BCELoss(reduction='none')
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.0, device=device))
    
for epoch in range(1, epochs + 1):
    train_loss = train(model, optimizer, train_loader, alpha, gamma, threshold, loss_fn)
    val_loss = calculate_val_loss(model, valid_loader, alpha, gamma, threshold, loss_fn)

    # test_loss, test_accuracy, p, r, f1, ap, auc = test(model, test_loader, alpha, gamma, threshold, loss_fn)
    acc, f1_score = test(model, test_loader, alpha, gamma, threshold, loss_fn)

    # print(f"train_loss: {train_loss}, val_loss: {val_loss}, test_loss: {test_loss}, test_accuracy: {test_accuracy}, p: {p}, r: {r}, f1: {f1}, ap: {ap}, auc: {auc}")
    print(f"train_loss: {train_loss}, val_loss: {val_loss}, acc: {acc}, f1: {f1_score}")

[1.0]
train_loss: 0.9571395715077718, val_loss: 0.9568238258361816, acc: 0.5, f1: 1.0


# Prediction

In [41]:
test_data = test_dataset[1]
model.eval()
output = model(test_data)

output = output.squeeze(1).flatten()
output = torch.nn.Sigmoid()(output)
prediction = (output > threshold).float()  
print(prediction.int())
print(test_data.y)

tensor([1], dtype=torch.int32)
1


In [42]:
from collections import Counter
print(Counter(np.asarray(prediction)))
print(Counter(np.asarray([test_data.y])))

Counter({1.0: 1})
Counter({1: 1})


In [43]:
labels_path = "../data/liver_diseases.csv"
labels = pd.read_csv(labels_path, delimiter=",", dtype=str, index_col=0) 
dirs = next(os.walk(path))[1]
gt = []
for dir in dirs:
    label = 1 if int(dir) in labels.index else 0
    gt.append(label)

print(Counter(np.asarray(gt)))

Counter({0: 29891, 1: 488})


In [44]:
count = [0,0,0,0,0,0]
for i in train_dataset:
    count[i.y] = count[i.y] + 1
for i in val_dataset:
    count[i.y+2] = count[i.y+2] + 1
for i in test_dataset:
    count[i.y+4] = count[i.y+4] + 1

count

[6, 6, 2, 2, 1, 1]

# Tests

In [30]:
def get_data(path="../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/", organ="liver_mesh.ply", label="height", save=False):
    registered_mesh = []
    test_ids_path = "../data/NonNa_organs_split_test.txt"
    test_dirs = np.loadtxt(test_ids_path, delimiter=",", dtype=str)
    dirs = next(os.walk(path))[1]
    train_dataset = []
    test_dataset = []

    #In Test
    dirs = dirs[:5]
    body_fields = ["eid", "22407-2.0", "22408-2.0", "31-0.0"]
    full_ukbb_data = pd.read_csv("../../../../../../vol/aimspace/projects/ukbb/data/tabular/ukb668815_imaging.csv", usecols=body_fields)
    full_ukbb_data_new_names = {'22407-2.0':'VAT', '22408-2.0':'ASAT', '31-0.0':'sex'}
    full_ukbb_data = full_ukbb_data.rename(index=str, columns=full_ukbb_data_new_names)
    
    basic_features = pd.read_csv("../data/basic_features.csv")
    basic_features_new_names = {'21003-2.0':'age', '31-0.0':'sex', '21001-2.0':'bmi', '21002-2.0':'weight','50-2.0':'height'}
    basic_features = basic_features.rename(index=str, columns=basic_features_new_names)
    print(f'Number of samples used: {len(dirs)}, with label: {label}', flush=True)

    if(label == 'sex' or label == 'VAT' or label == 'ASAT'):
        features = full_ukbb_data
    else:
        features = basic_features
    
    for dir in dirs:
        registered_mesh = []
        mesh = o3d.io.read_triangle_mesh(f'{path}{dir}/{organ}')
    
        vertices_data = np.asarray(mesh.vertices)
        triangles = np.asarray(mesh.triangles)
        vertices = torch.from_numpy(vertices_data).double()
        edges = []
        for triangle in triangles:
            edges.append([triangle[0], triangle[1]])
            edges.append([triangle[0], triangle[2]])
            edges.append([triangle[1], triangle[2]])
            
        edges_torch = [[],[]]
        edges =np.unique(np.array(edges), axis=0)
        for edge in edges:
            edges_torch[0].append(edge[0])
            edges_torch[1].append(edge[1])
    
        edges_torch = torch.from_numpy(np.asarray(edges_torch)).long()

        cur_patient_feature = features[features['eid'] == int(dir)]
        if(len(cur_patient_feature[label]) == 1):
            if(not pd.isnull(cur_patient_feature[label].item())):
                cur_patient_feature_tensor = torch.tensor(cur_patient_feature[label].item())
                registered_mesh.append((vertices.type(torch.float32), edges_torch, cur_patient_feature_tensor.type(torch.float32)))
                data = Data(x=registered_mesh[0][0], y=registered_mesh[0][2], edge_index=registered_mesh[0][1], num_nodes= len(registered_mesh[0][0]))
                print(data, flush=True)
                if(dir in test_dirs):
                    test_dataset.append(data)
                else:
                    train_dataset.append(data)
                # data = train_test_split_edges(data)
                # print(data)
    
    if(save):
        with open(f'../data/infograph/{organ}/data', 'wb') as f:
            pkl.dump(train_dataset, f)

    return train_dataset, test_dataset

In [31]:
x, y = get_data()

Number of samples used: 5, with label: height
Data(x=[1087, 3], edge_index=[2, 4555], y=178.0, num_nodes=1087)
Data(x=[1062, 3], edge_index=[2, 4504], y=188.0, num_nodes=1062)
Data(x=[1091, 3], edge_index=[2, 4487], y=164.0, num_nodes=1091)
Data(x=[1086, 3], edge_index=[2, 4567], y=181.0, num_nodes=1086)
Data(x=[1086, 3], edge_index=[2, 4522], y=163.0, num_nodes=1086)
