In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
from skorch import NeuralNetRegressor
from skorch.dataset import CVSplit
from skorch.callbacks import Checkpoint, EpochScoring
from skorch.callbacks.lr_scheduler import LRScheduler
import skorch.callbacks.base
# from amp.descriptor.gaussian import Gaussian
from amptorch.gaussian import SNN_Gaussian
from amptorch.model import BPNN, CustomMSELoss
from amptorch.data_preprocess import AtomsDataset, factorize_data, collate_amp, TestDataset
from amptorch.skorch_model import AMP
from amptorch.skorch_model.utils import target_extractor, energy_score, forces_score, train_end_load_best_loss
from amptorch.analysis import parity_plot
from torch.utils.data import DataLoader
from torch.nn import init
from skorch.utils import to_numpy
import numpy as np
from ase import Atoms
from ase.calculators.emt import EMT
import ase.io
from sklearn.model_selection import ShuffleSplit
import copy

label = "zeolite"

In [3]:
from pymatgen.analysis.local_env import VoronoiNN, CovalentBondNN, JmolNN
from pymatgen.io.ase import AseAtomsAdaptor

images = ase.io.read('traj_taged_adsorptionenergy.traj', index=':')
for atoms in images:
    all_distances = atoms.get_all_distances(mic=True)
    within_cutoff = all_distances[atoms.get_tags()==1]
    within_cutoff = (within_cutoff<3)&(within_cutoff>0.5)
    within_cutoff = within_cutoff.any(axis=0)*1
    atoms.set_tags(within_cutoff)

In [4]:
# images = images[0:100]

def Split(images):
  '''random split'''
  args = (np.arange(len(images)),)
  cv = ShuffleSplit(n_splits=10, random_state=None, test_size=0.2, train_size=None)
  idx_train, idx_test = next(iter(cv.split(*args, groups=None)))
  train_images = [images[index] for index in idx_train]
  test_images = [images[index] for index in idx_test]
  return train_images, test_images

train_images_original, test_images_original = Split(images)
train_images = copy.deepcopy(train_images_original)
test_images = copy.deepcopy(test_images_original)

# train_images = train_images[0:200]


In [19]:

LR_schedule = LRScheduler('CosineAnnealingLR', T_max=5)
# saves best validation loss
cp = Checkpoint(monitor='valid_loss_best', fn_prefix='valid_best_')
# loads best validation loss at the end of training
load_best_valid_loss = train_end_load_best_loss(label)


# define symmetry functions to be used

Gs = {}
Gs["G2_etas"] = np.logspace(np.log10(0.05), np.log10(5.0), num=4)
Gs["G2_rs_s"] = [0] * 4
Gs["G4_etas"] = [0.005]
Gs["G4_zetas"] = [1.0]
Gs["G4_gammas"] = [+1.0, -1]
Gs["cutoff"] = 6.5

forcetraining = False
training_data = AtomsDataset(train_images, SNN_Gaussian, Gs, forcetraining=forcetraining,
        label=label, cores=1, delta_data=None,  specific_atoms=True)

unique_atoms = training_data.elements
fp_length = training_data.fp_length
device = "cpu"

# loads best validation loss at the end of training
class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
    def on_train_end(self, net, X, y):
        net.load_params('valid_best_params.pt')
cp = Checkpoint(monitor='valid_loss_best', fn_prefix='valid_best_')
load_best_valid_loss = train_end_load_best_valid_loss()

net = NeuralNetRegressor(
    module=BPNN(unique_atoms, [fp_length, 3, 10], device, forcetraining=forcetraining),
    criterion=CustomMSELoss,
    criterion__force_coefficient=0.0,
    optimizer=torch.optim.AdamW,
    lr=0.01,
    batch_size=40,
    max_epochs=1000,
    iterator_train__collate_fn=collate_amp,
    iterator_train__shuffle=True,
    iterator_valid__collate_fn=collate_amp,
    iterator_valid__shuffle=False,
    optimizer__weight_decay=0.001,
    device=device,
    train_split=CVSplit(5),
    callbacks=[
        EpochScoring(
            energy_score,
            name='energy_score_valid',
            on_train=False,
            use_caching=False,
            target_extractor=target_extractor,
        ),
        EpochScoring(
            energy_score,
            name='energy_score_train',
            on_train=True,
            use_caching=False,
            target_extractor=target_extractor,
        ),
        LRScheduler(skorch.callbacks.WarmRestartLR, max_lr=0.01),
        cp,
        load_best_valid_loss,
    ],
)
calc = AMP(training_data, net, 'test')
calc.train(overwrite=True)
# parity_plot(calc, images, data="energy", label=label)
# parity_plot(calc, images, data="forces", label=label)

