In [15]:
import os
import numpy as np
import pandas as pd

import deepchem as dc
from openpom.feat.graph_featurizer import GraphFeaturizer, GraphConvConstants
from openpom.models.mpnn_pom import MPNNPOMModel
from deepchem.models.optimizers import ExponentialDecay

import torch
from pathlib import Path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
relative_path = Path("..")


In [16]:
train_path = "data/processed/pubchem_info.csv"
test_path  = "data/processed/test_pubchem_props.csv"

train_df = pd.read_csv(relative_path / train_path)
test_df  = pd.read_csv(relative_path / test_path)
train_df = train_df[train_df['ConnectivitySMILES'].notna()]
print(train_df.shape, test_df.shape)
train_df.head()


(2064, 16) (888, 17)


Unnamed: 0,Pubchem_ID,taste_cluster,MolecularWeight,ConnectivitySMILES,XLogP,TPSA,Complexity,Charge,HBondDonorCount,HBondAcceptorCount,RotatableBondCount,HeavyAtomCount,AtomStereoCount,BondStereoCount,Volume3D,EffectiveRotorCount3D
0,95609,0,116.16,CCC(C(=O)CC)O,0.7,37.3,78.0,0.0,1.0,2.0,3.0,8.0,1.0,0.0,98.0,3.0
1,104224,0,216.36,CCCCCCCCCCC(OC)OC,5.1,18.5,111.0,0.0,0.0,2.0,11.0,15.0,0.0,0.0,192.4,11.0
3,5284499,6,268.5,CCCCCCCCC=CCCCCCCCCO,7.4,20.2,175.0,0.0,1.0,1.0,15.0,19.0,0.0,1.0,246.4,15.0
4,10886,4,186.29,CCCCCC(=O)OCCCCC,3.8,26.3,121.0,0.0,0.0,2.0,9.0,13.0,0.0,0.0,164.4,9.0
5,12978217,3,192.3,CC=CC(=O)C1C(C=CCC1(C)C)C,3.4,17.1,271.0,0.0,0.0,1.0,2.0,14.0,2.0,1.0,165.9,2.8


In [17]:
smiles_train = train_df["ConnectivitySMILES"].astype(str).tolist()
smiles_test  = test_df["ConnectivitySMILES"].astype(str).tolist()

featurizer = GraphFeaturizer()

X_train_graph = featurizer.featurize(smiles_train)
X_test_graph  = featurizer.featurize(smiles_test)

len(X_train_graph), len(X_test_graph)


(2064, 888)

In [18]:
import numpy as np

def get_valid_indices(graphs):
    
    valid_idx = []
    bad_idx = []
    for i, g in enumerate(graphs):
        if isinstance(g, np.ndarray) and g.size == 0:
            bad_idx.append(i)
        elif g is None:
            bad_idx.append(i)
        else:
            valid_idx.append(i)
    return valid_idx, bad_idx

valid_train_idx, bad_train_idx = get_valid_indices(X_train_graph)
valid_test_idx,  bad_test_idx  = get_valid_indices(X_test_graph)

print("Проблемные train-индексы:", bad_train_idx)
print("Проблемные test-индексы:",  bad_test_idx)
print("Осталось валидных train:", len(valid_train_idx), "из", len(X_train_graph))
print("Осталось валидных test:",  len(valid_test_idx),  "из", len(X_test_graph))


Проблемные train-индексы: []
Проблемные test-индексы: []
Осталось валидных train: 2064 из 2064
Осталось валидных test: 888 из 888


In [19]:
n_tasks_pom = 138 
X_train_graph_valid = [X_train_graph[i] for i in valid_train_idx]
X_test_graph_valid  = [X_test_graph[i]  for i in valid_test_idx]

dummy_y_train_valid = np.zeros((len(X_train_graph_valid), n_tasks_pom), dtype=np.float32)
dummy_y_test_valid  = np.zeros((len(X_test_graph_valid),  n_tasks_pom), dtype=np.float32)

dc_train = dc.data.NumpyDataset(X_train_graph_valid, dummy_y_train_valid)
dc_test  = dc.data.NumpyDataset(X_test_graph_valid,  dummy_y_test_valid)


In [20]:
dummy_y_train = np.zeros((len(X_train_graph), n_tasks_pom), dtype=np.float32)
dummy_y_test  = np.zeros((len(X_test_graph),  n_tasks_pom), dtype=np.float32)

dc_train = dc.data.NumpyDataset(X_train_graph, dummy_y_train)
dc_test  = dc.data.NumpyDataset(X_test_graph,  dummy_y_test)


In [21]:
lr_schedule = ExponentialDecay(
    initial_rate=1e-3,
    decay_rate=0.5,
    decay_steps=32 * 20,
    staircase=True,
)

