In [1]:
import json
import os
import pickle
import random
import sys

from typing import Callable, Dict, List, Optional
import haiku as hk
import ase
import ase.io
import jax
import jax.numpy as jnp
import numpy as np
import optax
import yaml


from model.datasets import becs_eps_datasets
from model.utils import (
    create_directory_with_random_name,
    compute_avg_num_neighbors,
)
from model.data_utils import (
    get_atomic_number_table_from_zs,
    compute_average_E0s,
)
from model.predictors import predict_becs_eps
from model.optimizer import optimizer
from model.becs_eps_train import BECS_EPS_train
from model.loss import BecsEpsLoss


from model.becs_eps_model import BECS_EPS_model

jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)
np.set_printoptions(precision=3, suppress=True)

In [2]:
with open('data/train_eps.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

save_dir_name = create_directory_with_random_name(
    os.path.splitext('eps_training')[0]
)

2024-03-12-18:18-eps_training-amazing-caren


In [3]:
# Save config and code
with open(f"{save_dir_name}/config.yaml", "w") as f:
    yaml.dump(config, f)
with open(f"{save_dir_name}/train.py", "w") as f:
    with open(sys.argv[0]) as g:
        f.write(g.read())
        
train_loader, valid_loader,test_loader, r_max = becs_eps_datasets(
    r_max = config["cutoff"],
    train_path = config["dataset"]["train_path"],
    #valid_path = config["dataset"]["valid_path"],
    #train_num = config["dataset"]["train_num"],
    valid_num = config["dataset"]["valid_num"],
    n_node = config["dataset"]["num_nodes"],
    n_edge = config["dataset"]["num_edges"],
    n_graph = config["dataset"]["num_graphs"],
)

print(len(train_loader.graphs))
print(len(valid_loader.graphs))


model_fn, params, num_message_passing = BECS_EPS_model(
    r_max=r_max,
    atomic_energies_dict={},
    train_graphs=train_loader.graphs,
    initialize_seed=config["model"]["seed"],
    num_species = config["model"]["num_species"],
    use_sc = True,
    graph_net_steps = config["model"]["num_layers"],
    hidden_irreps = config["model"]["internal_irreps"],
    nonlinearities =  {'e': 'swish', 'o': 'tanh'},
    save_dir_name = save_dir_name,
    reload = config["initialization"]['reload'] if 'reload' in config["initialization"] else None,
)
    
print("num_params:", sum(p.size for p in jax.tree_util.tree_leaves(params)))
    
predictor = jax.jit(
    lambda w, g: predict_becs_eps(lambda *x: model_fn(w, *x), g)
)
    
gradient_transform, steps_per_interval, max_num_intervals = optimizer(
    lr = config["training"]["learning_rate"],
    max_num_intervals = config["training"]["max_num_intervals"],
    steps_per_interval = config["training"]["steps_per_interval"],
    # weight_decay = config["training"]["weight_decay"],
)
optimizer_state = gradient_transform.init(params)
print("optimizer num_params:", sum(p.size for p in jax.tree_util.tree_leaves(optimizer_state)))
    
loss_fn = BecsEpsLoss(
    becs_weight = config["training"]["becs_weight"],
    becs_sum_weight = config["training"]["becs_sum_weight"],
    eps_weight = config["training"]["eps_weight"],
)

100%|█████████████████████████████████████| 1420/1420 [00:00<00:00, 9067.21it/s]
100%|███████████████████████████████████████| 100/100 [00:00<00:00, 7507.52it/s]


1420
100
z_table= AtomicNumberTable: (1, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 57, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83)
Compute the average number of neighbors: 28.767
Do not normalize the radial basis (avg_r_min=None)
Create BECS/EPS (NequIP-based) model with parameters {'use_sc': True, 'graph_net_steps': 3, 'hidden_irreps': '48x0e + 48x0o + 32x1o + 32x1e +24x2o + 24x2e', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 28.766618979494517, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7f26cccf2c00>, 'radial_envelope': <function soft_envelope at 0x7f26cccf3c40>}
model: hidden_irreps=48x0e+48x0o+32x1o+32x1e+24x2o+24x2e sh_irreps=1x0e+1x1o+1x2e+1x3o 
num_params: 1426672
optimizer num_params: 4280018


In [4]:
BECS_EPS_train(
    predictor,
    params,
    optimizer_state,
    train_loader,
    valid_loader,
    #test_loader,
    gradient_transform,
    loss_fn =loss_fn,
    max_num_intervals = max_num_intervals,
    steps_per_interval = steps_per_interval,
    save_dir_name = save_dir_name,
    patience = config["training"]["patience"],
    #ema_decay = config["training"]["ema_decay"],
)
print('training done!')
    

Started training


eval_train:   0%|                                        | 0/79 [00:00<?, ?it/s]

Compiled function `model` for args:
- n_node=[ 3  8  8  9 10  8  2  8 10 12 10 10 12 12  4 10 12  2 10  2  8 22  0  0] total=192
- n_edge=[112 140 160 166 336 208  32 284 152 152 404 124 352 208  76 156 336  36
 316  52 160  38   0   0] total=4000
cache size: 1


eval_train:   3%|▊                               | 2/79 [00:07<05:03,  3.94s/it]

Compiled function `model` for args:
- n_node=[12  8  3  5  8  3  6  3  5  5  4 12 16 12  7 12  7  0  0  0  0  0  0  0] total=128
- n_edge=[160 464 128 250 400  64 108  64 142 250  56 216 752 288  96 204 358   0
   0   0   0   0   0   0] total=4000
cache size: 2
Compiled function `model` for args:
- n_node=[12 10  8  9 10  8  5  5  9 10  4  6 10  8  9  5] total=128
- n_edge=[760 192 348 280 254 348 116 106 220 360 150  72 148 204 174 268] total=4000
cache size: 3


eval_train:  20%|██████▎                        | 16/79 [00:16<00:47,  1.33it/s]

Compiled function `model` for args:
- n_node=[12  5 10  4  2  6 12 16  6  2 11  8  2  0  0  0] total=96
- n_edge=[504  92 560 146  28 110 432 800  72  52 474 292 438   0   0   0] total=4000
cache size: 4


eval_train:  27%|████████▏                      | 21/79 [00:17<00:21,  2.74it/s]

Compiled function `model` for args:
- n_node=[ 5 10  5  9  4  4  7  4  2  5  4  2  5 12  4  2 12  5 10  8  8  6 12 12
  4 31  0  0  0  0  0  0] total=192
- n_edge=[144 308 194 378 174 100  98  56  56 106  72  52 122 144 114  36 152 170
 176 192 126 176 168 174 104 408   0   0   0   0   0   0] total=4000
cache size: 5


eval_train:  99%|██████████████████████████████▌| 78/79 [00:23<00:00,  3.27it/s]


Interval 0: eval_train: mae_becs=0.7741, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.78it/s]


