# Imports

In [1]:
import logging
from tqdm import tqdm
import numpy as np
import random
import pickle as pkl
import pandas as pd
import scipy.sparse as sp

import open3d as o3d

import dgl
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
from dgl.dataloading import GraphDataLoader

import torch
from torch.utils.data.sampler import SubsetRandomSampler
import torch_geometric
from torch_geometric.data import Data

from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import f1_score, mean_squared_error, mean_absolute_error, r2_score

from graphmae.utils import (
    build_args,
    create_optimizer,
    set_random_seed,
    TBLogger,
    get_current_lr,
    load_best_configs,
)
from graphmae.datasets.data_util import load_graph_classification_dataset
from graphmae.models import build_model
from graphmae.evaluation import linear_probing_for_inductive_node_classiifcation, LogisticRegression

INFO - 2023-10-15 16:19:25,614 - instantiator - Created a temporary directory at /tmp/tmpi08_v8ai
INFO - 2023-10-15 16:19:25,616 - instantiator - Writing /tmp/tmpi08_v8ai/_remote_module_non_scriptable.py


# Functions

In [8]:
def triangle_mesh_to_adjacency_matrix(mesh):
    # Get the vertices and triangles of the mesh
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)

    # Create an empty adjacency matrix
    n_vertices = len(vertices)
    adjacency_matrix = sp.lil_matrix((n_vertices, n_vertices), dtype=np.float32)

    # Iterate through the triangles and add edges to the adjacency matrix
    for tri in triangles:
        adjacency_matrix[tri[0], tri[1]] = 1.0
        adjacency_matrix[tri[1], tri[0]] = 1.0
        adjacency_matrix[tri[1], tri[2]] = 1.0
        adjacency_matrix[tri[2], tri[1]] = 1.0
        adjacency_matrix[tri[2], tri[0]] = 1.0
        adjacency_matrix[tri[0], tri[2]] = 1.0

    # Convert the adjacency matrix to a more efficient sparse matrix representation
    adjacency_matrix = adjacency_matrix.tocsr()
    
    return adjacency_matrix

#############################################################################################################

def open3d_to_dgl_graph(path, open3d_geometry):
    intensity_path = path.replace("registered_meshes","organ_decimations_ply")
    intensity_mesh = o3d.io.read_triangle_mesh(intensity_path)
    open3d_geometry.compute_vertex_normals()

    # Extract points, normals and adjacency information
    points = open3d_geometry.vertices
    adjacency_matrix = triangle_mesh_to_adjacency_matrix(open3d_geometry)
    # Create a DGL graph from the adjacency matrix
    dgl_graph = dgl.from_scipy(adjacency_matrix)

    # Add node features (e.g., point coordinates) to the DGL graph
    points_np = np.array(open3d_geometry.vertices)
    normals_np = np.array(open3d_geometry.vertex_normals)
    intensities_np = np.array(intensity_mesh.vertex_colors)
    features = np.concatenate((points_np, intensities_np, normals_np), axis=1)
    
    dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)

    return dgl_graph

#############################################################################################################

