In [36]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool
from rdkit.Chem import AllChem

from contextSub.model import GNN
from contextSub.loader import MoleculeDataset
from contextSub.splitters import scaffold_split, scaffold_split_from_smiles
from contextSub.dataloader import DataLoaderPooling


In [2]:
contextpred_model = GNN(
    num_layer=5,
    emb_dim=300,
    JK="last",
    drop_ratio=0.4,
    gnn_type="gin",
    partial_charge=False,
    input_mlp=False,
)

In [3]:
contextpred_model.load_state_dict(torch.load("contextSub/trained_models/finetuned/contextPred/gnn.pth"))

<All keys matched successfully>

In [4]:
dataset = MoleculeDataset(
    root="contextSub/dataset/lightbbb",
    dataset="lightbbb",
    partial_charge=False,
    substruct_input=False,
    pattern_path=None,
    context=False,
    hops=None,
    pooling_indicator=False
)

In [5]:
smiles_list = pd.read_csv(
    "contextSub/dataset/lightbbb/processed/smiles.csv",
    header=None,
)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(
    dataset,
    smiles_list,
    null_value=0,
    frac_train=0.8,
    frac_valid=0.1,
    frac_test=0.1,
)

In [43]:
batch_size = 32
num_workers = 4
train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
val_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

In [44]:
labels = []
contextpred_model.eval()
graph_embs = []
with torch.no_grad():
    for batch in train_loader:
        node_embs = contextpred_model(batch)
        graph_embs.append(global_mean_pool(node_embs, batch.batch))
        labels.append(batch.y)

In [45]:
graph_embs = torch.cat(graph_embs, axis=0)

In [46]:
train_graph_embs = graph_embs.numpy()

In [47]:
val_graph_embs = []
with torch.no_grad():
    for batch in val_loader:
        node_embs = contextpred_model(batch)
        val_graph_embs.append(global_mean_pool(node_embs, batch.batch))
        labels.append(batch.y)
val_graph_embs = torch.cat(val_graph_embs, axis=0).numpy()

In [48]:
test_graph_embs = []
with torch.no_grad():
    for batch in test_loader:
        node_embs = contextpred_model(batch)
        test_graph_embs.append(global_mean_pool(node_embs, batch.batch))
        labels.append(batch.y)
test_graph_embs = torch.cat(test_graph_embs, axis=0).numpy()

In [49]:
val_graph_embs.shape

(714, 300)

In [50]:
test_graph_embs.shape

(715, 300)

In [51]:
train_graph_embs.shape

(5712, 300)

In [52]:
embeddings = []
for emb in np.concatenate([train_graph_embs, val_graph_embs, test_graph_embs], axis=0):
    embeddings.append(",".join(map(str, emb)))

In [53]:
df = pd.DataFrame()
df["embeddings"] = embeddings
df["split"] = ["train"] * train_graph_embs.shape[0] + ["validation"] * val_graph_embs.shape[0] + ["test"] * test_graph_embs.shape[0]

In [56]:
df.head()

Unnamed: 0,embeddings,split
0,"-0.24097337,0.13866656,-0.020489423,0.5635318,...",train
1,"-0.12967293,-0.042312548,0.055031244,0.2166339...",train
2,"-0.24756524,0.084039055,-0.06386699,0.25360867...",train
3,"-0.10989679,-0.029275743,0.10355943,0.36867473...",train
4,"-0.08101643,-0.018708985,-2.1436466e-05,0.2513...",train


In [58]:
labels = torch.cat(labels, axis=0)

In [60]:
labels.size()

torch.Size([7141])

In [66]:
labels[-10:]

tensor([-1,  1,  1,  1,  1,  1,  1, -1,  1, -1])

In [61]:
df["labels"] = labels.numpy()

In [68]:
df.to_csv("contextSub/dataset/lightbbb_contextpred_embeddings.csv")