Interval 0: eval_valid: mae_becs=0.7818, 
Interval 0: Time per interval: 0.0s, among which 24.0s for evaluation.


Train interval 0:   0%|         | 1/500 [00:07<1:03:29,  7.63s/it, loss=469.983]

Compiled function `update_fn` for args:
- n_node=[ 8  6  7  6  2  7  2 12 10  3  7  3  2  3 10  6  8  8  6 12  5 59  0  0] total=192
- n_edge=[236 108 298 228  32 122  92 534 446 112  96 128  56  64 140 252 212 188
 176 384  76  20   0   0] total=4000
Outout: loss= 366.340
Compilation time: 7.634s, cache size: 1
Compiled function `update_fn` for args:
- n_node=[ 5 10  8  8  6  5 12  7  5  9  6  9 10 10 12  9 61  0  0  0  0  0  0  0] total=192
- n_edge=[100 308 132 172 144 140 600  88 106 342  68 270 316 124 548 360 182   0
   0   0   0   0   0   0] total=4000
Outout: loss= 45.891
Compilation time: 0.035s, cache size: 2


Train interval 0:   2%|▏           | 9/500 [00:15<10:08,  1.24s/it, loss=92.336]

Compiled function `update_fn` for args:
- n_node=[ 7  6  6  5 12 12 10 12 10  8 12 10  5 13  0  0] total=128
- n_edge=[168 138 176 158 616 384 308 760 206 208 288 416 154  20   0   0] total=4000
Outout: loss= 99.672
Compilation time: 7.283s, cache size: 3


