#### Please use pretrainingset.csv and train_chemprop.sh to generate pretrained model before finetuning.

In [1]:
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl
from sklearn.preprocessing import StandardScaler
import torch

from chemprop import data, featurizers, models

In [2]:
chemprop_dir = Path.cwd().parent
input_path = '../../data/processed/trainingset.csv'
num_workers = 0
smiles_column = 'SMILES' 
target_columns = ['labels']

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,SMILES,labels
0,O[C@H]1[C@H](O)[C@H](O[C@@H](OC[C@@H](O[C@@H](...,0
1,O[C@H]1[C@H](O)[C@H](O[C@@H](OC[C@@H](O[C@@H](...,0
2,[C@@H]1([C@H](O[C@@H](O[C@@H]([C@H](N)C)CCCCCC...,0
3,[C@@H]1(O)[C@H](O[C@@H](O[C@@H]([C@@H](N)CO)CC...,0
4,[C@@H](C#CC#CCO)(O)CCCCCCCCC\C=C/CCCCCCCC,0
...,...,...
15203,C1(=O)CC(=C(COC(=O)C)[C@]([H])(C[C@]2(C)[C@@](...,2
15204,C1(=O)CC(=C(COC(=O)CC(C)C)[C@]([H])(C[C@]2(C)[...,2
15205,c1c(C)c(C(=O)Oc(cc(O)cc2C)c2O3)c3c(Cc4ccc(O)cc...,2
15206,[C@@]12([H])[C@@]3(C(=O)N[C@H]1CC(C)C)[C@@H](\...,2


In [8]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [5]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

In [6]:
list(data.SplitType.keys())

['SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

In [None]:
mols = [d.mol for d in all_data]  
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

In [None]:
chemprop_dir = Path.cwd().parent
checkpoint_path = '../../results/models/model_GCN_cleaned_pretrained.ckpt'  # your pretrained model
mpnn_cls = models.MPNN
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn

  d = torch.load(model_path, map_location=map_location)


MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=1100, bias=False)
    (W_h): Linear(in_features=1100, out_features=1100, bias=False)
    (W_o): Linear(in_features=1172, out_features=1100, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): GraphTransform(
      (V_transform): Identity()
      (E_transform): Identity()
    )
  )
  (agg): NormAggregation()
  (bn): BatchNorm1d(1100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): MulticlassClassificationFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=1100, out_features=1100, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=1100, out_features=1100, bias=True)
      )
      (2): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data, featurizer)
val_dset = data.MoleculeDataset(val_data, featurizer)
test_dset = data.MoleculeDataset(test_data, featurizer)

In [None]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

In [None]:
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval() 

BatchNorm1d(1100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
trainer = pl.Trainer(
    logger=True,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=200
)

trainer.fit(mpnn, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
C:\Users\User01\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.

  | Name            | Type                        | Params | Mode 
------------------------------------------------------------------------
0 | message_passing | BondMessagePassing          | 2.6 M  | eval 
1 | agg             | NormAggregation             | 0      | train
2 | bn              | BatchNorm1d                 | 2.2 K  | eval 
3 | predictor       | Multiclas

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

C:\Users\User01\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 199: 100%|██████████| 191/191 [00:09<00:00, 19.29it/s, v_num=1, train_loss=0.000, val_loss=0.564]   

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|██████████| 191/191 [00:10<00:00, 18.99it/s, v_num=1, train_loss=0.000, val_loss=0.564]


In [None]:
results = trainer.test(mpnn, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
C:\Users\User01\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 24/24 [00:01<00:00, 20.23it/s]


In [None]:
final_model_path = '../../results/models/model_GCN_cleaned_finetuned.ckpt'
models.save_model(final_model_path, mpnn)