In [248]:
import json
import networkx as nx
import numpy as np
import torch
import KGNN
from tqdm import tqdm
from torch_geometric.utils import to_networkx, from_networkx
from networkx.algorithms.traversal.breadth_first_search import bfs_edges
from torch_geometric.data import Data


# Data path
data_path = '../data_process/pretrain_data'
# Training Labels
## Load entity type labels
print('Loading entity type labels...')
ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy')) # (num_ent, num_ent_type)

## Load center molecule motifs
print('Loading center molecule motifs...')
motifs = []
with open(f'{data_path}/id2motifs.json', 'r') as f:
    id2motifs = json.load(f)
motif_len = len(id2motifs['0'])
for i in range(len(ent_type)):
    if str(i) in id2motifs.keys():
        motifs.append(np.array(id2motifs[str(i)]))
    else:
        motifs.append(np.array([0] * motif_len))

motifs = torch.tensor(np.array(motifs), dtype=torch.long) # (num_ent, motif_len)

## Center molecule ids
center_molecule_ids = [int(key) for key in id2motifs.keys()]

# Entire Knowledge Graph (MolKG)
print('Loading entire knowledge graph...')
G = nx.read_gpickle(f'{data_path}/graph.gpickle')
G_tg = from_networkx(G)

# molecule_mask
print('Loading molecule mask...')
molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1) # (num_edges,)

Loading entity type labels...
Loading center molecule motifs...
Loading entire knowledge graph...
Loading molecule mask...


In [247]:
motifs.shape

torch.Size([32022, 85])

In [95]:
from networkx.algorithms.traversal.breadth_first_search import bfs_edges
from torch_geometric.utils import k_hop_subgraph

def get_subgraph(G_nx, idx):    
    # Prepare a subgraph with the specified center molecule
    subgraph = nx.Graph()

    # BFS, limit to 3 hops
    bfs_tree = bfs_edges(G_nx, idx, depth_limit=2)
    
    for edge in bfs_tree:
        node1, node2 = edge

        # Stop extending if the node is not a molecule.
        ## Case 1: node1 is molecule, node2 is not
        if ent_type[node1][0] == 1 and ent_type[node2][0] != 1:
            subgraph.add_edge(node1, node2, **G_nx.edges[edge])
        ## Case 2: node1 is molecule, node2 is also molecule
        elif (ent_type[node1][0] == 1) and (ent_type[node2][0] == 1):
            subgraph.add_edge(node1, node2, **G_nx.edges[edge])
                
    return subgraph


In [109]:
from torch_geometric.utils import k_hop_subgraph
import torch

molecule_mask = ent_type[:,0][G_tg.edge_index[0]] == 1
def get_subgraph(G_tg, ent_type, idx):
    nodes, filt_edge_index, _, edge_mask = k_hop_subgraph(idx, 1, G_tg.edge_index)
    double_mask = molecule_mask * edge_mask
    mask_idx = torch.where(double_mask)[0]
    edge_subgraph = G_tg.edge_subgraph(mask_idx)
    subgraph = edge_subgraph.subgraph(nodes)
    masked_node_ids = edge_subgraph.edge_index[0]

    return subgraph, masked_node_ids

    

In [124]:
nodes, filt_edge_index, idx_mapping, edge_mask = k_hop_subgraph(0, 1, G_tg.edge_index)

In [129]:
filt_edge_index[0][0]

tensor(0)

In [130]:
import torch

ent_type = torch.tensor(ent_type, dtype=torch.long)

In [234]:
# nodes, filt_edge_index, _, edge_mask = k_hop_subgraph(0, 1, G_tg.edge_index)
# molecule_mask = ent_type[:,0][filt_edge_index[0]] == 1
# masked_edge_index= filt_edge_index[:, molecule_mask] # (2, num_edges)
# relation = G_tg.relation[edge_mask][molecule_mask]
# subgraph = Data(edge_index=masked_edge_index, relabel=True)
# subgraph.relation = relation

In [252]:
molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1)
ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy'))

nodes, _, _, edge_mask = k_hop_subgraph(0, 1, G_tg.edge_index)
double_mask = molecule_mask * edge_mask
mask_idx = torch.where(double_mask)[0]
edge_subgraph = G_tg.edge_subgraph(mask_idx)
subgraph = edge_subgraph.subgraph(nodes)
masked_node_ids = edge_subgraph.edge_index[0] # (num_masked_nodes,)
motif_labels = motifs[masked_node_ids] # (num_masked_nodes, motif_len)
node_labels = ent_type[masked_node_ids] # (num_masked_nodes, num_ent_type)
rel_labels = subgraph.relation

subgraph, masked_node_ids.shape, motif_labels.shape, node_labels.shape, rel_labels.shape

(Data(edge_index=[2, 1951], relation=[1951], num_nodes=385),
 torch.Size([1951]),
 torch.Size([1951, 85]),
 torch.Size([1951, 16]),
 torch.Size([1951]))