Train interval 0:   3%|▎          | 16/500 [00:22<08:05,  1.00s/it, loss=50.537]

Compiled function `update_fn` for args:
- n_node=[ 5  4 10  4 10  2  6  3  8  8  2  2 10  2  8  7  5 10 12  6  4  0  0  0] total=128
- n_edge=[164 156 260  94 254  32 104 128 540 184  36  28 308  36 252 104 150 184
 580  68 338   0   0   0] total=4000
Outout: loss= 54.153
Compilation time: 7.313s, cache size: 4


Train interval 0:   5%|▍         | 24/500 [00:32<03:24,  2.33it/s, loss=102.047]

Compiled function `update_fn` for args:
- n_node=[ 7  2 10 10 12  3  8  8  6  9  5  6  6  6 30  0] total=128
- n_edge=[ 96  28 140 148 394 128 288 160 192 124  76 256 296 436 310   0] total=3072
Outout: loss= 71.601
Compilation time: 9.382s, cache size: 5


Train interval 0:   6%|▋          | 31/500 [00:41<10:28,  1.34s/it, loss=57.240]

Compiled function `update_fn` for args:
- n_node=[11  3 10  4 10  8  8  8  6  9 40 11] total=128
- n_edge=[ 440   72  206  126  356  288  516  196  186  124 1464   26] total=4000
Outout: loss= 19.236
Compilation time: 9.460s, cache size: 6


Train interval 0:  13%|█▍         | 67/500 [00:50<03:30,  2.06it/s, loss=27.851]

Compiled function `update_fn` for args:
- n_node=[ 8  7  5  5  6  8  4  6  8  6  6  4 10  6 10  4  3 12 10  2  8  3  6  8
  2 35  0  0  0  0  0  0] total=192
- n_edge=[224 128  76 142 176 392 114 216 148 196 100 122 422 108 154 100  70 252
 152  32 196  64 144 216  12  44   0   0   0   0   0   0] total=4000
Outout: loss= 28.528
Compilation time: 7.756s, cache size: 7


Train interval 0:  18%|██         | 91/500 [01:00<05:13,  1.31it/s, loss=42.666]

Compiled function `update_fn` for args:
- n_node=[ 4  9 10 11 10  8  7 12 12 11 12 12  6 10 12 46] total=192
- n_edge=[174 124 244 334 158 160 232 204 384 306 416 464  92 328 144 236] total=4000
Outout: loss= 56.358
Compilation time: 9.124s, cache size: 8


Train interval 0:  20%|██▏        | 99/500 [01:08<05:11,  1.29it/s, loss=18.502]

Compiled function `update_fn` for args:
- n_node=[12  9  6  3  6  4 12  8 12  8  2 12  2  0  0  0] total=96
- n_edge=[616 234 184 194 180 144 858 116 580 300  52 368 174   0   0   0] total=4000
Outout: loss= 26.875
Compilation time: 7.383s, cache size: 9


Train interval 0:  33%|██▉      | 163/500 [01:19<04:18,  1.30it/s, loss=109.418]

Compiled function `update_fn` for args:
- n_node=[ 2 12  2  8  2  4  3 11 12  2  8  2  9 10  4  5  2  3  3 24  0  0  0  0] total=128
- n_edge=[ 52 300  36 172  52 114  64 306 300 160 348  28 438 144 104 116  52 194
  54  38   0   0   0   0] total=3072
Outout: loss= 5.928
Compilation time: 9.835s, cache size: 10


Train interval 0:  85%|████████▍ | 423/500 [01:36<00:39,  1.95it/s, loss=26.350]

Compiled function `update_fn` for args:
- n_node=[10  8  8  7  5  3  3  8  4  8  4  5  4 19  0  0] total=96
- n_edge=[446 416 342 168 164 128  94 208 106 224 284  76 122 294   0   0] total=3072
Outout: loss= 94.990
Compilation time: 9.218s, cache size: 11


Train interval 0: 100%|██████████| 500/500 [01:39<00:00,  5.04it/s, loss=15.245]
eval_train:  44%|█████████████▋                 | 35/79 [00:02<00:06,  6.97it/s]