def pretrain(model, dataloaders, optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger=None):
    logging.info("start training..")
    train_loader, val_loader, test_loader, eval_train_loader = dataloaders

    epoch_iter = tqdm(range(max_epoch))

    if isinstance(train_loader, list) and len(train_loader) ==1:
        train_loader = [train_loader[0].to(device)]
        eval_train_loader = train_loader
    if isinstance(val_loader, list) and len(val_loader) == 1:
        val_loader = [val_loader[0].to(device)]
        test_loader = val_loader

    for epoch in epoch_iter:
        model.train()
        loss_list = []

        for subgraph in train_loader:
            subgraph = subgraph.to(device)
            loss, loss_dict = model(subgraph, subgraph.ndata["feat"])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())

        if scheduler is not None:
            scheduler.step()

        train_loss = np.mean(loss_list)
        epoch_iter.set_description(f"# Epoch {epoch} | train_loss: {train_loss:.4f}")
        if logger is not None:
            loss_dict["lr"] = get_current_lr(optimizer)
            logger.note(loss_dict, step=epoch)

        if epoch == (max_epoch//2):
            # print(model)
            evaluate(model, (eval_train_loader, val_loader, test_loader), num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob, mute=True)
    return model

#############################################################################################################

def evaluate(model, loaders, num_classes, lr_f, weight_decay_f, max_epoch_f, device, linear_prob=True, mute=False):
    model.eval()
    if linear_prob:
        if len(loaders[0]) > 1:
            x_all = {"train": [], "val": [], "test": []}
            y_all = {"train": [], "val": [], "test": []}

            with torch.no_grad():
                for key, loader in zip(["train", "val", "test"], loaders):
                    for subgraph in loader:
                        subgraph = subgraph.to(device)
                        feat = subgraph.ndata["feat"]
                        x = model.embed(subgraph, feat)
                        # print(f'latent space: {x}')
                        # print(f'latent space shape: {x.shape}')
                        x_all[key].append(x)
                        y_all[key].append(subgraph.ndata["feat"])  
            in_dim = x_all["train"][0].shape[1]
            encoder = LogisticRegression(in_dim, num_classes)
            num_finetune_params = [p.numel() for p in encoder.parameters() if  p.requires_grad]
            if not mute:
                print(f"num parameters for finetuning: {sum(num_finetune_params)}")
                # torch.save(x.cpu(), "feat.pt")
            
            encoder.to(device)
            optimizer_f = create_optimizer("adam", encoder, lr_f, weight_decay_f)
            final_acc, estp_acc = mutli_graph_linear_evaluation(encoder, x_all, y_all, optimizer_f, max_epoch_f, device, mute)
            return final_acc, estp_acc
        else:
            x_all = {"train": None, "val": None, "test": None}
            y_all = {"train": None, "val": None, "test": None}

            with torch.no_grad():
                for key, loader in zip(["train", "val", "test"], loaders):
                    for subgraph in loader:
                        subgraph = subgraph.to(device)
                        feat = subgraph.ndata["feat"]
                        x = model.embed(subgraph, feat)
                        mask = subgraph.ndata[f"{key}_mask"]
                        x_all[key] = x[mask]
                        y_all[key] = subgraph.ndata["label"][mask]  
            in_dim = x_all["train"].shape[1]
            
            encoder = LogisticRegression(in_dim, num_classes)
            encoder = encoder.to(device)
            optimizer_f = create_optimizer("adam", encoder, lr_f, weight_decay_f)

            x = torch.cat(list(x_all.values()))
            y = torch.cat(list(y_all.values()))
            num_train, num_val, num_test = [x.shape[0] for x in x_all.values()]
            num_nodes = num_train + num_val + num_test
            train_mask = torch.arange(num_train, device=device)
            val_mask = torch.arange(num_train, num_train + num_val, device=device)
            test_mask = torch.arange(num_train + num_val, num_nodes, device=device)
            
            final_acc, estp_acc = linear_probing_for_inductive_node_classiifcation(encoder, x, y, (train_mask, val_mask, test_mask), optimizer_f, max_epoch_f, device, mute)
            return final_acc, estp_acc
    else:
        raise NotImplementedError
    
#############################################################################################################

def mutli_graph_linear_evaluation(model, feat, labels, optimizer, max_epoch, device, mute=False):
    criterion = torch.nn.BCEWithLogitsLoss()

    best_val_acc = 0
    best_val_epoch = 0
    best_val_test_acc = 0

    if not mute:
        epoch_iter = tqdm(range(max_epoch))
    else:
        epoch_iter = range(max_epoch)

    for epoch in epoch_iter:
        model.train()
        for x, y in zip(feat["train"], labels["train"]):
            out = model(None, x)
            loss = criterion(out, y)
            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3)
            optimizer.step()

        with torch.no_grad():
            model.eval()
            val_out = []
            test_out = []
            for x, y in zip(feat["val"], labels["val"]):
                val_pred = model(None, x)
                val_out.append(val_pred)
            val_out = torch.cat(val_out, dim=0).cpu().numpy()
            val_label = torch.cat(labels["val"], dim=0).cpu().numpy()
            # val_out = np.where(val_out >= 0.0, 1.0, 0.0)

            for x, y in zip(feat["test"], labels["test"]):
                test_pred = model(None, x)# 
                test_out.append(test_pred)
            test_out = torch.cat(test_out, dim=0).cpu().numpy()
            test_label = torch.cat(labels["test"], dim=0).cpu().numpy()
            # test_out = np.where(test_out >= 0.0, 1.0, 0.0)

            # val_acc = f1_score(val_label, val_out, average="micro")
            # test_acc = f1_score(test_label, test_out, average="micro")

            # mse = mean_squared_error(val_label, val_out)
            val_acc = mean_absolute_error(val_label, val_out)
            test_acc = mean_absolute_error(val_label, val_out)
            # r2 = r2_score(y_true, y_pred)
        
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            best_val_epoch = epoch
            best_val_test_acc = test_acc

        if not mute:
            epoch_iter.set_description(f"# Epoch: {epoch}, train_loss:{loss.item(): .4f}, val_acc:{val_acc}, test_acc:{test_acc: .4f}")

    if mute:
        print(f"# IGNORE: --- Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch}, Early-stopping-TestAcc: {best_val_test_acc:.4f},  Final-TestAcc: {test_acc:.4f}--- ")
    else:
        print(f"--- Best ValAcc: {best_val_acc:.4f} in epoch {best_val_epoch}, Early-stopping-TestAcc: {best_val_test_acc:.4f}, Final-TestAcc: {test_acc:.4f} --- ")

    return test_acc, best_val_test_acc


