In [1]:
import os

# to avoid the following error when using MPS (GPU in ARM architecture)):
# NotImplementedError: The operator 'aten::scatter_reduce.two_out'
# is not currently implemented for the MPS device. If you want
# this op to be added in priority during the prototype phase of
# this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764.
# As a temporary fix, you can set the environment variable
# `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as
# a fallback for this op.
# WARNING: this will be slower than running natively on MPS.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # before importing torch

print(os.cpu_count())


8


\# TODO

- <s>BatchNormalization?</s>
- ハイパーパラメータチューニング (optuna?)
- <s>モデルの解釈 (tanhやpoolingについて)</s>
  - tanh: 活性化関数
  - pooling: すべての原子の情報を統合する。原子数が異なるためそれらを揃える役割も。
- <s>dropoutの導入？←過学習対策</s>
- autumentation?
- モデルの途中保存 (エポック毎?)
- エッジの重みも学習する (`nn.Parameter`?)
- early-stopping
- max-poolingの検討
- 転移学習 (fine-tuning and feature-extraction)
- [GIN層](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINConv.html#torch_geometric.nn.conv.GINConv)を試してみる

In [2]:
import numpy as np
import pandas as pd
import rdkit
from rdkit import Chem
from rdkit.Chem import Descriptors
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor

import torch
import torch_geometric
import torch_geometric.nn
import torch_geometric.data
import torch_geometric.loader

import lightning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from torch_utils.data import GraphDataset
from torch_utils.model import GCN
from torch_utils.torch_utils import (
    torch_seed,
)
from torch_utils.utils import yyplot
from torch_utils.lightning_utils import LightningGCN


In [3]:
torch.__version__


'2.1.0'

In [4]:
torch_geometric.__version__


'2.4.0'

In [5]:
rdkit.__version__


'2023.09.1'

In [6]:
seed = 334
batch_size = 256

torch_seed(seed)


In [7]:
# df_raw = pd.read_csv("./data/curated-solubility-dataset.csv", index_col=0)
df_raw = pd.read_csv("./data/logSdataset1290.csv", index_col=0)
# 計算時間短縮
# df_raw = df_raw.iloc[:1000]
print(df_raw.shape)
df_raw.head()


(1290, 197)


Unnamed: 0,logS,MolWt,HeavyAtomMolWt,ExactMolWt,NumValenceElectrons,NumRadicalElectrons,MaxPartialCharge,MinPartialCharge,MaxAbsPartialCharge,MinAbsPartialCharge,...,fr_sulfide,fr_sulfonamd,fr_sulfone,fr_term_acetylene,fr_tetrazole,fr_thiazole,fr_thiocyan,fr_thiophene,fr_unbrch_alkane,fr_urea
CC(N)=O,1.58,59.068,54.028,59.037114,24,0,0.21379,-0.369921,0.369921,0.21379,...,0,0,0,0,0,0,0,0,0,0
CNN,1.34,46.073,40.025,46.053098,20,0,-0.001725,-0.271722,0.271722,0.001725,...,0,0,0,0,0,0,0,0,0,0
CC(=O)O,1.22,60.052,56.02,60.021129,24,0,0.299685,-0.481433,0.481433,0.299685,...,0,0,0,0,0,0,0,0,0,0
C1CCNC1,1.15,71.123,62.051,71.073499,30,0,-0.004845,-0.316731,0.316731,0.004845,...,0,0,0,0,0,0,0,0,0,0
NC(=O)NO,1.12,76.055,72.023,76.027277,30,0,0.335391,-0.349891,0.349891,0.335391,...,0,0,0,0,0,0,0,0,0,1


In [8]:
# smiles = df_raw["SMILES"]
# y = df_raw["Solubility"]
smiles = df_raw.index
y = df_raw["logS"]


In [9]:
# scaling
y_mean = y.mean()
y_std = y.std(ddof=1)


In [10]:
smiles = smiles.tolist()
y = torch.Tensor(((y - y_mean) / y_std).tolist()).view(-1, 1)


In [11]:
mols = map(Chem.MolFromSmiles, smiles)


In [12]:
dataset = GraphDataset(mols, y, n_jobs=-1, ipynb=True)
dataset


  0%|          | 0/1290 [00:00<?, ?it/s]

GraphDataset(size=1290)

In [13]:
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
dataloader = torch_geometric.loader.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)
# the following is deprecated
# dataloader = torch_geometric.data.DataLoader(
#     dataset, batch_size=batch_size, shuffle=True
# )
dataloader