Compiled function `model` for args:
- n_node=[12  6  8  8  6  4 10  2 10  8  2  8  8  9 27  0] total=128
- n_edge=[520  80 136 232 186 180 308  92 228 264  32 282 228 270  34   0] total=3072
cache size: 6
Compiled function `model` for args:
- n_node=[12 10  9 12  2 12  7 10 12  5  5  0] total=96
- n_edge=[352 404 342 824  92 300 168 308 632  92 486   0] total=4000
cache size: 7


eval_train:  49%|███████████████▎               | 39/79 [00:08<00:18,  2.14it/s]

Compiled function `model` for args:
- n_node=[12 24  3  8 12  2  7 10  4 10  4 32] total=128
- n_edge=[ 504 1340   64  288  524   92  188  156   66  548   66  164] total=4000
cache size: 8


eval_train:  99%|██████████████████████████████▌| 78/79 [00:08<00:00,  8.67it/s]


Interval 1: eval_train: mae_becs=0.9917, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 143.14it/s]


Interval 1: eval_valid: mae_becs=0.9925, 
Interval 1: Time per interval: 61.6s, among which 16.5s for evaluation.


Train interval 1: 100%|██████████| 500/500 [00:14<00:00, 33.75it/s, loss=89.200]
eval_train:  97%|██████████████████████████████▏| 77/79 [00:02<00:00, 28.03it/s]


Compiled function `model` for args:
- n_node=[40 10 12  8  6 10  5 10 12 10 11  5  6 47  0  0] total=192
- n_edge=[1464  158  272  298  104  206  116  340  208  254  334   76   92   78
    0    0] total=4000
cache size: 9
Interval 2: eval_train: mae_becs=1.1085, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 137.45it/s]


Interval 2: eval_valid: mae_becs=1.1382, 
Interval 2: Time per interval: 49.0s, among which 11.9s for evaluation.


Train interval 2: 100%|██████████| 500/500 [00:15<00:00, 33.18it/s, loss=10.366]
eval_train:  73%|██████████████████████▊        | 58/79 [00:05<00:03,  6.00it/s]

Compiled function `model` for args:
- n_node=[ 8  2 12  3  9 10  7 12  9  8  4  6  4 12 10 12 10  8  3 10  9  4  8  7
  2  9 58  0  0  0  0  0] total=256
- n_edge=[224  36 180  74 124 140 136 240 124 212  56 144  56 312 156 152 140 136
 132 206 166  72 184 134  36 412  16   0   0   0   0   0] total=4000
cache size: 10


eval_train:  99%|██████████████████████████████▌| 78/79 [00:05<00:00, 14.76it/s]


Interval 3: eval_train: mae_becs=1.1640, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 137.98it/s]


Interval 3: eval_valid: mae_becs=1.2055, 
Interval 3: Time per interval: 55.0s, among which 5.7s for evaluation.


Train interval 3: 100%|██████████| 500/500 [00:14<00:00, 34.01it/s, loss=88.859]
eval_train:  38%|███████████▊                   | 30/79 [00:03<00:07,  6.97it/s]

Compiled function `model` for args:
- n_node=[ 8  2 15 12 12 12  4  4 10  6 11  0] total=96
- n_edge=[358  28 398 416 144 456 176 186 152 340 418   0] total=3072
cache size: 11


eval_train:  99%|██████████████████████████████▌| 78/79 [00:04<00:00, 19.39it/s]


Interval 4: eval_train: mae_becs=1.3650, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.43it/s]


Interval 4: eval_valid: mae_becs=1.4858, 
Interval 4: Time per interval: 20.6s, among which 4.1s for evaluation.


Train interval 4:  18%|██         | 92/500 [00:10<02:54,  2.34it/s, loss=16.126]

Compiled function `update_fn` for args:
- n_node=[ 2 10  9 10 12 11  2 12 10  7  3  8] total=96
- n_edge=[ 36 356 360 132 512 482  36 416 152 188  54 348] total=3072
Outout: loss= 8.382
Compilation time: 7.584s, cache size: 12


Train interval 4:  57%|█████▋    | 284/500 [00:23<01:31,  2.35it/s, loss=15.767]