In [2]:
contextsub_model = GNN(
    num_layer=5,
    emb_dim=300,
    JK="last",
    drop_ratio=0.4,
    gnn_type="gin",
    partial_charge=True,
)
contextsub_model.load_state_dict(torch.load("contextSub/trained_models/contextSub_chembl_partialCharge_noNorm_filteredPattern_epoch300.pth"))

<All keys matched successfully>

In [3]:
def separate_indicators(emb_indicator):
    """ From emb_indicator, compute the indicators for the substructures only and
    the indices for the molecule level embeddings.
    """
    emb_indicator = emb_indicator.squeeze()
    _, counts = np.unique(emb_indicator.cpu().numpy(), return_counts=True)
    mol_indices = [0]
    for count in counts[:-1]:
        mol_indices.append(mol_indices[-1] + count)
    emb_indicator[mol_indices] = emb_indicator[-1] + 1
    return emb_indicator, mol_indices

In [4]:
pattern_path = os.path.join(
        "contextSub", "resources", "pubchemFPKeys_to_SMARTSpattern_filtered.csv"
    )
dataset = MoleculeDataset(
    "contextSub/dataset/lightbbb",
    dataset="lightbbb",
    partial_charge=True,
    substruct_input=True,
    pattern_path=pattern_path,
    context=True,
    hops=5,
    pooling_indicator=True,
)

In [6]:
smiles_list = pd.read_csv(
    "contextSub/dataset/lightbbb/processed/smiles.csv",
    header=None,
)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(
    dataset,
    smiles_list,
    null_value=0,
    frac_train=0.8,
    frac_valid=0.1,
    frac_test=0.1,
)

