# Imports

In [3]:
import argparse
import random
import pickle as pkl

import dgl

from dataset import *
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from evaluate_embeddings import evaluate_embedding
from model import InfoGraph

INFO - 2023-09-27 17:02:10,141 - utils - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO - 2023-09-27 17:02:10,142 - utils - NumExpr defaulting to 8 threads.


# Functions

In [4]:
def collate(samples):
    """collate function for building graph dataloader"""

    # graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(samples)
    # batched_labels = th.tensor(labels)

    n_graphs = len(samples)
    graph_id = torch.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata["graph_id"] = graph_id

    return batched_graph

# Arguments

In [5]:
class args(object):
    seed= 42
    device= 0
    max_epoch=200
    warmup_steps=-1

    num_heads=1
    num_out_heads= 1 
    num_layers= 2
    num_hidden= 256
    residual= False
    in_drop= 0.2
    attn_drop= 0.1
    norm= None
    lr= 0.001
    weight_decay= 5e-4
    negative_slope= 0.2
    activation= "prelu"
    mask_rate= 0.3
    drop_edge_rate= 0.0
    replace_rate: float = 0.15

    encoder= "gat"
    decoder= "gat"
    loss_fn= "sce"
    alpha_l= 2 #pow coefficient for sce loss
    optimizer = "adam"
    
    max_epoch_f= 30
    lr_f= 0.001
    weight_decay_f= 0.0
    linear_prob= True
    concat_hidden = True
    num_features = 0

In [6]:
device = torch.device('cpu')
seed = args.seed
max_epoch = args.max_epoch
max_epoch_f = args.max_epoch_f
num_hidden = args.num_hidden
num_layers = args.num_layers
encoder_type = args.encoder
decoder_type = args.decoder
replace_rate = args.replace_rate

optim_type = args.optimizer 
loss_fn = args.loss_fn

lr = args.lr
weight_decay = args.weight_decay
lr_f = args.lr_f
weight_decay_f = args.weight_decay_f
linear_prob = args.linear_prob
concat_hidden = args.concat_hidden

In [7]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.determinstic = True

# Data

In [8]:
graphs = []
path = "../../../../../../../vol/aimspace/users/wyo/organ_decimations_ply/2000/1000180/liver_mesh.ply"
mesh = o3d.io.read_triangle_mesh(path)
dgl_graph = open3d_to_dgl_graph(path, mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
print(dgl_graph)
graphs.append(dgl_graph)

path = "../../../../../../../vol/aimspace/users/wyo/organ_decimations_ply/2000/1000071/liver_mesh.ply"
mesh = o3d.io.read_triangle_mesh(path)
dgl_graph = open3d_to_dgl_graph(path, mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
graphs.append(dgl_graph)

path = "../../../../../../../vol/aimspace/users/wyo/organ_decimations_ply/2000/2901448/liver_mesh.ply"
mesh = o3d.io.read_triangle_mesh(path)
dgl_graph = open3d_to_dgl_graph(path, mesh)
dgl_graph = dgl_graph.remove_self_loop()
dgl_graph = dgl_graph.add_self_loop()
graphs.append(dgl_graph)
dgl_graph

# decoder_g = pre_use_g.clone()
# array_zeros = np.zeros((np.asarray(pre_use_g.ndata["feat"]).shape[0], np.asarray(pre_use_g.ndata["feat"]).shape[1]))
# decoder_g.ndata['feat'] = torch.tensor(array_zeros, dtype=torch.float32)
# dgl_graph.ndata.pop('feat')

Graph(num_nodes=1150, num_edges=7120,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})


Graph(num_nodes=1086, num_edges=7064,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})

In [12]:
dataloader = GraphDataLoader(
        graphs,
        batch_size=3,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )
dataloader

<dgl.dataloading.dataloader.GraphDataLoader at 0x7f8a846d34c0>

In [13]:
for graph in dataloader:
            # print(graph)
            # print(graph.ndata["feat"])
            print(graph.ndata["graph_id"])

tensor([0, 0, 0,  ..., 2, 2, 2])


In [13]:
with open('../../data/gae/liver/data', 'rb') as f:
        data = pkl.load(f)
data

Data(x=[1073, 3], y=[2, 4534], num_nodes=1073, val_pos_edge_index=[2, 106], test_pos_edge_index=[2, 213], train_pos_edge_index=[2, 3622], train_neg_adj_mask=[1073, 1073], val_neg_edge_index=[2, 106], test_neg_edge_index=[2, 213])

In [14]:
args.num_features = data.x.shape[1]
num_classes = data.x.shape[1]
scheduler = None
logger = None
args.num_features

3

# Model

In [None]:
model = build_model(args)
model.to(device)
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

In [157]:
acc_list = []
estp_acc_list = []
scheduler = None
logger = None