# Arguments

In [6]:
class args(object):
    seed= 42
    device= 0
    max_epoch=200
    warmup_steps=-1

    num_heads=1
    num_out_heads= 1 
    num_layers= 2
    num_hidden= 256
    residual= False
    in_drop= 0.2
    attn_drop= 0.1
    norm= None
    lr= 0.001
    weight_decay= 5e-4
    negative_slope= 0.2
    activation= "prelu"
    mask_rate= 0.3
    drop_edge_rate= 0.0
    replace_rate: float = 0.15

    encoder= "gat"
    decoder= "gat"
    loss_fn= "sce"
    alpha_l= 2 #pow coefficient for sce loss
    optimizer = "adam"
    
    max_epoch_f= 30
    lr_f= 0.001
    weight_decay_f= 0.0
    linear_prob= True
    concat_hidden = True
    num_features = 0

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = args.seed
max_epoch = args.max_epoch
max_epoch_f = args.max_epoch_f
num_hidden = args.num_hidden
num_layers = args.num_layers
encoder_type = args.encoder
decoder_type = args.decoder
replace_rate = args.replace_rate

optim_type = args.optimizer 
loss_fn = args.loss_fn

lr = args.lr
weight_decay = args.weight_decay
lr_f = args.lr_f
weight_decay_f = args.weight_decay_f
linear_prob = args.linear_prob
concat_hidden = args.concat_hidden

In [8]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.determinstic = True

# Data

In [12]:
graphs = []
mesh = o3d.io.read_triangle_mesh("../../../local_data/organ_decimations_ply/2000/1000180/liver_mesh.ply")
dgl_graph = open3d_to_dgl_graph(mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
print(dgl_graph)
graphs.append(dgl_graph)

mesh = o3d.io.read_triangle_mesh("../../../local_data/organ_decimations_ply/2000/1000071/liver_mesh.ply")
dgl_graph = open3d_to_dgl_graph(mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
graphs.append(dgl_graph)

mesh = o3d.io.read_triangle_mesh("../../../local_data/organ_decimations_ply/2000/2901448/liver_mesh.ply")
dgl_graph = open3d_to_dgl_graph(mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
graphs.append(dgl_graph)
dgl_graph

# decoder_g = pre_use_g.clone()
# array_zeros = np.zeros((np.asarray(pre_use_g.ndata["feat"]).shape[0], np.asarray(pre_use_g.ndata["feat"]).shape[1]))
# decoder_g.ndata['feat'] = torch.tensor(array_zeros, dtype=torch.float32)
# dgl_graph.ndata.pop('feat')

Graph(num_nodes=1100, num_edges=7046,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})


Graph(num_nodes=1079, num_edges=7071,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})

In [13]:
with open('../../data/gae/liver/data', 'rb') as f:
        data = pkl.load(f)
data

Data(x=[1073, 3], y=[2, 4534], num_nodes=1073, val_pos_edge_index=[2, 106], test_pos_edge_index=[2, 213], train_pos_edge_index=[2, 3622], train_neg_adj_mask=[1073, 1073], val_neg_edge_index=[2, 106], test_neg_edge_index=[2, 213])

In [14]:
args.num_features = data.x.shape[1]
num_classes = data.x.shape[1]
scheduler = None
logger = None
args.num_features

3

# Model

In [None]:
model = build_model(args)
model.to(device)
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

In [157]:
acc_list = []
estp_acc_list = []
scheduler = None
logger = None

In [158]:
model = pretrain(model, (graphs, graphs, graphs, graphs), optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
# model = pretrain(model, (train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, data), optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
model = model.cpu()

# model = model.to(device)
# model.eval()

2023-06-21 15:34:15,204 - INFO - start training..
# Epoch 100 | train_loss: 0.0004:  50%|█████     | 101/200 [00:24<00:34,  2.91it/s]

# IGNORE: --- Best ValAcc: 0.5383 in epoch 2, Early-stopping-TestAcc: 0.5383,  Final-TestAcc: 0.4947--- 


# Epoch 199 | train_loss: 0.0005: 100%|██████████| 200/200 [00:49<00:00,  4.08it/s]


# Encoding

In [2]:
model = torch.load("../../models/liver_mesh.ply_graphmae.pt", map_location=torch.device('cpu'))


# subgraph = ""
# feat = ""

# x = model.embed(subgraph, feat)
# model = model.to(device)

# model.eval()

# Pooling Embeddings

In [1]:
############################################################################################

def triangle_mesh_to_adjacency_matrix(mesh):
    # Get the vertices and triangles of the mesh
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)

    # Create an empty adjacency matrix
    n_vertices = len(vertices)
    adjacency_matrix = sp.lil_matrix((n_vertices, n_vertices), dtype=np.float32)

    # Iterate through the triangles and add edges to the adjacency matrix
    for tri in triangles:
        adjacency_matrix[tri[0], tri[1]] = 1.0
        adjacency_matrix[tri[1], tri[0]] = 1.0
        adjacency_matrix[tri[1], tri[2]] = 1.0
        adjacency_matrix[tri[2], tri[1]] = 1.0
        adjacency_matrix[tri[2], tri[0]] = 1.0
        adjacency_matrix[tri[0], tri[2]] = 1.0

    # Convert the adjacency matrix to a more efficient sparse matrix representation
    adjacency_matrix = adjacency_matrix.tocsr()
    
    return adjacency_matrix

#############################################################################################################

def open3d_to_dgl_graph(path, open3d_geometry):
    intensity_path = path.replace("registered_meshes","organ_decimations_ply")
    intensity_mesh = o3d.io.read_triangle_mesh(intensity_path)
    open3d_geometry.compute_vertex_normals()

    # Extract points, normals and adjacency information
    points = open3d_geometry.vertices
    adjacency_matrix = triangle_mesh_to_adjacency_matrix(open3d_geometry)
    # Create a DGL graph from the adjacency matrix
    dgl_graph = dgl.from_scipy(adjacency_matrix)

    # Add node features (e.g., point coordinates) to the DGL graph
    points_np = np.array(open3d_geometry.vertices)
    normals_np = np.array(open3d_geometry.vertex_normals)
    intensities_np = np.array(intensity_mesh.vertex_colors)
    # features = np.concatenate((points_np, normals_np, intensities_np), axis=1)
    features = points_np
    
    dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)

    return dgl_graph