Compiled function `update_fn` for args:
- n_node=[12  2 10  5  2  9 10 12 12 12  6  4] total=96
- n_edge=[362  52 152 250  52 438 560 256 652 336 444 446] total=4000
Outout: loss= 4.309
Compilation time: 7.570s, cache size: 13


Train interval 4: 100%|███████████| 500/500 [00:29<00:00, 16.68it/s, loss=7.073]
eval_train:  99%|██████████████████████████████▌| 78/79 [00:00<00:00, 83.79it/s]


Interval 5: eval_train: mae_becs=1.5619, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 138.62it/s]


Interval 5: eval_valid: mae_becs=1.5293, 
Interval 5: Time per interval: 24.0s, among which 3.5s for evaluation.


Train interval 5: 100%|███████████| 500/500 [00:14<00:00, 33.79it/s, loss=6.585]
eval_train:  99%|██████████████████████████████▌| 78/79 [00:00<00:00, 91.38it/s]


Interval 6: eval_train: mae_becs=1.3412, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 141.62it/s]


Interval 6: eval_valid: mae_becs=1.3159, 
Interval 6: Time per interval: 23.3s, among which 2.0s for evaluation.


Train interval 6:  12%|█▎         | 60/500 [00:11<03:56,  1.86it/s, loss=12.073]

Compiled function `update_fn` for args:
- n_node=[ 2  2  7  8  9  2  3 10  6 10  4 10  8  2 10  2  1  0  0  0  0  0  0  0] total=96
- n_edge=[ 36  28 232 128 342 112  64 152 256 360 108 152 168  56 356 100 422   0
   0   0   0   0   0   0] total=3072
Outout: loss= 18.333
Compilation time: 9.661s, cache size: 14


Train interval 6: 100%|██████████| 500/500 [00:24<00:00, 20.37it/s, loss=10.402]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 100.52it/s]


Interval 7: eval_train: mae_becs=1.3743, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 141.48it/s]


Interval 7: eval_valid: mae_becs=1.3265, 
Interval 7: Time per interval: 25.1s, among which 0.9s for evaluation.


Train interval 7:  62%|██████▊    | 312/500 [00:16<01:14,  2.51it/s, loss=7.724]

Compiled function `update_fn` for args:
- n_node=[10 10  3 10  9  3  7 10  9  8 10  2 11  8  6  2  2  2  2  5 12 10 10  9
  2 12  4  8  9 51  0  0] total=256
- n_edge=[316 148 128 176 130  64  98 160 130 204 152  32 206 124 104  28  56  52
  28  76 228 152 464 134  28 210  56 138 130  48   0   0] total=4000
Outout: loss= 8.648
Compilation time: 7.041s, cache size: 15


Train interval 7: 100%|██████████| 500/500 [00:21<00:00, 22.85it/s, loss=81.721]
eval_train:  99%|██████████████████████████████▌| 78/79 [00:01<00:00, 76.89it/s]


Interval 8: eval_train: mae_becs=1.3604, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 140.84it/s]


Interval 8: eval_valid: mae_becs=1.3159, 
Interval 8: Time per interval: 21.3s, among which 0.9s for evaluation.


Train interval 8:  98%|██████████▋| 488/500 [00:24<00:05,  2.07it/s, loss=4.746]

Compiled function `update_fn` for args:
- n_node=[12  8 12  9  6 11  3  8  6 12 12 29] total=128
- n_edge=[356 224 208 364 136 606  64 176 108 504 172 154] total=3072
Outout: loss= 19.675
Compilation time: 8.656s, cache size: 16


Train interval 8: 100%|███████████| 500/500 [00:24<00:00, 20.49it/s, loss=7.035]
eval_train:  97%|█████████████████████████████▏| 77/79 [00:00<00:00, 110.89it/s]


Interval 9: eval_train: mae_becs=1.4488, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 138.11it/s]


Interval 9: eval_valid: mae_becs=1.3832, 
Interval 9: Time per interval: 24.5s, among which 0.9s for evaluation.


Train interval 9: 100%|██████████| 500/500 [00:14<00:00, 33.60it/s, loss=45.585]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 145.00it/s]


Interval 10: eval_train: mae_becs=1.6583, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.89it/s]


