In [15]:
import json
import csv

with open('/home/pj20/gode/data_process/valid_smiles_ids.json', 'r') as f:
    valid_smiles_ids = json.load(f)

with open('/home/pj20/gode/data_process/valid_smiles.csv') as f:
    reader = csv.reader(f)
    next(reader)  # skip header

    valid_smiles = []
    for line in reader:
        smiles = line[0]
        valid_smiles.append(line[0])

In [16]:
len(valid_smiles_ids), len(valid_smiles)

(41212, 41212)

In [17]:
smile2entid = {}

for i in range(len(valid_smiles)):
    smile2entid[valid_smiles[i]] = valid_smiles_ids[i]

In [1]:
import torch
import numpy as np
import json
from grover.model.models import GroverFpGeneration, GroverFinetuneTask, GroverFinetuneKGE, GroverKGNNFinetuneTask, KGNN, MGNN

def get_everything(data_path):
    # Training Labels
    ## Load entity type labels
    print('Loading entity type labels...')
    ent_type = torch.tensor(np.load(f'{data_path}/ent_type_onehot.npy')) # (num_ent, num_ent_type)

    ## Load center molecule motifs
    print('Loading center molecule motifs...')
    motifs = []
    with open(f'{data_path}/id2motifs.json', 'r') as f:
        id2motifs = json.load(f)
    motif_len = len(id2motifs['0'])
    for i in range(len(ent_type)):
        if str(i) in id2motifs.keys():
            motifs.append(np.array(id2motifs[str(i)]))
        else:
            motifs.append(np.array([0] * motif_len))

    motifs = torch.tensor(np.array(motifs), dtype=torch.long) # (num_ent, motif_len)


    # Entire Knowledge Graph (MolKG)
    print('Loading entire knowledge graph...')
    with open(f'{data_path}/graph.pt', 'rb') as f:
        G_tg = torch.load(f)

    return ent_type, motifs, G_tg

  from .autonotebook import tqdm as notebook_tqdm


In [41]:
KHOP = 3
KGE = True
HIDDEN_EMB =1200
def build_model_kgnn():
    print("Preparing KGNN data...")
    data_path = '/data/pj20/molkg/pretrain_data'
    ent_type, motifs, _ = get_everything(data_path)

    kgnn = KGNN(
        node_emb=None,
        rel_emb=None,
        num_nodes=ent_type.shape[0],
        num_rels=39,
        embedding_dim=512,
        hidden_dim=200,
        num_motifs=motifs.shape[1],
    )

    print("Loading Pre-trained KGNN ...")
    # kgnn.load_state_dict(torch.load(f'/data/pj20/molkg/kgnn_last_{KHOP}_hops_kge_{KGE}_{HIDDEN_EMB}.pkl', map_location='cuda:0'), strict=False)
    kgnn.load_state_dict(torch.load(f'/data/pj20/molkg/kgnn_last_{KHOP}_hops_kge_{KGE}.pkl', map_location='cuda:0'), strict=False)

    kgnn = kgnn.cuda()

    return kgnn

In [42]:
kgnn = build_model_kgnn()
kgnn.add_embedding()

Preparing KGNN data...
Loading entity type labels...
Loading center molecule motifs...
Loading entire knowledge graph...
Loading Pre-trained KGNN ...


In [43]:
kgnn_emb = kgnn.node_emb.weight.data.cpu().numpy()

In [44]:
len(kgnn_emb)

184820

In [45]:
len(kgnn_emb[-2])

512

In [46]:
import numpy as np
from tqdm import tqdm

def emb_map(task):
    with open(f"./exampledata/finetune/{task}.csv") as f:
        reader = csv.reader(f)
        next(reader)  # skip header

        smiles = []
        for line in reader:
            smiles.append(line[0])


    pre_feature = features = np.load(f"./exampledata/finetune/{task}.npz")['features']
    
    kgnn_emb_ = []
    for i in tqdm(range(len(smiles))):
        smile = smiles[i]
        if smile in smile2entid.keys():
            id_ = smile2entid[smile]
            emb = kgnn_emb[id_]
            kgnn_emb_.append(emb)
        else:
            emb = np.zeros(len(kgnn_emb[0]))
            kgnn_emb_.append(emb)
    
    kgnn_emb_ = np.array(kgnn_emb_)

    post_feature = np.concatenate((pre_feature, kgnn_emb_), axis=1)

    np.save(f"./exampledata/finetune/{task}_kgnn_3hop.npy", kgnn_emb_)
    # np.save(f"./exampledata/finetune/{task}_fg_kgnn.npy", post_feature)

    return kgnn_emb_

In [47]:
tasks = ['bace', 'bbbp', 'clintox', 'esol', 'freesolv', 'lipo', 'qm7', 'qm8', 'sider', 'tox21', 'toxcast']

for task in tasks:
    emb_map(task=task)

100%|██████████| 1513/1513 [00:00<00:00, 515514.37it/s]
100%|██████████| 2039/2039 [00:00<00:00, 931306.31it/s]
100%|██████████| 1478/1478 [00:00<00:00, 517504.07it/s]
100%|██████████| 1128/1128 [00:00<00:00, 963775.70it/s]
100%|██████████| 642/642 [00:00<00:00, 1011169.05it/s]
100%|██████████| 4200/4200 [00:00<00:00, 1087479.28it/s]
100%|██████████| 6830/6830 [00:00<00:00, 960184.22it/s]
100%|██████████| 21786/21786 [00:00<00:00, 558946.34it/s]
100%|██████████| 1427/1427 [00:00<00:00, 385053.51it/s]
100%|██████████| 7831/7831 [00:00<00:00, 958911.47it/s]
100%|██████████| 8576/8576 [00:00<00:00, 826961.65it/s]
