In [2]:
from dig.xgraph.method import SubgraphX
from dig.xgraph.method.subgraphx import find_closest_node_result

from model import GNN
from pretrain_joao import graphcl

from torch_geometric.data import DataLoader, Data
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from loader import BioDataset
from dataloader import DataLoaderFinetune
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class GNN_graphpred_for_x(torch.nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        drop_ratio (float): dropout rate
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set
        
    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536
    """
    def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "graphsage"):
        super(GNN_graphpred_for_x, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        #if self.num_layer < 2:
        #    raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)

        #Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Linear(100, 1))
        else:
            raise ValueError("Invalid graph pooling type.")
        self.linear = torch.nn.Linear(self.emb_dim, 100)
        self.graph_pred_linear = torch.nn.Linear(100, self.num_tasks)
        self.softmax = torch.nn.Softmax(dim=1)

    #def from_pretrained(self, model_file):
    #    self.gnn.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))

    def forward(self, x, edge_index, do_visualize=False):
        #x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        node_representation = self.gnn(x, edge_index, None)
        device = torch.device('cuda:4')
        pooled = self.pool(node_representation, torch.LongTensor([0]*116).to(device))
        
        #graph_rep = torch.cat([pooled, center_node_rep], dim = 1)
        graph_rep = self.linear(pooled)
        
        #if do_visualize:
        #    visualize(graph_rep, data.y)      
  
        return self.softmax(self.graph_pred_linear(graph_rep))


In [26]:
num_layer = 2
emb_dim = 128
num_tasks = 2
dropout_ratio = 0.5
graph_pooling = 'mean'
gnn_type = 'graphsage'

In [27]:
model = GNN_graphpred_for_x(num_layer, emb_dim, num_tasks, JK = "last", drop_ratio=dropout_ratio, graph_pooling=graph_pooling, gnn_type = gnn_type)
model_graphcl = graphcl(model.gnn, emb_dim)
model.gnn = model_graphcl.gnn

In [28]:
PATH = '/nasdata3/kyj/graphcl/GraphCL_Automated/transferLearning_MoleculeNet_PPI/bio/fintune_weight/'
FILE = 'graphsage_none_lr1e-4_NYU_epoch300_batch800_s_decay0_layer2_5fold_dim128_fc100_fold5_seed0.pt'
model.load_state_dict(torch.load(PATH+FILE))

<All keys matched successfully>

In [29]:
print(model)

GNN_graphpred_for_x(
  (gnn): GNN(
    (gnns): ModuleList(
      (0): GraphSAGEConv(
        (linear): Linear(in_features=128, out_features=128, bias=True)
        (edge_encoder): Linear(in_features=9, out_features=128, bias=True)
        (linear1): Linear(in_features=232, out_features=128, bias=True)
      )
      (1): GraphSAGEConv(
        (linear): Linear(in_features=128, out_features=128, bias=True)
        (edge_encoder): Linear(in_features=9, out_features=128, bias=True)
      )
    )
  )
  (linear): Linear(in_features=128, out_features=100, bias=True)
  (graph_pred_linear): Linear(in_features=100, out_features=2, bias=True)
  (softmax): Softmax(dim=1)
)


In [30]:
device = torch.device('cuda:4')

In [31]:
#ABIDE1에서 제공하는 전처리된 aal roi 사용해서 ASD and control classification(downstream task)을 위한 dataset 만들기
sub_list = open('/nasdata4/kyj0305/ABIDE/Download_preprocess/NYU_list','r').read().split('\n')
sub_list.pop()
one_hot = torch.zeros(116,116).to(device)
for i in range(116):
    one_hot[i][i] = 1
dataset=[]
x = []
edge_index = []
for i in sub_list:
    with open(f'/nasdata4/kyj0305/ABIDE/Download_preprocess/roi_aal_1d/pcc_{i}.txt','r') as f:
       #l = pickle.load(f)
       x =[[float(num) for num in line.split(' ')] for line in f]
       x = torch.FloatTensor(x).to(device)
       x = torch.transpose(x,0,1)
       x = torch.cat((x,one_hot),1)
       #print(x)
        
    with open(f'/nasdata4/kyj0305/ABIDE/Download_preprocess/y/{i}.txt','r') as f:
        f = f.read()
        y = torch.LongTensor([int(f)]).to(device)
        
    with open(f'/nasdata4/kyj0305/ABIDE/Download_preprocess/edge/edge_index20/ROICorrelation_FisherZ_{i}_edge_index.txt') as f:
        edge_index = torch.LongTensor([[int(num) for num in line.split()] for line in f]).to(device)
        edge_index = torch.transpose(edge_index,0,1)

    my_data = Data(x=x, edge_index=edge_index, y=y)

    dataset.append(my_data)

In [32]:
x = dataset[0].x
edge_indel=dataset[0].edge_index
print(x)
print(edge_index)

tensor([[ 1.0000,  0.3842,  0.0923,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3842,  1.0000,  0.0387,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0923,  0.0387,  1.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.1210,  0.0976, -0.0380,  ...,  1.0000,  0.0000,  0.0000],
        [-0.2144,  0.2723,  0.1242,  ...,  0.0000,  1.0000,  0.0000],
        [-0.2691,  0.1478, -0.0415,  ...,  0.0000,  0.0000,  1.0000]],
       device='cuda:4')
tensor([[  0,   0,   0,  ..., 113, 113, 113],
        [  1,   2,   3,  ...,  99, 102, 110]], device='cuda:4')


In [37]:
explainer = SubgraphX(model, num_classes=2, device=4)
_, explanation_results, related_preds = explainer(x, edge_index, max_nodes=5)

TypeError: forward() got an unexpected keyword argument 'data'

: 