Interval 10: eval_valid: mae_becs=1.6085, 
Interval 10: Time per interval: 21.3s, among which 0.8s for evaluation.


Train interval 10: 100%|██████████| 500/500 [00:15<00:00, 31.41it/s, loss=5.535]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 141.33it/s]


Interval 11: eval_train: mae_becs=1.9922, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 141.90it/s]


Interval 11: eval_valid: mae_becs=1.9541, 
Interval 11: Time per interval: 19.2s, among which 0.6s for evaluation.


Train interval 11:  54%|████▊    | 268/500 [00:16<02:38,  1.47it/s, loss=29.760]

Compiled function `update_fn` for args:
- n_node=[ 8  8  6  4 10  8 12  4  9  4  2  6 12 11  8  3  5  8 64  0  0  0  0  0] total=192
- n_edge=[228 216 108 106 140 252 504  56 130 104  36 104 176 230 160  64 154 216
  88   0   0   0   0   0] total=3072
Outout: loss= 9.978
Compilation time: 8.716s, cache size: 17


Train interval 11: 100%|█████████| 500/500 [00:23<00:00, 21.19it/s, loss=12.682]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 127.41it/s]


Interval 12: eval_train: mae_becs=2.2603, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 138.15it/s]


Interval 12: eval_valid: mae_becs=2.3016, 
Interval 12: Time per interval: 18.8s, among which 0.6s for evaluation.


Train interval 12: 100%|██████████| 500/500 [00:14<00:00, 33.34it/s, loss=4.701]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 111.76it/s]


Interval 13: eval_train: mae_becs=2.3597, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 135.08it/s]


Interval 13: eval_valid: mae_becs=2.5293, 
Interval 13: Time per interval: 18.8s, among which 0.7s for evaluation.


Train interval 13: 100%|█████████| 500/500 [00:14<00:00, 33.74it/s, loss=14.826]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 145.98it/s]


Interval 14: eval_train: mae_becs=3.3262, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 137.89it/s]


Interval 14: eval_valid: mae_becs=4.0378, 
Interval 14: Time per interval: 18.5s, among which 0.7s for evaluation.


Train interval 14: 100%|██████████| 500/500 [00:15<00:00, 31.38it/s, loss=3.365]
eval_train: 100%|███████████████████████████████| 79/79 [00:04<00:00, 19.54it/s]


Compiled function `model` for args:
- n_node=[ 6  8  2  4  5  5 10  3  8  6  6  2  3  6  2  8  4  2  2  4 10  3  2  8
  9  0  0  0  0  0  0  0] total=128
- n_edge=[104 152  12  96 106 154 296  28 108 172  68 160  72  92  56 192  96  68
  28 122 252 128  32 432  46   0   0   0   0   0   0   0] total=3072
cache size: 12
Interval 15: eval_train: mae_becs=3.0846, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.89it/s]


Interval 15: eval_valid: mae_becs=3.3494, 
Interval 15: Time per interval: 15.9s, among which 1.8s for evaluation.


Train interval 15: 100%|█████████| 500/500 [00:14<00:00, 33.55it/s, loss=11.903]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 146.45it/s]


Interval 16: eval_train: mae_becs=3.5410, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 142.60it/s]


Interval 16: eval_valid: mae_becs=3.7970, 
Interval 16: Time per interval: 17.0s, among which 1.8s for evaluation.


Train interval 16: 100%|█████████| 500/500 [00:14<00:00, 33.62it/s, loss=13.517]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 126.43it/s]


Interval 17: eval_train: mae_becs=3.4246, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 134.29it/s]


Interval 17: eval_valid: mae_becs=3.4916, 
Interval 17: Time per interval: 17.0s, among which 1.8s for evaluation.


Train interval 17: 100%|██████████| 500/500 [00:14<00:00, 33.72it/s, loss=3.054]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 111.74it/s]


Interval 18: eval_train: mae_becs=3.5054, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 135.22it/s]


Interval 18: eval_valid: mae_becs=3.6232, 
Interval 18: Time per interval: 16.6s, among which 0.7s for evaluation.


Train interval 18: 100%|██████████| 500/500 [00:14<00:00, 33.88it/s, loss=2.190]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 144.48it/s]


Interval 19: eval_train: mae_becs=3.4200, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 140.93it/s]


