In [1]:
import torch
import numpy as np
import argparse
import pickle
from collections import defaultdict
from pathlib import Path
from tqdm.auto import tqdm

from data_grn_processing import load_grn_dataset_dgl
from model_grn import GRNGNN, prediction_dgl
from utils import set_config_args, get_comp_g_edge_labels, get_comp_g_path_labels
from utils import src_tgt_khop_in_subgraph, eval_edge_mask_auc, eval_edge_mask_topk_path_hit
from explainer_grn import PaGELink



  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
# DGL 그래프에서 feature dimension 가져오기 (feat이 아닌 모든 ndata 속성 사용)
def get_in_dim(mp_g):
    """
    DGL 그래프에서 모든 노드 feature의 총 차원을 계산하는 함수
    """
    node_feats = []
    for key in mp_g.ndata.keys():  # 모든 노드 feature 속성 확인
        feat = mp_g.ndata[key]  # 해당 feature 텐서 가져오기
        if len(feat.shape) == 2:  # (num_nodes, feature_dim) 형태일 경우만 추가
            node_feats.append(feat.shape[1])
    
    if not node_feats:
        raise ValueError("No valid node features found in graph! Check ndata.")

    return sum(node_feats)  # 모든 feature 차원을 더해서 총 in_dim 반환



In [4]:
device_id = 1

In [7]:
if torch.cuda.is_available() and device_id >= 0:
    device = torch.device('cuda', index=device_id)
else:
    device = torch.device('cpu')

print(device)

cuda:1


In [8]:
dataset_dir = "datasets"
dataset_name = "Ecoli1_basic_graph"
valid_ratio = 0.1
test_ratio = 0.2
g, processed_g = load_grn_dataset_dgl(dataset_dir,dataset_name, valid_ratio, test_ratio)
mp_g, train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = [g.to(device) for g in processed_g]

try:
    in_dim = get_in_dim(mp_g)
except KeyError:
    raise ValueError("Graph does not contain 'feat' in node features. Ensure features are properly assigned.")

hidden_dim_1 = 128
hidden_dim_2 = 64
out_dim = 32
af_val = "F.silu"
dec = "dot_sum"
num_layers = 4
num_epochs = 20
aggr = "sum"
var = "ChebConv"
saved_model_dir = "saved_models"
saved_modal_name = "basic_data_Ecoli1_InSilicoSize100_model"


model = GRNGNN(in_dim, hidden1_channels=128, hidden2_channels=64, out_channels=32, dec='dot_sum', af_val='F.silu', num_layers=4, epoch=20, aggr='sum', var='ChebConv').to(device)
state = torch.load(f'saved_models/basic_data_Ecoli1_InSilicoSize100_model.pth', map_location='cuda')
model.load_state_dict(state) 

pagelink = PaGELink(model, 
                    lr=0.01,
                    alpha=1.0, 
                    beta=1.0, 
                    num_epochs=20,
                    log=True,
                    af_val='F.silu').to(device)


In [None]:
from data_processing import load_dataset
valid_ratio = 0.1
test_ratio = 0.2
dataset_name = "aug_citation"
dataset_dir = 'datasets'

g, processed_g, pred_pair_to_edge_labels, pred_pair_to_path_labels = load_dataset(dataset_dir,
                                                                                  dataset_name,
                                                                                  valid_ratio,
                                                                                  test_ratio)


Type of pred_pair_to_edge_labels: <class 'dict'>
Type of pred_pair_to_path_labels: <class 'collections.defaultdict'>
Length of pred_pair_to_edge_labels: 4501
Length of pred_pair_to_path_labels: 4501
Sample keys in pred_pair_to_edge_labels: [(('author', 1), ('paper', np.int64(2217)))]
Sample keys in pred_pair_to_path_labels: [(('author', 1), ('paper', np.int64(2217)))]


In [None]:


print("Type of pred_pair_to_edge_labels:", type(pred_pair_to_edge_labels))
print("Type of pred_pair_to_path_labels:", type(pred_pair_to_path_labels))
print("Length of pred_pair_to_edge_labels:", len(pred_pair_to_edge_labels))
print("Length of pred_pair_to_path_labels:", len(pred_pair_to_path_labels))


print("Sample keys in pred_pair_to_edge_labels:", list(pred_pair_to_edge_labels.keys())[:1])
print("Sample keys in pred_pair_to_path_labels:", list(pred_pair_to_path_labels.keys())[:1])