#############################################################################################################

def get_single_subject(path, organ, subject):

    mesh = o3d.io.read_triangle_mesh(f'{path}{subject}/{organ}')
    dgl_graph = open3d_to_dgl_graph(f'{path}{subject}/{organ}', mesh)
    dgl_graph = dgl_graph.remove_self_loop()
    dgl_graph = dgl_graph.add_self_loop()

    return dgl_graph

############################################################################################

In [14]:
import open3d as o3d
import numpy as np
import pandas as pd
import scipy.sparse as sp
import dgl
import torch
dgl_graph = get_single_subject("../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/", "liver_mesh.ply", "1000071")
dgl_graph = dgl_graph
model_path = "../../models/liver_mesh.ply_graphmae_no_feat.pt"
model = torch.load(model_path, map_location=torch.device('cpu'))

z = model.embed(dgl_graph, dgl_graph.ndata['feat'], True)

rep, pred = model.decode(dgl_graph, z)
pred = pred.detach().cpu().numpy()
rep = rep.detach().cpu().numpy()

z = z.detach().cpu().numpy()

In [18]:
z.shape

(1087, 1280)

In [17]:
rep.shape

(1087, 256)

In [23]:
graph_embedding = np.mean(rep, axis=0)
graph_embedding.shape

(256,)

In [2]:
import open3d as o3d
import numpy as np
import pandas as pd
import scipy.sparse as sp
import dgl
import torch
dgl_graph = get_single_subject("../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/", "liver_mesh.ply", "1000071")
dgl_graph = dgl_graph
model_path = "../../models/liver_mesh.ply_graphmae_gat_pool_mlp.pt"
model = torch.load(model_path, map_location=torch.device('cpu'))

z = model.embed(dgl_graph, dgl_graph.ndata['feat'], False)

z = z.detach().cpu().numpy()

INFO - 2023-09-26 14:24:28,075 - utils - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO - 2023-09-26 14:24:28,076 - utils - NumExpr defaulting to 8 threads.


In [12]:
print(z.shape)

(1, 3000)


In [16]:
import pickle as pkl
with open(f'../../../../../../vol/aimspace/users/wyo/latent_spaces/vertices_prediction_sort_pool_gmae/liver/1000071', "rb") as fp:
    x = pkl.load(fp)
fp.close()
x.shape

(1, 3000)

In [7]:
from dgl.nn import SortPooling
sortpool = SortPooling(k=1000)
x= sortpool(dgl_graph, dgl_graph.ndata['feat']) 
x.shape[1]/3

1000.0

In [8]:
dgl_graph

Graph(num_nodes=1087, num_edges=7067,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})