<torch_geometric.loader.dataloader.DataLoader at 0x17fd221c0>

In [14]:
dataset_train, dataset_test = train_test_split(
    dataset, test_size=0.2, random_state=seed
)
dataset_train, dataset_val = train_test_split(
    dataset_train, test_size=0.1, random_state=seed
)
print(len(dataset_train), len(dataset_val), len(dataset_test))


928 104 258


In [15]:
dataloader_train = torch_geometric.loader.DataLoader(
    dataset_train, batch_size=batch_size, shuffle=True
)
dataloader_val = torch_geometric.loader.DataLoader(
    dataset_val, batch_size=batch_size, shuffle=False
)
dataloader_test = torch_geometric.loader.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=False
)


In [16]:
data = dataset_train[0]
num_features = data.x.shape[1]


In [17]:
model = GCN(in_channels=num_features, embedding_size=64)
print(model)


GCN(
  (initial_conv): GCNConv(30, 64)
  (batch_norm0): BatchNorm(64)
  (conv1): GCNConv(64, 64)
  (batch_norm1): BatchNorm(64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (dropout): Dropout(p=0.01, inplace=False)
  (relu): LeakyReLU(negative_slope=0.01)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)


In [18]:
# Need 'tabulate' package
print(torch_geometric.nn.summary(model, data))


+----------------------------+---------------+----------------+----------+
| Layer                      | Input Shape   | Output Shape   | #Param   |
|----------------------------+---------------+----------------+----------|
| GCN                        | [14, 14]      | [1, 1]         | 14,849   |
| ├─(initial_conv)GCNConv    |               | [14, 64]       | 1,984    |
| ├─(batch_norm0)BatchNorm   | [14, 64]      | [14, 64]       | 128      |
| │    └─(module)BatchNorm1d | [14, 64]      | [14, 64]       | 128      |
| ├─(conv1)GCNConv           |               | [14, 64]       | 4,160    |
| ├─(batch_norm1)BatchNorm   | [14, 64]      | [14, 64]       | 128      |
| │    └─(module)BatchNorm1d | [14, 64]      | [14, 64]       | 128      |
| ├─(conv2)GCNConv           |               | [14, 64]       | 4,160    |
| ├─(conv3)GCNConv           |               | [14, 64]       | 4,160    |
| ├─(dropout)Dropout         | [14, 64]      | [14, 64]       | --       |
| ├─(relu)LeakyReLU      

In [19]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device


device(type='mps')

In [20]:
model_lightning = LightningGCN(
    model,
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    criterion=torch.nn.MSELoss(),
)


In [21]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=5, min_delta=0, mode="min"
)


In [22]:
trainer = lightning.Trainer(
    accelerator="auto", max_epochs=100, callbacks=[early_stopping]
)
trainer.fit(model_lightning, dataloader_train, dataloader_val)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name      | Type    | Params
--------------------------------------
0 | net       | GCN     | 14.8 K
1 | criterion | MSELoss | 0     
--------------------------------------
14.8 K    Trainable params
0         Non-trainable params
14.8 K    Total params
0.059     Total estimated model params s

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  return src.new_zeros(size).scatter_reduce_(
/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1306. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/yu9824/miniforge3/e

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [23]:
trainer.test(model_lightning, dataloader_test)


/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3319. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 14. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test_loss': 0.14455369114875793}]

In [36]:
y_pred_train_scaled: torch.Tensor = torch.cat(
    trainer.predict(model_lightning, dataloader_train), dim=0
)
y_pred_train_scaled.shape


/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:492: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

torch.Size([928, 1])

In [37]:
y_pred_val_scaled: torch.Tensor = torch.cat(
    trainer.predict(model_lightning, dataloader_val), dim=0
)
y_pred_val_scaled.shape


/Users/yu9824/miniforge3/envs/torch39/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

torch.Size([104, 1])

In [38]:
y_pred_test_scaled: torch.Tensor = torch.cat(
    trainer.predict(model_lightning, dataloader_test), dim=0
)
y_pred_test_scaled.shape


Predicting: |          | 0/? [00:00<?, ?it/s]

torch.Size([258, 1])

In [None]:
yyplot
