In [1]:
import os
# to avoid the following error when using MPS (GPU in ARM architecture)):
# NotImplementedError: The operator 'aten::scatter_reduce.two_out'
# is not currently implemented for the MPS device. If you want
# this op to be added in priority during the prototype phase of
# this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764.
# As a temporary fix, you can set the environment variable
# `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as
# a fallback for this op.
# WARNING: this will be slower than running natively on MPS.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # before importing torch


\# TODO

- BatchNormalization?
- ハイパーパラメータチューニング (optuna?)
- モデルの解釈 (tanhやpoolingについて)
- dropoutの導入？←過学習対策
- autumentation?

In [2]:
import pandas as pd
from rdkit import Chem
from sklearn.model_selection import train_test_split

import torch
import torch_geometric.nn
import torch_geometric.data

from src.data import GraphDataset
from src.model import GCN
from src.torch_utils import fit, eval_loss, evaluate_history, torch_seed


In [3]:
seed = 334
batch_size = 256

torch_seed(seed)


In [4]:
df_raw = pd.read_csv("./data/curated-solubility-dataset.csv", index_col=0)
# 計算時間短縮
# df_raw = df_raw.iloc[:1000]
print(df_raw.shape)
df_raw.head()


(9982, 25)


Unnamed: 0_level_0,Name,InChI,InChIKey,SMILES,Solubility,SD,Ocurrences,Group,MolWt,MolLogP,...,NumRotatableBonds,NumValenceElectrons,NumAromaticRings,NumSaturatedRings,NumAliphaticRings,RingCount,TPSA,LabuteASA,BalabanJ,BertzCT
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A-3,"N,N,N-trimethyloctadecan-1-aminium bromide",InChI=1S/C21H46N.BrH/c1-5-6-7-8-9-10-11-12-13-...,SZEMGTQCPRNXEG-UHFFFAOYSA-M,[Br-].CCCCCCCCCCCCCCCCCC[N+](C)(C)C,-3.616127,0.0,1,G1,392.51,3.9581,...,17.0,142.0,0.0,0.0,0.0,0.0,0.0,158.520601,0.0,210.377334
A-4,Benzo[cd]indol-2(1H)-one,InChI=1S/C11H7NO/c13-11-8-5-1-3-7-4-2-6-9(12-1...,GPYLCFQEKPUWLD-UHFFFAOYSA-N,O=C1Nc2cccc3cccc1c23,-3.254767,0.0,1,G1,169.183,2.4055,...,0.0,62.0,2.0,0.0,1.0,3.0,29.1,75.183563,2.582996,511.229248
A-5,4-chlorobenzaldehyde,InChI=1S/C7H5ClO/c8-7-3-1-6(5-9)2-4-7/h1-5H,AVPYQKSLYISFPO-UHFFFAOYSA-N,Clc1ccc(C=O)cc1,-2.177078,0.0,1,G1,140.569,2.1525,...,1.0,46.0,1.0,0.0,0.0,1.0,17.07,58.261134,3.009782,202.661065
A-8,"zinc bis[2-hydroxy-3,5-bis(1-phenylethyl)benzo...",InChI=1S/2C23H22O3.Zn/c2*1-15(17-9-5-3-6-10-17...,XTUPUYCJWKHGSW-UHFFFAOYSA-L,[Zn++].CC(c1ccccc1)c2cc(C(C)c3ccccc3)c(O)c(c2)...,-3.924409,0.0,1,G1,756.226,8.1161,...,10.0,264.0,6.0,0.0,0.0,6.0,120.72,323.755434,2.322963e-07,1964.648666
A-9,4-({4-[bis(oxiran-2-ylmethyl)amino]phenyl}meth...,InChI=1S/C25H30N2O4/c1-5-20(26(10-22-14-28-22)...,FAUAZXVRLVIARB-UHFFFAOYSA-N,C1OC1CN(CC2CO2)c3ccc(Cc4ccc(cc4)N(CC5CO5)CC6CO...,-4.662065,0.0,1,G1,422.525,2.4854,...,12.0,164.0,2.0,4.0,4.0,6.0,56.6,183.183268,1.084427,769.899934


In [5]:
smiles = df_raw["SMILES"]
y = df_raw["Solubility"]


In [6]:
# scaling
y_mean = y.mean()
y_std = y.std(ddof=1)


In [7]:
smiles = smiles.tolist()
y = ((y - y_mean) / y_std).tolist()


In [8]:
mols = map(Chem.MolFromSmiles, smiles)


In [9]:
dataset = GraphDataset(mols, y, n_jobs=-1)
dataset


  torch.tensor(y, dtype=torch.float32)


GraphDataset(9982)

In [10]:
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
dataloader = torch_geometric.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)
dataloader




<torch_geometric.deprecation.DataLoader at 0x17b83aa00>

In [11]:
dataset_train, dataset_test = train_test_split(
    dataset, test_size=0.2, random_state=seed
)
print(len(dataset_train), len(dataset_test))


7985 1997


In [12]:
dataloader_train = torch_geometric.data.DataLoader(
    dataset_train, batch_size=batch_size, shuffle=True
)
dataloader_test = torch_geometric.data.DataLoader(
    dataset_test, batch_size=batch_size, shuffle=False
)


In [13]:
model = GCN(in_channels=dataset_train[0].x.shape[1], embedding_size=64)
print(model)


GCN(
  (initial_conv): GCNConv(30, 64)
  (conv1): GCNConv(64, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (out): Linear(in_features=128, out_features=1, bias=True)
)


In [14]:
data = dataset_train[0]


In [15]:
# Need 'tabulate' package
print(torch_geometric.nn.summary(model, data))


+-------------------------+---------------+----------------+----------+
| Layer                   | Input Shape   | Output Shape   | #Param   |
|-------------------------+---------------+----------------+----------|
| GCN                     | [11, 11]      | [1]            | 14,593   |
| ├─(initial_conv)GCNConv |               | [11, 64]       | 1,984    |
| ├─(conv1)GCNConv        |               | [11, 64]       | 4,160    |
| ├─(conv2)GCNConv        |               | [11, 64]       | 4,160    |
| ├─(conv3)GCNConv        |               | [11, 64]       | 4,160    |
| ├─(out)Linear           | [1, 128]      | [1, 1]         | 129      |
+-------------------------+---------------+----------------+----------+


In [16]:
model(data)


tensor([-0.1715], grad_fn=<ViewBackward0>)

In [17]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


In [18]:
history = fit(
    model.to(device),
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
    criterion=torch.nn.MSELoss(),
    train_loader=dataloader_train,
    test_loader=dataloader_test,
    # num_epochs=50,
    num_epochs=100,
    device=device,
    ipynb=True,
)


  0%|          | 0/32 [00:00<?, ?it/s]

  return src.new_zeros(size).scatter_reduce_(


Epoch [1/100], loss: 0.80411 val_loss: 0.72426, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [2/100], loss: 0.61204 val_loss: 0.60290, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [3/100], loss: 0.53518 val_loss: 0.54612, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [4/100], loss: 0.49334 val_loss: 0.53877, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [5/100], loss: 0.48017 val_loss: 0.48902, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [6/100], loss: 0.44535 val_loss: 0.45545, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [7/100], loss: 0.41692 val_loss: 0.43777, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [8/100], loss: 0.40497 val_loss: 0.42486, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [9/100], loss: 0.38402 val_loss: 0.41873, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [10/100], loss: 0.37296 val_loss: 0.41694, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [11/100], loss: 0.36773 val_loss: 0.39749, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [12/100], loss: 0.35441 val_loss: 0.38654, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [13/100], loss: 0.34303 val_loss: 0.37627, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [14/100], loss: 0.33472 val_loss: 0.38137, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [15/100], loss: 0.33508 val_loss: 0.37040, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [16/100], loss: 0.32542 val_loss: 0.35497, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [17/100], loss: 0.31981 val_loss: 0.35487, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [18/100], loss: 0.31240 val_loss: 0.35536, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [19/100], loss: 0.30896 val_loss: 0.33759, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [20/100], loss: 0.29878 val_loss: 0.33229, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [21/100], loss: 0.29547 val_loss: 0.32998, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [22/100], loss: 0.29172 val_loss: 0.32351, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [23/100], loss: 0.28780 val_loss: 0.33311, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [24/100], loss: 0.29333 val_loss: 0.34069, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [25/100], loss: 0.28594 val_loss: 0.31363, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [26/100], loss: 0.27633 val_loss: 0.32053, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [27/100], loss: 0.28024 val_loss: 0.31410, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [28/100], loss: 0.27054 val_loss: 0.31423, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [29/100], loss: 0.26965 val_loss: 0.32864, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [30/100], loss: 0.27083 val_loss: 0.30331, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [31/100], loss: 0.27216 val_loss: 0.29509, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [32/100], loss: 0.26093 val_loss: 0.29588, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [33/100], loss: 0.26409 val_loss: 0.31124, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [34/100], loss: 0.26359 val_loss: 0.31169, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [35/100], loss: 0.25825 val_loss: 0.29007, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [36/100], loss: 0.25271 val_loss: 0.29040, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [37/100], loss: 0.25409 val_loss: 0.28613, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [38/100], loss: 0.25010 val_loss: 0.29206, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [39/100], loss: 0.24564 val_loss: 0.28708, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [40/100], loss: 0.24611 val_loss: 0.28638, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [41/100], loss: 0.24344 val_loss: 0.29677, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [42/100], loss: 0.24776 val_loss: 0.28098, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [43/100], loss: 0.23830 val_loss: 0.28753, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [44/100], loss: 0.24061 val_loss: 0.28065, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [45/100], loss: 0.23867 val_loss: 0.28424, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [46/100], loss: 0.24366 val_loss: 0.29108, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [47/100], loss: 0.23535 val_loss: 0.27467, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [48/100], loss: 0.23180 val_loss: 0.27984, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [49/100], loss: 0.23303 val_loss: 0.29520, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [50/100], loss: 0.23223 val_loss: 0.27874, 


  0%|          | 0/32 [00:00<?, ?it/s]

Epoch [51/100], loss: 0.23831 val_loss: 0.28448, 


  0%|          | 0/32 [00:00<?, ?it/s]

In [None]:
evaluate_history(history)