Interval 19: eval_valid: mae_becs=3.4577, 
Interval 19: Time per interval: 15.5s, among which 0.7s for evaluation.


Train interval 19:  54%|█████▍    | 272/500 [00:17<02:03,  1.84it/s, loss=2.538]

Compiled function `update_fn` for args:
- n_node=[12  9  3 12  9 10 12 12  7  8 40 58] total=192
- n_edge=[ 300  166   64  252  320  120  488  208  152  224 1464  242] total=4000
Outout: loss= 2.842
Compilation time: 9.788s, cache size: 18


Train interval 19: 100%|██████████| 500/500 [00:24<00:00, 20.28it/s, loss=1.042]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 145.50it/s]


Interval 20: eval_train: mae_becs=3.3959, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 131.11it/s]


Interval 20: eval_valid: mae_becs=3.4303, 
Interval 20: Time per interval: 18.8s, among which 0.6s for evaluation.


Train interval 20: 100%|██████████| 500/500 [00:15<00:00, 33.23it/s, loss=0.946]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 144.83it/s]


Interval 21: eval_train: mae_becs=3.6222, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 138.74it/s]


Interval 21: eval_valid: mae_becs=3.6309, 
Interval 21: Time per interval: 18.8s, among which 0.6s for evaluation.


Train interval 21: 100%|██████████| 500/500 [00:15<00:00, 33.12it/s, loss=4.578]
eval_train: 100%|██████████████████████████████| 79/79 [00:00<00:00, 146.39it/s]


Interval 22: eval_train: mae_becs=3.7389, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.11it/s]


Interval 22: eval_valid: mae_becs=3.6983, 
Interval 22: Time per interval: 18.9s, among which 0.6s for evaluation.


Train interval 22: 100%|██████████| 500/500 [00:15<00:00, 32.96it/s, loss=0.838]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 141.88it/s]


Interval 23: eval_train: mae_becs=3.8113, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 135.45it/s]


Interval 23: eval_valid: mae_becs=3.8067, 
Interval 23: Time per interval: 15.7s, among which 0.6s for evaluation.


Train interval 23: 100%|██████████| 500/500 [00:15<00:00, 33.10it/s, loss=5.338]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 142.29it/s]


Interval 24: eval_train: mae_becs=3.7896, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 134.88it/s]


Interval 24: eval_valid: mae_becs=3.7184, 
Interval 24: Time per interval: 15.7s, among which 0.6s for evaluation.


Train interval 24:  56%|█████▌    | 280/500 [00:18<02:01,  1.81it/s, loss=1.347]

Compiled function `update_fn` for args:
- n_node=[ 8  2  6  5  6  5 10  3  2  2  3  5 10  3  2  7  2  4 10  8  6  6  8  4
  1  0  0  0  0  0  0  0] total=128
- n_edge=[328  92 256 106 192 154 500  64  32  68  64  84 254  64  56 128  52  56
 206 216 236 256 192 110 234   0   0   0   0   0   0   0] total=4000
Outout: loss= 1.232
Compilation time: 9.980s, cache size: 19


Train interval 24: 100%|██████████| 500/500 [00:25<00:00, 19.96it/s, loss=1.508]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 141.18it/s]


Interval 25: eval_train: mae_becs=4.0721, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 131.02it/s]


Interval 25: eval_valid: mae_becs=4.0273, 
Interval 25: Time per interval: 19.0s, among which 0.6s for evaluation.


Train interval 25: 100%|██████████| 500/500 [00:15<00:00, 33.16it/s, loss=2.181]
eval_train:  99%|██████████████████████████████▌| 78/79 [00:02<00:00, 31.09it/s]


Compiled function `model` for args:
- n_node=[ 9 10  2  2  5 12  3  6 12  8 10  8  9  0  0  0] total=96
- n_edge=[438 500 172  56 106 368 128 112 388 128 308 160 208   0   0   0] total=3072
cache size: 13
Interval 26: eval_train: mae_becs=4.0431, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 132.75it/s]


Interval 26: eval_valid: mae_becs=4.0266, 
Interval 26: Time per interval: 19.0s, among which 1.3s for evaluation.