In [251]:
rel_labels

tensor([ 0,  0,  0,  ..., 31, 31, 31])

In [238]:
print(G_tg.edge_index[:, mask_idx][0])

tensor([    0,     0,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,    29,    29,    29,    29,   661,   661,   661,   661,   661,
          661,   661,   661,   661,   661,  2277,  2343,  2343,  2343, 18686,
        18686, 18686, 18686, 18686, 18686, 18686, 28000, 28000, 28000, 28749,
        29943, 29943, 29943, 29943, 29943, 29943, 33784, 33784, 33784, 33784,
        33784, 39941, 39941, 39941, 39941, 39941, 39941, 42393, 42393, 42393,
        42393, 42393, 42393, 42393, 42393, 42393, 42624, 42624, 42624, 42624,
        42624, 42624, 43274, 43462, 48532, 48532, 48532, 48532, 48532, 48532,
        58132, 58132, 58132, 58132, 58132, 58132, 58132, 60341, 65188])


In [239]:
subgraph.edge_index

tensor([[ 0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,
          5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  8,  9,  9,  9,  9,
          9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 14, 15, 16, 16, 16, 16, 16, 16,
         17, 17, 17, 17, 17, 17, 17, 18, 19],
        [ 1,  7,  0,  2,  4,  5,  3, 18,  8,  6, 14, 11, 16,  7, 15, 13,  9, 10,
         17, 19, 12,  1,  2,  5,  3,  1,  2, 16, 11, 12, 10,  6, 13,  9, 17,  1,
          1,  2,  7,  1,  3, 13,  9, 17, 12, 11,  0,  1,  5,  1,  1,  3,  6, 13,
         17, 12,  1,  3, 11, 16, 12,  1,  3,  6, 10, 16, 12,  1,  3,  6,  9, 10,
         11, 16, 13, 17,  1,  3,  6,  9, 12, 17,  1,  1,  1,  3, 10, 11, 12, 17,
          1,  3,  6,  9, 12, 13, 16,  1,  1]])

In [107]:
s1.number_of_nodes(), s1.number_of_edges()

(3433, 3432)

In [94]:
for c_id in tqdm(center_molecule_ids):
    subgraph = s1 = get_subgraph(G, c_id)
    

16979

In [1]:
from pretrain import *

# your credentials

params = {
    "lr": 1e-4,
    "hidden_dim": 200,
    "epochs": 100,
    "lambda": "08_15_10"
}
logger = get_logger(lr=params['lr'], hidden_dim=params['hidden_dim'], epochs=params['epochs'], lambda_=params['lambda'])


# Data path
data_path = '../data_process/pretrain_data'
print('Getting everything prepared...')
ent_type, motifs, G_tg, center_molecule_ids, molecule_mask = get_everything(data_path)

# Get dataloader
train_loader, val_loader = get_dataloader(G_tg, center_molecule_ids, molecule_mask, motifs, ent_type, batch_size=1)

# Load KGE embeddings
# return:
# entity_embedding: (num_ent, emb_dim)
# relation_embedding: (num_rel, emb_dim)
emb_path = '/data/pj20/molkg_kge/transe'
print('Loading KGE embeddings...')
entity_embedding, relation_embedding = load_kge_embeddings(emb_path)

# Initialize model
print('Initializing model...')
model = KGNN(
    node_emb=entity_embedding,
    rel_emb=relation_embedding,
    num_nodes=ent_type.shape[0],
    num_rels=39,
    embedding_dim=512,
    hidden_dim=200,
    num_motifs=motifs.shape[1],
    lambda_edge=0.8,
    lambda_motif=1.5,
    lambda_mol_class=1
)

# Train
device = torch.device('cuda:1')
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 100

  from .autonotebook import tqdm as notebook_tqdm


Getting everything prepared...
Loading entity type labels...
Loading center molecule motifs...
Loading entire knowledge graph...
Loading molecule mask...
Loading KGE embeddings...
Loading KGE embeddings...


  molecule_mask = torch.tensor(ent_type[:,0][G_tg.edge_index[0]] == 1) # (num_edges,)
  return torch.tensor(entity_embedding), torch.tensor(relation_embedding)


Initializing model...


In [4]:
for batch in train_loader:
    print(batch)
    break

DataBatch(edge_index=[2, 122], relation=[122], num_nodes=22, masked_node_ids=[122], rel_ids=[122], center_molecule_id=[1], motif_labels=[122, 85], node_labels=[122, 16], batch=[22], ptr=[2])


In [2]:
 train_loop(model, train_loader, val_loader, optimizer, device, epochs)

  0%|          | 0/30420 [00:00<?, ?it/s]


RuntimeError: Index tensor must have the same number of dimensions as self tensor

In [6]:
torch.zeros((5, 39))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])