## Trained model and optimal hyperparameters

We have provided trained model and hyperparameters of SphereNet on QM9 and MD17 [here](https://github.com/divelab/DIG_storage/tree/main/3dgraph).

## Example of 3D Graph

Here we provide the example code for SphereNet on QM93D and MD17 datasets. You can easily replace SphereNet with SchNet and DimeNetPP by changing model name and model parameters.

In [2]:
import torch
import sys
sys.path.insert(0,'..')
sys.path.insert(0,'../..')
from dig.threedgraph.dataset import QM93D
from dig.threedgraph.dataset import MD17
from dig.threedgraph.method import SphereNet #SchNet, DimeNetPP
from dig.threedgraph.method import run
from dig.threedgraph.evaluation import ThreeDEvaluator

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda', index=0)

### Example code for QM93D data
***Note***: '3D' means that the data includes positional information for atoms.

We trained a separate model for each target except for _gap_, which was predicted by taking _lumo-homo_. You can use default hyperparameters to get comparable results, we also tuned hyperparameters like lr, lr_decay_factor, lr_decay_step_size, cutoff, num_spherical, num_radial, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion to achieve better performance. The values/search space for hyperparameters are listed in the Appendix of our paper.

The default hyperparameters for QM93D are:  
    &ensp; energy_and_force=False, cutoff=5.0, num_layers=4, hidden_channels=128, out_channels=1, int_emb_size=64,  
    &ensp; basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,  
    &ensp; num_spherical=3, num_radial=6, envelope_exponent=5,  
    &ensp; num_before_skip=1, num_after_skip=2, num_output_layers=3,  
    &ensp; epochs=1000, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=100.




#### Loading dataset

In [3]:
dataset = QM93D(root='dataset/')
target = 'U0'
dataset.data.y = dataset.data[target]

split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=110000, valid_size=10000, seed=42)

train_dataset, valid_dataset, test_dataset = dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]
print('train, validaion, test:', len(train_dataset), len(valid_dataset), len(test_dataset))

train, validaion, test: 110000 10000 10831


#### Loading model, loss and evaluation function

The evaluation metric is mean absolute error (MAE).

In [4]:
model = SphereNet(energy_and_force=False, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3, use_node_features=True
        )
loss_func = torch.nn.L1Loss()
evaluation = ThreeDEvaluator()

#### Training

In [5]:
run3d = run()
run3d.run(device, train_dataset, valid_dataset, test_dataset, model, loss_func, evaluation, epochs=20, batch_size=32, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=15)

#Params: 1890118

=====Epoch 1

Training...


100%|██████████| 3438/3438 [07:22<00:00,  7.77it/s]



Evaluating...



100%|██████████| 313/313 [00:20<00:00, 15.30it/s]



Testing...



100%|██████████| 339/339 [00:21<00:00, 16.04it/s]


{'Train': 0.8305539944409076, 'Validation': 0.7885677814483643, 'Test': 0.7943109273910522}

=====Epoch 2

Training...



100%|██████████| 3438/3438 [06:16<00:00,  9.12it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.91it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 27.32it/s]


{'Train': 0.3417653005923415, 'Validation': 0.16290859878063202, 'Test': 0.16250823438167572}

=====Epoch 3

Training...



100%|██████████| 3438/3438 [04:24<00:00, 13.01it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.31it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 27.83it/s]


{'Train': 0.2626579807482881, 'Validation': 0.10924234241247177, 'Test': 0.1091669574379921}

=====Epoch 4

Training...



100%|██████████| 3438/3438 [04:26<00:00, 12.88it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.67it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 28.65it/s]


{'Train': 0.2185871605092249, 'Validation': 0.1412947177886963, 'Test': 0.14113298058509827}

=====Epoch 5

Training...



100%|██████████| 3438/3438 [04:25<00:00, 12.97it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.38it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.38it/s]


{'Train': 0.18415136586759867, 'Validation': 0.08948442339897156, 'Test': 0.08791808038949966}

=====Epoch 6

Training...



100%|██████████| 3438/3438 [04:24<00:00, 13.00it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.43it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.58it/s]


{'Train': 0.17059671088246983, 'Validation': 0.10857655853033066, 'Test': 0.1086759939789772}

=====Epoch 7

Training...



100%|██████████| 3438/3438 [04:30<00:00, 12.69it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.61it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.14it/s]


{'Train': 0.15622219235277093, 'Validation': 0.08192159235477448, 'Test': 0.08170071989297867}

=====Epoch 8

Training...



100%|██████████| 3438/3438 [04:35<00:00, 12.48it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.14it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 28.63it/s]