Calculating fingerprints...
Fingerprints Calculated!
  epoch    energy_score_train    energy_score_valid    train_loss    valid_loss    cp     dur
-------  --------------------  --------------------  ------------  ------------  ----  ------
      1                2.8548                2.8181        [35m3.6970[0m        [31m0.7659[0m     +  1.7847




      2                2.5194                2.6503        [35m0.7029[0m        [31m0.6877[0m     +  1.4215
      3                2.1786                2.3032        [35m0.4976[0m        [31m0.5076[0m     +  1.5603
      4                1.8403                1.9889        [35m0.4095[0m        [31m0.3827[0m     +  1.3703
      5                1.9351                2.0118        0.4103        0.3938        1.3422
      6                1.7375                1.8561        [35m0.3384[0m        [31m0.3337[0m     +  1.4006
      7                1.7338                1.9003        [35m0.2912[0m        0.3522        1.4534
      8                1.6605                1.7593        [35m0.2783[0m        [31m0.2969[0m     +  1.4237
      9                1.6541                1.7512        0.2841        0.2983        1.4208
     10                1.6369                1.7698        [35m0.2666[0m        0.3061        1.3983
     11                1.6322                1

     86                1.3388                1.7106        0.1783        0.2661        1.4051
     87                1.2197                1.7089        0.1854        0.2762        1.4303
     88                1.3737                1.8469        0.1798        0.3273        1.2513
     89                1.3778                1.7582        0.1858        0.2799        1.2524
     90                1.3087                1.7302        0.2019        0.2832        1.2297
     91                1.2443                1.7439        0.1583        0.2857        1.2332
     92                1.4139                1.7965        0.1833        0.3024        1.4838
     93                1.2211                1.7239        0.1833        0.2829        1.4223
     94                1.4681                1.9058        0.1779        0.3312        1.4715
     95                1.6800                2.0620        0.2227        0.3978        1.7026
     96                1.3268                1.7396        0

    172                1.2594                1.8694        0.1298        0.3299        1.2965
    173                1.3434                1.9840        0.1634        0.3740        1.2385
    174                0.9924                1.7258        0.1438        0.2746        1.4561
    175                0.9768                1.7612        0.1194        0.2866        1.4023
    176                0.9829                1.7017        0.1046        0.2706        1.4028
    177                1.4198                2.0350        0.1153        0.3902        1.4081
    178                1.3064                1.8803        0.1936        0.3358        1.4241
    179                1.1046                1.8042        0.1479        0.2943        1.3866
    180                0.9925                1.6888        0.1110        0.2709        1.4266
    181                1.0774                1.7604        0.0972        0.2842        1.3651
    182                1.0278                1.7404        0

    258                0.5613                1.7264        [35m0.0330[0m        0.2719        1.4700
    259                0.5854                1.8040        0.0386        0.2957        1.2958
    260                0.5723                1.7909        0.0427        0.2941        1.2639
    261                0.5801                1.7497        0.0342        0.2805        1.2897
    262                0.6122                1.7735        [35m0.0328[0m        0.2838        1.3763
    263                0.5721                1.7639        0.0336        0.2866        1.4505
    264                0.5626                1.7644        0.0349        0.2813        1.4135
    265                0.5365                1.7517        [35m0.0305[0m        0.2783        1.4319
    266                0.5524                1.7759        [35m0.0295[0m        0.2847        1.5353
    267                0.5478                1.7579        0.0302        0.2784        1.4517
    268                0

    344                1.0117                1.9213        0.0794        0.3488        1.4502
    345                0.8008                1.7753        0.0709        0.2870        1.3596
    346                0.9394                1.8271        0.0729        0.3047        1.3807
    347                0.9901                1.8894        0.0780        0.3144        1.3990
    348                1.2098                1.9649        0.0917        0.3810        1.3876
    349                1.0498                1.8957        0.0804        0.3354        1.4236
    350                0.7573                1.7880        0.0745        0.2900        1.4050
    351                0.7841                1.7630        0.0641        0.2784        1.4918
    352                1.0971                1.9327        0.0907        0.3620        1.4590
    353                0.7312                1.7707        0.0784        0.2799        1.4224
    354                0.9274                1.8164        0

    431                0.4206                1.8221        [35m0.0185[0m        0.3025        1.5042
    432                0.5416                1.7786        0.0190        0.2908        1.4300
    433                0.4224                1.7729        0.0248        0.2838        1.4828
    434                0.4264                1.7946        0.0227        0.2945        1.4455
    435                0.3898                1.8028        0.0200        0.2957        1.6424
    436                0.3660                1.7884        [35m0.0159[0m        0.2927        1.4391
    437                0.3933                1.7856        0.0173        0.2904        1.4386
    438                0.4584                1.8777        0.0195        0.3224        1.4669
    439                0.4515                1.8039        0.0207        0.2941        1.4220
    440                0.5645                1.8409        0.0172        0.3033        1.5327
    441                0.7069             

    517                0.3011                1.8153        0.0101        0.2968        1.4427
    518                0.2840                1.8276        0.0088        0.3018        1.4726
    519                0.2890                1.8406        [35m0.0083[0m        0.3064        1.4767
    520                0.2843                1.8346        0.0086        0.3052        1.4325
    521                0.2870                1.8398        0.0088        0.3069        1.4476
    522                0.2687                1.8266        [35m0.0078[0m        0.3036        1.3906
    523                0.3000                1.8339        0.0079        0.3030        1.4926
    524                0.2745                1.8499        0.0083        0.3092        1.4730
    525                0.3030                1.8447        0.0079        0.3048        1.4304
    526                0.2716                1.8273        [35m0.0075[0m        0.3016        1.4113
    527                0.2818    

    603                0.2205                1.8426        [35m0.0048[0m        0.3066        1.5957
    604                0.2209                1.8419        [35m0.0047[0m        0.3062        1.4791
    605                0.2203                1.8407        0.0048        0.3057        1.4029
    606                0.2219                1.8419        0.0047        0.3058        1.4374
    607                0.2216                1.8416        0.0048        0.3058        1.4229
    608                0.2234                1.8443        0.0048        0.3077        1.4855
    609                0.2203                1.8428        0.0052        0.3069        1.7522
    610                0.2226                1.8369        0.0049        0.3039        1.3704
    611                0.2198                1.8418        0.0048        0.3064        1.2942
    612                0.2201                1.8434        0.0048        0.3069        2.1453
    613                0.2197             

    690                0.7095                1.7731        0.0671        0.2888        1.4998
    691                0.7475                1.8188        0.0533        0.3171        1.4789
    692                0.6742                1.7810        0.0497        0.2981        1.3824
    693                0.8461                1.8492        0.0568        0.3209        1.3233
    694                0.6387                1.7640        0.0528        0.2924        1.2832
    695                0.8227                1.8912        0.0605        0.3409        1.2986
    696                0.7920                1.8258        0.0523        0.3095        1.2743
    697                0.7745                1.8136        0.0563        0.3063        1.4738
    698                0.7527                1.7573        0.0644        0.2925        1.4175
    699                0.7233                1.7827        0.0525        0.2994        1.3936
    700                0.6947                1.7797        0

    778                0.4290                1.7887        0.0204        0.2972        1.4766
    779                0.4732                1.7740        0.0181        0.2939        1.4183
    780                0.4639                1.7485        0.0264        0.2877        1.2942
    781                0.6673                1.8322        0.0389        0.3130        1.3283
    782                0.4617                1.7844        0.0329        0.2998        1.3410
    783                0.4821                1.7599        0.0332        0.2923        1.5030
    784                0.4359                1.7541        0.0332        0.2850        1.4927
    785                0.5425                1.7821        0.0486        0.2978        1.4801
    786                0.4789                1.7605        0.0340        0.2886        1.4700
    787                0.5459                1.8223        0.0278        0.3082        1.5022
    788                0.5594                1.8336        0

    866                0.3845                1.7733        0.0212        0.2996        1.2866
    867                0.4218                1.8255        0.0162        0.3174        1.3187
    868                0.4808                1.8061        0.0197        0.3060        1.3264
    869                0.4914                1.8752        0.0207        0.3365        1.3374
    870                0.3809                1.8268        0.0162        0.3118        1.1781
    871                0.3311                1.8278        0.0140        0.3164        1.3122
    872                0.3070                1.8278        0.0120        0.3160        1.2654
    873                0.3621                1.8251        0.0107        0.3249        1.3894
    874                0.3591                1.8410        0.0140        0.3214        2.5067
    875                0.4050                1.8005        0.0161        0.3060        1.5392
    876                0.4140                1.8222        0

    954                0.2940                1.8465        0.0094        0.3246        1.3143
    955                0.2723                1.8588        0.0074        0.3352        1.4759
    956                0.2471                1.8502        0.0070        0.3226        1.5966
    957                0.2594                1.8583        0.0065        0.3289        1.6967
    958                0.2565                1.8329        0.0082        0.3221        1.7205
    959                0.2542                1.8598        0.0076        0.3243        1.6604
    960                0.2921                1.8592        0.0069        0.3289        1.3320
    961                0.2434                1.8699        0.0065        0.3281        1.4909
    962                0.2282                1.8722        0.0063        0.3302        1.6057
    963                0.3041                1.8701        0.0058        0.3312        1.3739
    964                0.3187                1.8727        0

In [6]:
energies = training_data.energy_dataset.detach().numpy()
energies.std()

0.9994212

In [7]:
energies = training_data.energy_dataset.detach().numpy()
print(np.mean((energies-energies.mean())**2))

0.99884266


In [8]:
train_images[0].get_tags()

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1,
       0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1,
       1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0])

In [9]:
len(train_images[0])

124

In [10]:
np.std([image.get_potential_energy() for image in train_images])

2.2586267682543157

In [11]:
training_data.energy_dataset

tensor([-0.7988, -0.5639, -1.1549, -0.8455,  1.4445, -1.4874, -0.8727, -0.8596,
         0.2586, -1.0138, -0.6981, -1.2368,  0.2611, -0.9089,  0.7494,  0.4381,
         0.6183,  0.0199, -0.1472,  2.3756, -1.0614, -3.5384,  1.0726,  0.9567,
         0.2036,  1.4097,  1.2245, -1.0392,  0.0409,  1.2739, -1.1547,  2.4198,
         0.7855,  1.0311,  0.1562,  1.1425, -1.2950,  0.5064,  0.3752,  0.0644,
         0.8524,  0.3442,  0.2866, -1.3212, -0.9067, -0.0050, -1.1088,  0.5668,
        -0.1838,  0.6235,  0.2679, -0.5463, -1.1920, -0.2274,  0.2907, -0.3928,
        -1.4090,  0.3200,  0.3359, -1.1185, -0.7900,  0.0549,  0.3257, -1.1145,
         0.8812, -0.1959, -0.3429,  1.3555,  0.3090,  0.5742, -1.2567,  0.9791,
         0.2919,  1.2568, -0.2156, -0.3037, -1.2212, -0.2563,  0.3412,  1.2526,
        -1.0578,  1.1282,  0.0271,  1.5150, -0.2418,  1.3465, -0.3750, -0.7319,
        -0.3865,  0.3248, -0.5598,  0.7123, -1.9891,  0.2748,  0.1687,  1.2319,
        -0.9351, -0.8931,  0.2396, -0.05

In [12]:
train_split, valid_split = CVSplit(5)(training_data)

In [13]:
net.forward(train_split)[0]

tensor([[ 0.9257],
        [-0.5453],
        [-0.5206],
        [-1.2872],
        [ 0.0755],
        [ 0.1618],
        [ 0.5747],
        [-0.3535],
        [ 0.4618],
        [ 1.3920],
        [ 0.2584],
        [ 0.1435],
        [ 0.3212],
        [-0.4797],
        [-0.9153],
        [ 1.2592],
        [ 0.5297],
        [ 1.1969],
        [ 1.6084],
        [-0.4589],
        [ 0.1943],
        [ 0.2919],
        [ 1.4089],
        [ 0.1057],
        [ 0.4172],
        [-0.8489],
        [ 0.3337],
        [ 0.5900],
        [ 0.2462],
        [ 0.3558],
        [-1.1019],
        [-0.4412],
        [-0.0757],
        [ 1.5962],
        [-1.2163],
        [ 0.1902],
        [-0.7009],
        [-0.7492],
        [-0.1262],
        [ 0.9387],
        [ 0.3507],
        [-1.1474],
        [-0.1238],
        [ 0.5947],
        [ 0.6348],
        [ 0.1692],
        [ 0.1020],
        [ 1.4759],
        [ 1.7373],
        [-0.1425],
        [ 0.3549],
        [ 0.3926],
        [-2.

In [14]:
training_data.scalings[2].denorm(torch.Tensor(np.array([a[1] for a in valid_split])))

tensor([ -4.9782,  -4.4473,  -5.7831,  -5.0839,   0.0915,  -6.5344,  -5.1452,
         -5.1157,  -2.5886,  -5.4642,  -4.7506,  -5.9681,  -2.5830,  -5.2270,
         -1.4794,  -2.1829,  -1.7756,  -3.1280,  -3.5056,   2.1957,  -5.5718,
        -11.1695,  -0.7489,  -1.0110,  -2.7128,   0.0128,  -0.4057,  -5.5215,
         -3.0806,  -0.2941,  -5.7826,   2.2955,  -1.3978,  -0.8429,  -2.8200,
         -0.5910,  -6.0996,  -2.0285,  -2.3250,  -3.0274,  -1.2467,  -2.3951,
         -2.5253,  -6.1587,  -5.2221,  -3.1844,  -5.6788,  -1.8920,  -3.5883,
         -1.7639,  -2.5676,  -4.4075,  -5.8668,  -3.6869,  -2.5160,  -4.0608,
         -6.3573,  -2.4499,  -2.4140,  -5.7008,  -4.9583,  -3.0489,  -2.4369,
         -5.6918,  -1.1816,  -3.6158,  -3.9479,  -0.1098,  -2.4746,  -1.8753,
         -6.0130,  -0.9602,  -2.5133,  -0.3326,  -3.6602,  -3.8592,  -5.9328,
         -3.7522,  -2.4018,  -0.3423,  -5.5635,  -0.6233,  -3.1118,   0.2507,
         -3.7195,  -0.1300,  -4.0204,  -4.8270,  -4.0464,  -2.43

In [15]:
errors = net.forward(train_split)[0][:,0]-torch.Tensor(np.array([a[1] for a in train_split]))
torch.sqrt(torch.mean(errors**2))

tensor(0.4659, grad_fn=<SqrtBackward>)

In [16]:
net.forward(train_split)[0][:,0]


tensor([ 0.9257, -0.5453, -0.5206, -1.2872,  0.0755,  0.1618,  0.5747, -0.3535,
         0.4618,  1.3920,  0.2584,  0.1435,  0.3212, -0.4797, -0.9153,  1.2592,
         0.5297,  1.1969,  1.6084, -0.4589,  0.1943,  0.2919,  1.4089,  0.1057,
         0.4172, -0.8489,  0.3337,  0.5900,  0.2462,  0.3558, -1.1019, -0.4412,
        -0.0757,  1.5962, -1.2163,  0.1902, -0.7009, -0.7492, -0.1262,  0.9387,
         0.3507, -1.1474, -0.1238,  0.5947,  0.6348,  0.1692,  0.1020,  1.4759,
         1.7373, -0.1425,  0.3549,  0.3926, -2.9147, -0.1650, -0.5547,  0.7841,
        -0.0676, -0.1312, -0.5740,  1.1438,  0.0608, -2.6758,  0.3223,  0.4712,
         0.1515, -0.9067, -0.5550, -0.8153, -0.2218, -1.6152, -1.1686,  2.3239,
         0.9329,  0.1666, -0.4713, -0.0843,  0.6043,  0.1300,  0.2076,  1.1765,
         1.2182,  1.0393,  0.3523,  0.4985, -0.7186,  0.1035, -1.0049,  0.8152,
         0.3610, -0.8980, -0.0203, -0.0801,  0.5917,  0.0461,  2.0826, -0.4359,
        -1.4237, -0.1313, -0.9362,  2.00

In [17]:
errors = training_data.scalings[2].denorm(net.forward(valid_split)[0])[:,0]-training_data.scalings[2].denorm(torch.Tensor(np.array([a[1] for a in valid_split])))

In [18]:
torch.mean(errors**2)

tensor(2.7096, grad_fn=<MeanBackward0>)