Train interval 26: 100%|██████████| 500/500 [00:15<00:00, 33.32it/s, loss=1.275]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 124.55it/s]


Interval 27: eval_train: mae_becs=4.0453, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 137.26it/s]


Interval 27: eval_valid: mae_becs=4.0245, 
Interval 27: Time per interval: 19.6s, among which 1.3s for evaluation.


Train interval 27: 100%|██████████| 500/500 [00:14<00:00, 33.63it/s, loss=2.343]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 144.80it/s]


Interval 28: eval_train: mae_becs=4.1249, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 141.06it/s]


Interval 28: eval_valid: mae_becs=4.0544, 
Interval 28: Time per interval: 16.3s, among which 1.3s for evaluation.


Train interval 28: 100%|██████████| 500/500 [00:14<00:00, 33.50it/s, loss=0.755]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 145.22it/s]


Interval 29: eval_train: mae_becs=4.1200, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 136.60it/s]


Interval 29: eval_valid: mae_becs=4.0848, 
Interval 29: Time per interval: 16.2s, among which 0.6s for evaluation.


Train interval 29: 100%|██████████| 500/500 [00:14<00:00, 33.56it/s, loss=0.474]
eval_train:  99%|█████████████████████████████▌| 78/79 [00:00<00:00, 144.75it/s]


Interval 30: eval_train: mae_becs=4.1604, 


eval_valid: 100%|████████████████████████████████| 5/5 [00:00<00:00, 139.06it/s]


Interval 30: eval_valid: mae_becs=4.1131, 
Interval 30: Time per interval: 15.5s, among which 0.6s for evaluation.
Training complete
training done!


In [5]:
model_fn, params, num_message_passing = BECS_EPS_model(
    r_max=r_max,
    atomic_energies_dict={},
    train_graphs=[],
    initialize_seed=config["model"]["seed"],
    num_species = config["model"]["num_species"],
    use_sc = True,
    graph_net_steps = config["model"]["num_layers"],
    hidden_irreps = config["model"]["internal_irreps"],
    nonlinearities =  {'e': 'swish', 'o': 'tanh'},
    save_dir_name = save_dir_name,
    reload = '2024-03-12-18:18-eps_training-amazing-caren',
)


Create BECS/EPS (NequIP-based) model with parameters {'use_sc': True, 'graph_net_steps': 3, 'hidden_irreps': '48x0e + 48x0o + 32x1o + 32x1e +24x2o + 24x2e', 'nonlinearities': {'e': 'swish', 'o': 'tanh'}, 'r_max': 5.0, 'avg_num_neighbors': 28.766618979494517, 'avg_r_min': None, 'num_species': 100, 'radial_basis': <function bessel_basis at 0x7f26cccf2c00>, 'radial_envelope': <function soft_envelope at 0x7f26cccf3c40>}


In [6]:
print(len(valid_loader.graphs))

idx = 7

valid_graph = valid_loader.graphs[idx]
# ground truth Born effective charge
print(valid_graph.nodes.becs)

# predicted Born effective charges
pred_becs = predictor(params, valid_graph)
print(pred_becs['becs'])


100
[[[ 1.917  0.     0.   ]
  [ 0.     1.917  0.   ]
  [-0.    -0.     1.917]]

 [[ 1.917  0.     0.   ]
  [ 0.     1.917  0.   ]
  [-0.    -0.     1.917]]

 [[-3.835  0.     0.   ]
  [ 0.    -3.835  0.   ]
  [ 0.     0.    -3.835]]]
[[[-1.956  0.     0.   ]
  [ 0.    -1.956 -0.   ]
  [ 0.    -0.    -1.956]]

 [[-1.956  0.     0.   ]
  [ 0.    -1.956 -0.   ]
  [ 0.    -0.    -1.956]]

 [[ 3.913 -0.    -0.   ]
  [-0.     3.913  0.   ]
  [-0.     0.     3.913]]]


In [7]:


print(valid_graph.globals.eps)

print(pred_becs['eps'])

[[[14.029 -0.    -0.   ]
  [-0.    14.029 -0.   ]
  [-0.    -0.    14.029]]]
[[[11.607 -0.    -0.   ]
  [-0.    11.607  0.   ]
  [-0.     0.    11.607]]]