{'Train': 0.1442768630192958, 'Validation': 0.08120342344045639, 'Test': 0.08138693124055862}

=====Epoch 9

Training...



100%|██████████| 3438/3438 [04:24<00:00, 13.00it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.65it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 28.29it/s]


{'Train': 0.13906806218478485, 'Validation': 0.07339970022439957, 'Test': 0.0732196718454361}

=====Epoch 10

Training...



100%|██████████| 3438/3438 [04:35<00:00, 12.49it/s]



Evaluating...



100%|██████████| 313/313 [00:11<00:00, 27.44it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 26.88it/s]


{'Train': 0.12617339688792625, 'Validation': 0.11456501483917236, 'Test': 0.11438193917274475}

=====Epoch 11

Training...



100%|██████████| 3438/3438 [04:27<00:00, 12.85it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.90it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 26.23it/s]


{'Train': 0.12321726725571651, 'Validation': 0.0715189278125763, 'Test': 0.07092428207397461}

=====Epoch 12

Training...



100%|██████████| 3438/3438 [04:31<00:00, 12.68it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.38it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.33it/s]


{'Train': 0.11304465457233598, 'Validation': 0.1164650246500969, 'Test': 0.11696784943342209}

=====Epoch 13

Training...



100%|██████████| 3438/3438 [04:28<00:00, 12.80it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.54it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.32it/s]


{'Train': 0.11311055924429181, 'Validation': 0.1142609491944313, 'Test': 0.11372711509466171}

=====Epoch 14

Training...



100%|██████████| 3438/3438 [04:25<00:00, 12.97it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.43it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 27.68it/s]


{'Train': 0.1103381712277869, 'Validation': 0.05894898623228073, 'Test': 0.05792950466275215}

=====Epoch 15

Training...



100%|██████████| 3438/3438 [04:30<00:00, 12.69it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.26it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 28.55it/s]


{'Train': 0.09813584842398945, 'Validation': 0.13913576304912567, 'Test': 0.1383834183216095}

=====Epoch 16

Training...



100%|██████████| 3438/3438 [04:26<00:00, 12.89it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.29it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 28.72it/s]


{'Train': 0.05428033658000465, 'Validation': 0.06030373275279999, 'Test': 0.059175316244363785}

=====Epoch 17

Training...



100%|██████████| 3438/3438 [04:28<00:00, 12.80it/s]



Evaluating...



100%|██████████| 313/313 [00:11<00:00, 27.83it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 27.47it/s]


{'Train': 0.054203004988561614, 'Validation': 0.03810606151819229, 'Test': 0.03703922778367996}

=====Epoch 18

Training...



100%|██████████| 3438/3438 [04:29<00:00, 12.77it/s]



Evaluating...



100%|██████████| 313/313 [00:11<00:00, 28.11it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 26.31it/s]


{'Train': 0.0530719623151666, 'Validation': 0.04359658062458038, 'Test': 0.043418560177087784}

=====Epoch 19

Training...



100%|██████████| 3438/3438 [04:26<00:00, 12.89it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 28.87it/s]



Testing...



100%|██████████| 339/339 [00:12<00:00, 28.13it/s]


{'Train': 0.05202796294149651, 'Validation': 0.04247582331299782, 'Test': 0.04204947501420975}

=====Epoch 20

Training...



100%|██████████| 3438/3438 [04:31<00:00, 12.64it/s]



Evaluating...



100%|██████████| 313/313 [00:10<00:00, 29.59it/s]



Testing...



100%|██████████| 339/339 [00:11<00:00, 29.71it/s]



{'Train': 0.04962607438894397, 'Validation': 0.04090351238846779, 'Test': 0.040894996374845505}
Best validation MAE so far: 0.03810606151819229
Test MAE when got best validation result: 0.03703922778367996



### Example code for MD17 data

We trained a separate model for each molecule. You can use default hyperparameters to get comparable results, we also tuned hyperparameters like lr, lr_decay_factor, lr_decay_step_size, batch_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion to achieve better performance. The values/search space for hyperparameters are listed in the Appendix of our paper.

The default hyperparameters for MD17 are:  
    &ensp; energy_and_force=True, cutoff=5.0, num_layers=4, hidden_channels=128, out_channels=1, int_emb_size=64,  
    &ensp; basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,  
    &ensp; num_spherical=3, num_radial=6, envelope_exponent=5,  
    &ensp; num_before_skip=1, num_after_skip=2, num_output_layers=3,  
    &ensp; epochs=1000, batch_size=1, vt_batch_size=32, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=200.

#### Loading dataset