In [7]:
batch_size = 32
num_workers = 4
train_loader = DataLoaderPooling(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
val_loader = DataLoaderPooling(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)
test_loader = DataLoaderPooling(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

In [30]:
labels = []
contextsub_model.eval()
train_graph_embs = []
device = torch.device("cuda:0")
contextsub_model.to(device)
contextsub_model.eval()
with torch.no_grad():
    for batch in train_loader:
        batch = batch.to(device)
        node_embs = contextsub_model(batch)
        
        mask = batch.mask.to(torch.bool).squeeze()
        emb_repr = global_mean_pool(
            node_embs[mask], batch.pooling_indicator[mask].squeeze(),
        )
        
        sub_indi, mol_indi = separate_indicators(batch.emb_indicator)
        sub_pooled = global_mean_pool(emb_repr, sub_indi)[:-1]
        mol_pooled = emb_repr[mol_indi]
        pooled = torch.cat([sub_pooled, mol_pooled], dim=1)
            
        train_graph_embs.append(pooled)
        labels.append(batch.y)

In [31]:
val_graph_embs = []
with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        node_embs = contextsub_model(batch)
        
        mask = batch.mask.to(torch.bool).squeeze()
        emb_repr = global_mean_pool(
            node_embs[mask], batch.pooling_indicator[mask].squeeze(),
        )
        
        sub_indi, mol_indi = separate_indicators(batch.emb_indicator)
        sub_pooled = global_mean_pool(emb_repr, sub_indi)[:-1]
        mol_pooled = emb_repr[mol_indi]
        pooled = torch.cat([sub_pooled, mol_pooled], dim=1)
            
        val_graph_embs.append(pooled)
        labels.append(batch.y)

In [32]:
test_graph_embs = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        node_embs = contextsub_model(batch)
        
        mask = batch.mask.to(torch.bool).squeeze()
        emb_repr = global_mean_pool(
            node_embs[mask], batch.pooling_indicator[mask].squeeze(),
        )
        
        sub_indi, mol_indi = separate_indicators(batch.emb_indicator)
        sub_pooled = global_mean_pool(emb_repr, sub_indi)[:-1]
        mol_pooled = emb_repr[mol_indi]
        pooled = torch.cat([sub_pooled, mol_pooled], dim=1)
            
        test_graph_embs.append(pooled)
        labels.append(batch.y)

In [33]:
train_graph_embs = torch.cat(train_graph_embs, axis=0).cpu().numpy()

In [34]:
train_graph_embs.shape

(5712, 600)

In [35]:
val_graph_embs = torch.cat(val_graph_embs, axis=0).cpu().numpy()
test_graph_embs = torch.cat(test_graph_embs, axis=0).cpu().numpy()

In [36]:
print(val_graph_embs.shape)
print(test_graph_embs.shape)

(714, 600)
(715, 600)


In [37]:
embeddings = []
for emb in np.concatenate([train_graph_embs, val_graph_embs, test_graph_embs], axis=0):
    embeddings.append(",".join(map(str, emb)))

In [38]:
df = pd.DataFrame()
df["embeddings"] = embeddings
df["split"] = ["train"] * train_graph_embs.shape[0] + ["validation"] * val_graph_embs.shape[0] + ["test"] * test_graph_embs.shape[0]

In [39]:
labels = torch.cat(labels, axis=0).cpu().numpy()

In [40]:
df["labels"] = labels

In [41]:
df.tail()

Unnamed: 0,embeddings,split,labels
7136,"-0.030503618,0.0069729984,-0.17820397,-0.08808...",test,1
7137,"-0.106497735,0.070397,-0.1499533,-0.07778076,0...",test,1
7138,"-0.07903514,-0.034927227,-0.10666142,-0.077492...",test,-1
7139,"-0.0074477945,-0.060378015,-0.1882684,-0.06364...",test,1
7140,"-0.10663857,-0.09706559,-0.18169187,-0.0753999...",test,-1


In [42]:
len(embeddings[0].split(","))

600

In [43]:
df.to_csv("contextSub/dataset/lightbbb_contextsub_embeddings.csv")

## Add splitting indicators to descriptor and label file

In [4]:
smiles_list = pd.read_csv(
    "contextSub/dataset/lightbbb/processed/smiles.csv",
    header=None,
)[0].tolist()

In [6]:
train_smi, val_smi, test_smi = scaffold_split_from_smiles(smiles_list)

In [11]:
descriptors = pd.read_csv("contextSub/dataset/lightbbb/datasetNormalizedDescrs.csv")
labels = pd.read_csv("contextSub/dataset/lightbbb/raw/y_test_indices.csv")

In [35]:
labels.head()

Unnamed: 0,smiles,logBB,BBclass,split
0,c1cc(F)ccc1Cn(c(c23)cccc2)c(n3)[C@@H]4CCCN(C)C4,0.43,1,
1,CC1CCN(CC1)C(=O)c(c2)ccc3n(CC=C)c(c4c23)CCN(C4...,-0.13,0,
2,ClCCCl,-0.14,1,train
3,c1cccc(c1C23C)C(N3)Cc4c2cccc4,1.11,1,
4,CC(=O)C,-0.15,0,


In [38]:
processed_smiles = []
for smi in labels["smiles"]:
    mol = AllChem.MolFromSmiles(smi)
    if mol is None:
        processed_smiles.append(None)
    else:
        processed_smiles.append(AllChem.MolToSmiles(mol))

In [57]:
processed_smiles_np = np.array(processed_smiles)

In [58]:
for idx in train_smi:
    labels.loc[processed_smiles_np == smiles_list[idx], "split"] = "train"

In [59]:
for idx in val_smi:
    labels.loc[processed_smiles_np == smiles_list[idx], "split"] = "validation"
for idx in test_smi:
    labels.loc[processed_smiles_np == smiles_list[idx], "split"] = "test"

In [60]:
labels.loc[labels.split != "NA"].shape

(7141, 4)

In [67]:
labels.loc[labels.split == "train"].shape

(5712, 4)

In [68]:
labels.loc[labels.split == "validation"].shape

(714, 4)

In [69]:
labels.loc[labels.split == "test"].shape

(715, 4)

In [63]:
labels.to_csv("contextSub/dataset/splitted_y_test_indices.csv")