model = MPNNPOMModel(
    n_tasks=n_tasks_pom,
    batch_size=128,
    learning_rate=lr_schedule,
    class_imbalance_ratio=None,
    loss_aggr_type='sum',
    node_out_feats=100,
    edge_hidden_feats=75,
    edge_out_feats=100,
    num_step_message_passing=5,
    mpnn_residual=True,
    message_aggregator_type='sum',
    mode='classification',
    number_atom_features=GraphConvConstants.ATOM_FDIM,
    number_bond_features=GraphConvConstants.BOND_FDIM,
    n_classes=1,
    readout_type='set2set',
    num_step_set2set=3,
    num_layer_set2set=2,
    ffn_hidden_list=[392, 392],
    ffn_embeddings=256,      # размер POM-эмбеддинга
    ffn_activation='relu',
    ffn_dropout_p=0.12,
    ffn_dropout_at_input_no_act=False,
    weight_decay=1e-5,
    self_loop=False,
    optimizer_name='adam',
    log_frequency=0,
    model_dir="./pom_ckpt",
    device_name=device,      # "cuda" или "cpu"
)
model


No class imbalance ratio provided!


<openpom.models.mpnn_pom.MPNNPOMModel at 0x7f6fe8bddba0>

In [22]:
model.restore(relative_path / 'models/checkpoint2.pt')

Если хочется попробовать другие чекпоинты, необходимо запустить \
''' \
git clone https://github.com/BioMachineLearning/openpom.git \
ls openpom \
ls openpom/models \
ls openpom/models/ensemble_models/ \
'''

In [23]:
emb_train = model.predict_embedding(dc_train)
emb_test  = model.predict_embedding(dc_test)

emb_train.shape, emb_test.shape

((2064, 256), (888, 256))

In [None]:
pom_cols = [f"pom_{i}" for i in range(emb_train.shape[1])]

train_pom_df = pd.DataFrame(emb_train, columns=pom_cols)
test_pom_df  = pd.DataFrame(emb_test,  columns=pom_cols)

train_pom_df.insert(0, "ConnectivitySMILES", train_df["ConnectivitySMILES"].astype(str).values)
test_pom_df.insert(0,  "ConnectivitySMILES", test_df["ConnectivitySMILES"].astype(str).values)
train_pom_df.insert(1, "taste_cluster", train_df["taste_cluster"].astype(str).values)

id_col = None
for cand in ["CID", "MoleculeID", "ID"]:
    if cand in train_df.columns:
        id_col = cand
        break

if id_col is not None:
    train_pom_df.insert(0, id_col, train_df[id_col].values)
    test_pom_df.insert(0,  id_col, test_df[id_col].values)

train_pom_df.to_csv(relative_path / "data/processed/train_pom_embeds1.csv", index=False)
test_pom_df.to_csv(relative_path / "data/processed/test_pom_embeds1.csv", index=False)

train_pom_df.head()


Unnamed: 0,ConnectivitySMILES,taste_cluster,pom_0,pom_1,pom_2,pom_3,pom_4,pom_5,pom_6,pom_7,...,pom_246,pom_247,pom_248,pom_249,pom_250,pom_251,pom_252,pom_253,pom_254,pom_255
0,CCC(C(=O)CC)O,0,0.015913,-0.409914,-0.724948,0.814938,0.933522,-0.388368,0.052653,-0.535332,...,-1.446216,-1.587512,1.835953,-0.837952,-0.430424,0.569843,-0.693606,0.607416,-0.241914,0.48346
1,CCCCCCCCCCC(OC)OC,0,-0.423619,-0.285783,-1.385904,-0.745874,-0.018778,-0.384136,-1.068029,-1.488349,...,-1.483234,-0.259887,-0.25751,-1.475285,0.209974,-0.118646,-1.088198,0.429376,-1.372742,0.443456
2,CCCCCCCCC=CCCCCCCCCO,6,-1.159231,-0.265581,-1.951779,-0.408302,0.00073,0.15365,-1.017931,-1.200404,...,-1.835846,-1.031263,-0.973224,-1.436581,-0.438759,-0.280401,-0.76917,0.058871,-2.395262,0.134085
3,CCCCCC(=O)OCCCCC,4,0.092277,-0.316251,-0.153518,-0.241032,0.174953,0.158541,-0.439297,-0.784887,...,0.053043,-0.131249,0.088169,-1.027134,-0.137164,-0.563481,-0.58367,-0.042767,-0.864206,-0.070622
4,CC=CC(=O)C1C(C=CCC1(C)C)C,3,0.190403,-0.606193,-0.491406,-0.627852,-0.079033,-0.215921,-0.1074,-0.386821,...,-0.618083,-0.087135,-0.009854,-0.888789,0.054674,0.318542,-0.912135,-0.049696,-0.010296,-0.103215