In [7]:
dataset_md17 = MD17(root='dataset/', name='aspirin')

split_idx_md17 = dataset_md17.get_idx_split(len(dataset_md17.data.y), train_size=1000, valid_size=1000, seed=42)

train_dataset_md17, valid_dataset_md17, test_dataset_md17 = dataset_md17[split_idx_md17['train']], dataset_md17[split_idx_md17['valid']], dataset_md17[split_idx_md17['test']]
print('train, validaion, test:', len(train_dataset_md17), len(valid_dataset_md17), len(test_dataset_md17))

train, validaion, test: 1000 1000 209762


#### Loading model, loss and evaluation function

We predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces.

The evaluation metric is mean absolute error (MAE).

In [8]:
model_md17 = SphereNet(energy_and_force=True, cutoff=5.0, num_layers=4, 
        hidden_channels=128, out_channels=1, int_emb_size=64, 
        basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256, 
        num_spherical=3, num_radial=6, envelope_exponent=5, 
        num_before_skip=1, num_after_skip=2, num_output_layers=3 
        )
loss_func_md17 = torch.nn.L1Loss()
evaluation_md17 = ThreeDEvaluator()

#### Training

In [9]:
run3d_md17 = run()
run3d_md17.run(device, train_dataset_md17, valid_dataset_md17, test_dataset_md17, model_md17, loss_func_md17, evaluation_md17, epochs=5, batch_size=1, vt_batch_size=64, lr=0.0005, lr_decay_factor=0.5, lr_decay_step_size=200, energy_and_force=True)

#Params: 1890118

=====Epoch 1

Training...


100%|██████████| 1000/1000 [03:20<00:00,  4.99it/s]



Evaluating...



100%|██████████| 16/16 [00:03<00:00,  5.06it/s]

{'Energy MAE': 21076.162109375, 'Force MAE': 88.1651611328125}


Testing...



100%|██████████| 3278/3278 [10:49<00:00,  5.05it/s]

{'Energy MAE': 21079.103515625, 'Force MAE': 87.94914245605469}

{'Train': 62999.26555371094, 'Validation': 29892.67822265625, 'Test': 29874.01776123047}

=====Epoch 2

Training...



100%|██████████| 1000/1000 [03:35<00:00,  4.65it/s]



Evaluating...



100%|██████████| 16/16 [00:03<00:00,  5.09it/s]

{'Energy MAE': 11752.1220703125, 'Force MAE': 41.66204833984375}


Testing...



100%|██████████| 3278/3278 [10:27<00:00,  5.23it/s]

{'Energy MAE': 11752.3837890625, 'Force MAE': 41.8145866394043}

{'Train': 21397.7178671875, 'Validation': 15918.326904296875, 'Test': 15933.84245300293}

=====Epoch 3

Training...



100%|██████████| 1000/1000 [03:44<00:00,  4.46it/s]



Evaluating...



100%|██████████| 16/16 [00:03<00:00,  5.15it/s]

{'Energy MAE': 12009.2421875, 'Force MAE': 70.98027038574219}


Testing...



100%|██████████| 3278/3278 [10:42<00:00,  5.10it/s]

{'Energy MAE': 12010.466796875, 'Force MAE': 71.2234115600586}

{'Train': 14609.853533203124, 'Validation': 19107.26922607422, 'Test': 19132.80795288086}

=====Epoch 4

Training...



100%|██████████| 1000/1000 [03:46<00:00,  4.42it/s]



Evaluating...



100%|██████████| 16/16 [00:03<00:00,  5.11it/s]

{'Energy MAE': 14571.55078125, 'Force MAE': 49.371952056884766}


Testing...



100%|██████████| 3278/3278 [10:43<00:00,  5.09it/s]

{'Energy MAE': 14570.435546875, 'Force MAE': 49.68278121948242}

{'Train': 13809.444167358399, 'Validation': 19508.745986938477, 'Test': 19538.713668823242}

=====Epoch 5

Training...



100%|██████████| 1000/1000 [03:43<00:00,  4.47it/s]



Evaluating...



100%|██████████| 16/16 [00:03<00:00,  5.20it/s]

{'Energy MAE': 19717.771484375, 'Force MAE': 30.181947708129883}


Testing...



100%|██████████| 3278/3278 [10:41<00:00,  5.11it/s]

{'Energy MAE': 19717.19140625, 'Force MAE': 30.4222354888916}

{'Train': 12309.488090698242, 'Validation': 22735.96625518799, 'Test': 22759.41495513916}
Best validation MAE so far: 15918.326904296875
Test MAE when got best validation result: 15933.84245300293