In [158]:
model = pretrain(model, (graphs, graphs, graphs, graphs), optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
# model = pretrain(model, (train_dataloader, valid_dataloader, test_dataloader, eval_train_dataloader, data), optimizer, max_epoch, device, scheduler, num_classes, lr_f, weight_decay_f, max_epoch_f, linear_prob, logger)
model = model.cpu()

# model = model.to(device)
# model.eval()

2023-06-21 15:34:15,204 - INFO - start training..
# Epoch 100 | train_loss: 0.0004:  50%|█████     | 101/200 [00:24<00:34,  2.91it/s]

# IGNORE: --- Best ValAcc: 0.5383 in epoch 2, Early-stopping-TestAcc: 0.5383,  Final-TestAcc: 0.4947--- 


# Epoch 199 | train_loss: 0.0005: 100%|██████████| 200/200 [00:49<00:00,  4.08it/s]


# Encoding

In [2]:
model = torch.load("../../models/liver_mesh.ply_graphmae.pt", map_location=torch.device('cpu'))


# subgraph = ""
# feat = ""

# x = model.embed(subgraph, feat)
# model = model.to(device)

# model.eval()

# Pooling Embeddings

In [1]:
############################################################################################

def triangle_mesh_to_adjacency_matrix(mesh):
    # Get the vertices and triangles of the mesh
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)

    # Create an empty adjacency matrix
    n_vertices = len(vertices)
    adjacency_matrix = sp.lil_matrix((n_vertices, n_vertices), dtype=np.float32)

    # Iterate through the triangles and add edges to the adjacency matrix
    for tri in triangles:
        adjacency_matrix[tri[0], tri[1]] = 1.0
        adjacency_matrix[tri[1], tri[0]] = 1.0
        adjacency_matrix[tri[1], tri[2]] = 1.0
        adjacency_matrix[tri[2], tri[1]] = 1.0
        adjacency_matrix[tri[2], tri[0]] = 1.0
        adjacency_matrix[tri[0], tri[2]] = 1.0

    # Convert the adjacency matrix to a more efficient sparse matrix representation
    adjacency_matrix = adjacency_matrix.tocsr()
    
    return adjacency_matrix

#############################################################################################################

def open3d_to_dgl_graph(path, open3d_geometry):
    intensity_path = path.replace("registered_meshes","organ_decimations_ply")
    intensity_mesh = o3d.io.read_triangle_mesh(intensity_path)
    open3d_geometry.compute_vertex_normals()

    # Extract points, normals and adjacency information
    points = open3d_geometry.vertices
    adjacency_matrix = triangle_mesh_to_adjacency_matrix(open3d_geometry)
    # Create a DGL graph from the adjacency matrix
    dgl_graph = dgl.from_scipy(adjacency_matrix)

    # Add node features (e.g., point coordinates) to the DGL graph
    points_np = np.array(open3d_geometry.vertices)
    normals_np = np.array(open3d_geometry.vertex_normals)
    intensities_np = np.array(intensity_mesh.vertex_colors)
    # features = np.concatenate((points_np, normals_np, intensities_np), axis=1)
    features = points_np
    
    dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)

    return dgl_graph

#############################################################################################################

def get_single_subject(path, organ, subject):

    mesh = o3d.io.read_triangle_mesh(f'{path}{subject}/{organ}')
    dgl_graph = open3d_to_dgl_graph(f'{path}{subject}/{organ}', mesh)
    dgl_graph = dgl_graph.remove_self_loop()
    dgl_graph = dgl_graph.add_self_loop()

    return dgl_graph

############################################################################################

In [14]:
import open3d as o3d
import numpy as np
import pandas as pd
import scipy.sparse as sp
import dgl
import torch
dgl_graph = get_single_subject("../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/", "liver_mesh.ply", "1000071")
dgl_graph = dgl_graph
model_path = "../../models/liver_mesh.ply_graphmae_no_feat.pt"
model = torch.load(model_path, map_location=torch.device('cpu'))

z = model.embed(dgl_graph, dgl_graph.ndata['feat'], True)

rep, pred = model.decode(dgl_graph, z)
pred = pred.detach().cpu().numpy()
rep = rep.detach().cpu().numpy()

z = z.detach().cpu().numpy()

In [18]:
z.shape

(1087, 1280)

In [17]:
rep.shape

(1087, 256)

In [23]:
graph_embedding = np.mean(rep, axis=0)
graph_embedding.shape

(256,)

In [2]:
import open3d as o3d
import numpy as np
import pandas as pd
import scipy.sparse as sp
import dgl
import torch
dgl_graph = get_single_subject("../../../../../../vol/aimspace/users/wyo/registered_meshes/2000/", "liver_mesh.ply", "1000071")
dgl_graph = dgl_graph
model_path = "../../models/liver_mesh.ply_graphmae_gat_pool_mlp.pt"
model = torch.load(model_path, map_location=torch.device('cpu'))

z = model.embed(dgl_graph, dgl_graph.ndata['feat'], False)

z = z.detach().cpu().numpy()

INFO - 2023-09-26 14:24:28,075 - utils - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO - 2023-09-26 14:24:28,076 - utils - NumExpr defaulting to 8 threads.


In [12]:
print(z.shape)

(1, 3000)


In [16]:
import pickle as pkl
with open(f'../../../../../../vol/aimspace/users/wyo/latent_spaces/vertices_prediction_sort_pool_gmae/liver/1000071', "rb") as fp:
    x = pkl.load(fp)
fp.close()
x.shape

(1, 3000)

In [7]:
from dgl.nn import SortPooling
sortpool = SortPooling(k=1000)
x= sortpool(dgl_graph, dgl_graph.ndata['feat']) 
x.shape[1]/3

1000.0

In [8]:
dgl_graph

Graph(num_nodes=1087, num_edges=7067,
      ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={})