In [20]:
print("Sample keys in pred_pair_to_edge_labels:", list(pred_pair_to_edge_labels.keys())[:1])
print("Sample keys in pred_pair_to_path_labels:", list(pred_pair_to_path_labels.keys())[:1])

Sample keys in pred_pair_to_edge_labels: [(('author', 1), ('paper', np.int64(2217)))]
Sample keys in pred_pair_to_path_labels: [(('author', 1), ('paper', np.int64(2217)))]


In [21]:
key = (('author', 1), ('paper', np.int64(2217)))  # 해당 키를 변수에 저장
value = pred_pair_to_edge_labels[key]  # 딕셔너리에서 키로 값 조회
print(value)  # 값 출력


defaultdict(<class 'set'>, {('author', 'writes', 'paper'): (tensor([1, 1]), tensor([529, 212])), ('paper', 'in', 'fos'): (tensor([529, 212]), tensor([812, 812])), ('fos', 'of', 'paper'): (tensor([812]), tensor([2217]))})


In [25]:
import torch
print(set(zip(torch.tensor([529, 212]), torch.tensor([812, 812]))))


{(tensor(212), tensor(812)), (tensor(529), tensor(812))}


In [22]:
key = (('author', 1), ('paper', np.int64(2217)))  # 해당 키를 변수에 저장
value = pred_pair_to_path_labels[key]  # 딕셔너리에서 키로 값 조회
print(value)  # 값 출력


[[(('author', 'writes', 'paper'), 1, 529), (('paper', 'in', 'fos'), 529, 812), (('fos', 'of', 'paper'), 812, 2217)], [(('author', 'writes', 'paper'), 1, 212), (('paper', 'in', 'fos'), 212, 812), (('fos', 'of', 'paper'), 812, 2217)]]


In [13]:


with torch.no_grad():
    pred, pos = prediction_dgl(model, mp_g, af_val, dec)

print(mp_g)

print(len(pred))
print(type(pred))
print(pos)


Graph(num_nodes=100, num_edges=173,
      ndata_schemes={'wildtype': Scheme(shape=(1,), dtype=torch.float32), 'id': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'KD': Scheme(shape=(1,), dtype=torch.float32), 'KO': Scheme(shape=(1,), dtype=torch.float32)})
173
<class 'numpy.ndarray'>
[ 1.40120630e+03  3.55243848e+03  5.89308301e+03  1.02057373e+04
  1.19388418e+04  1.40031113e+04  1.61383174e+04 -2.08246914e+04
  2.00330586e+04  2.18680156e+04  2.39251426e+04  2.58888828e+04
  2.76326660e+04  2.97433340e+04  3.16176836e+04  3.35535117e+04
  3.57002969e+04  3.76867383e+04  3.96445039e+04  4.15748906e+04
  4.62188312e+05  2.22560450e+06 -2.00040922e+05  9.22621094e+04
  1.35639781e+05  1.71262285e+04  3.25263031e+05  3.38506938e+05
  3.51278406e+05  3.63046062e+05  3.76481812e+05  3.89492875e+05
  4.01435000e+05  4.14375844e+05  4.25605312e+05 -1.16599078e+05
  3.41255188e+01 -2.35964102e+04  4.62188312e+05  7.46156172e+04
 -7.68391250e+04  6.37151211e+04 -9.34809766e+04 

In [None]:

test_src_nids, test_tgt_nids = test_pos_g.edges()
comp_graphs = defaultdict(list)
comp_g_labels = defaultdict(list)
num_hops = 2
i = 4
# Get the k-hop subgraph
src_nid, tgt_nid = test_src_nids[i], test_tgt_nids[i]
comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids, comp_g_eids = src_tgt_khop_in_subgraph( src_nid,
                                                                                     tgt_nid,
                                                                                            mp_g,
                                                                                            num_hops)
with torch.no_grad():
    pred, pos = prediction_dgl(model, mp_g, af_val, dec)


# Extract edges from the graph
edge_index = torch.stack(comp_g.edges(), dim=0).cpu().numpy()
src_tgt_pair = np.array([comp_g_src_nid.cpu().numpy(), comp_g_tgt_nid.cpu().numpy()]).reshape(2, 1)

# Check if the prediction contains this specific edge
mask = np.all(edge_index == src_tgt_pair, axis=0)

if mask.sum() > 0 and pred[mask][0]:  # src_nid -> tgt_nid에 대한 예측 값이 1인지 확인
    src_tgt = (int(src_nid), int(tgt_nid))
    comp_graphs[src_tgt] = [comp_g_src_nid, comp_g_tgt_nid, comp_g, comp_g_feat_nids]
