In [12]:
import ase
import time
import sys
import torch
from torch.nn import MSELoss
from skorch import NeuralNetRegressor
from skorch.dataset import CVSplit
from skorch.callbacks import Checkpoint, EpochScoring
from skorch.callbacks.lr_scheduler import LRScheduler
from amptorch.gaussian import Gaussian
from amptorch.model import FullNN, CustomLoss, TanhLoss
from amptorch.data_preprocess import AtomsDataset, factorize_data, collate_amp, TestDataset
from md_work.md_utils import md_run, calculate_energies, calculate_forces, time_plots, kde_plots
from amptorch.skorch import AMP
from amptorch.skorch.utils import target_extractor, energy_score, forces_score
from amptorch.lj_model import lj_optim
from torch.utils.data import DataLoader
from torch.nn import init
from skorch.utils import to_numpy
import numpy as np
from ase import Atoms, units
from ase.calculators.emt import EMT
from ase.md import Langevin

In [13]:
# define symmetry functions to be used
Gs = {}
Gs["G2_etas"] = np.logspace(np.log10(0.05), np.log10(5.0), num=4)
Gs["G2_rs_s"] = [0] * 4
Gs["G4_etas"] = [0.005]
Gs["G4_zetas"] = [1.0, 4.0]
Gs["G4_gammas"] = [+1.0, -1]
Gs["cutoff"] = 6.0

In [14]:
# LJ Optimization
def lj_optimization(images, Gs, label):
    cutoff = Gs["cutoff"]
    p0 = [
        1.33905162,
        0.12290683,
        6.41914719,
        0.64021468,
        0.08010004,
        8.26082762,
        2.29284676,
        0.29639983,
        0.08071821,
    ]
    params_dict = {"C": [], "O": [], "Cu": []}
    lj_model = lj_optim(images, p0, params_dict, cutoff, label)
    fitted_params = lj_model.fit()
    lj_energies, lj_forces, num_atoms = lj_model.lj_pred(
        images, fitted_params, params_dict
    )
    lj_data = [
        lj_energies,
        lj_forces,
        num_atoms,
        fitted_params,
        params_dict,
        lj_model,
    ]
    return lj_data

In [15]:
# Define Training data
label = "skorch_example"
images = ase.io.read("../datasets/COCu/COCu_pbc_300K.traj", ":100")
lj_data = lj_optimization(images, Gs, label)
forcetraining = True
training_data = AtomsDataset(images, Gaussian, Gs, forcetraining=forcetraining,
        label=label, cores=4, lj_data=lj_data)
scalings = training_data.scalings
unique_atoms = training_data.elements
fp_length = training_data.fp_length
device = "cpu"

LJ optimization initiated...
Optimizer terminated successfully.
Calculating fingerprints...
Fingerprints Calculated!


In [16]:
import skorch.callbacks.base
class train_end_load_best_valid_loss(skorch.callbacks.base.Callback):
    def on_train_end(self, net, X, y):
        net.load_params('./results/checkpoints/valid_best_params.pt')

LR_schedule = LRScheduler('CosineAnnealingLR', T_max=5)
cp = Checkpoint(monitor='valid_loss_best', fn_prefix='./results/checkpoints/valid_best_')
load_best_valid_loss = train_end_load_best_valid_loss()


net = NeuralNetRegressor(
    module=FullNN(unique_atoms, [fp_length, 5, 5], device, forcetraining=forcetraining),
    criterion=TanhLoss,
    criterion__force_coefficient=0.3,
    optimizer=torch.optim.Adam,
    lr=1e-2,
    batch_size=10,
    max_epochs=500,
    iterator_train__collate_fn=collate_amp,
    iterator_train__shuffle=True,
    iterator_valid__collate_fn=collate_amp,
    iterator_valid__shuffle=False,
    device=device,
    train_split=CVSplit(0.2, random_state=1),
    callbacks=[
        EpochScoring(
            forces_score,
            on_train=True,
            use_caching=True,
            target_extractor=target_extractor,
        ),
        EpochScoring(
            energy_score,
            on_train=True,
            use_caching=True,
            target_extractor=target_extractor,
        ), cp,  load_best_valid_loss, LR_schedule
    ],
)

In [17]:
# Define calculator and train
calc = AMP(training_data, net, label=label)
calc.train(overwrite=True)



  epoch    energy_score    forces_score    train_loss    valid_loss    cp     dur
-------  --------------  --------------  ------------  ------------  ----  ------
      1          [36m0.0034[0m          [32m0.0598[0m       [35m41.8852[0m       [31m41.9433[0m     +  0.3835
      2          [36m0.0012[0m          [32m0.0463[0m       [35m40.2051[0m       [31m40.9248[0m     +  0.3734
      3          [36m0.0005[0m          [32m0.0416[0m       [35m39.7332[0m       [31m40.4871[0m     +  0.3817
      4          [36m0.0002[0m          [32m0.0415[0m       [35m39.3255[0m       40.6496        0.3819
      5          0.0004          [32m0.0403[0m       [35m39.1676[0m       40.6608        0.3708
      6          0.0004          0.0442       [35m39.1109[0m       40.6608        0.3747
      7          0.0004          0.0436       [35m39.0626[0m       40.5153        0.3785
      8          0.0003          0.0485       39.0646       [31m40.3304[0m     +  0.3823


     90          0.0004          0.0439       35.5625       37.6379        0.3721
     91          0.0004          0.0443       35.5827       37.8477        0.3894
     92          0.0004          0.0450       35.5664       37.6232        0.3729
     93          0.0003          0.0475       35.5219       37.7129        0.3711
     94          0.0006          0.0437       35.5431       [31m37.4917[0m     +  0.3763
     95          0.0003          0.0445       [35m35.3793[0m       [31m37.3662[0m     +  0.3699
     96          0.0003          0.0464       [35m35.3086[0m       37.3662        0.3772
     97          0.0004          0.0454       35.3249       37.3799        0.3657
     98          0.0004          0.0486       35.3668       37.5806        0.3706
     99          0.0003          0.0482       35.4898       37.7017        0.3901
    100          0.0003          0.0483       35.6626       38.1907        0.4311
    101          0.0005          0.0467       35.7272       37

    188          0.0003          0.0448       34.9983       37.0963        0.4355
    189          0.0003          0.0498       35.1077       37.4098        0.3478
    190          0.0004          0.0470       35.2453       37.5916        0.4067
    191          0.0002          0.0448       35.3744       37.5404        0.3768
    192          0.0004          0.0450       35.4652       37.7452        0.3428
    193          0.0007          0.0421       35.4095       37.6134        0.3305
    194          0.0005          0.0494       35.1604       37.2433        0.3717
    195          0.0002          0.0468       35.0218       37.1796        0.3641
    196          0.0002          0.0465       34.9810       37.1796        0.3677
    197          0.0002          0.0470       34.9995       37.2152        0.3818
    198          0.0002          0.0453       35.0207       37.1709        0.3922
    199          0.0002          0.0457       35.2239       37.2320        0.3791
    200         

    287          0.0002          0.0450       34.8604       37.0353        0.3771
    288          0.0002          0.0451       34.8197       37.1343        0.3425
    289          0.0004          0.0448       34.9327       37.1250        0.3855
    290          0.0008          0.0413       34.9305       37.0683        0.3892
    291          [36m0.0001[0m          0.0436       34.8795       37.1284        0.3274
    292          0.0003          0.0422       35.0589       37.3882        0.3271
    293          0.0002          0.0457       35.0911       37.3402        0.3705
    294          0.0006          0.0431       34.9717       37.2616        0.3792
    295          0.0006          0.0459       34.8938       37.0867        0.3673
    296          0.0002          0.0434       34.7887       37.0867        0.3586
    297          0.0002          0.0470       34.7996       37.0172        0.3797
    298          0.0002          0.0459       34.7886       36.9558        0.3840
    299

    386          0.0002          0.0495       34.6884       36.8775        0.3884
    387          0.0003          0.0499       [35m34.6212[0m       36.9635        0.4038
    388          0.0003          0.0391       34.7481       36.9392        0.3727
    389          0.0003          0.0429       34.7460       37.1464        0.4088
    390          0.0003          0.0476       34.8379       37.1722        0.3720
    391          0.0004          0.0413       35.0614       37.2161        0.3297
    392          0.0005          0.0451       35.0478       37.0762        0.4189
    393          0.0003          0.0421       34.9799       37.2148        0.3703
    394          0.0003          0.0464       34.8403       37.0398        0.4124
    395          0.0005          0.0432       34.7761       36.9922        0.3568
    396          0.0003          0.0454       34.7187       36.9922        0.3849
    397          0.0003          0.0490       34.7043       36.9719        0.3715
    398

    485          0.0003          0.0457       34.6609       36.8951        0.3727
    486          0.0003          0.0428       [35m34.6054[0m       36.8951        0.3765
    487          [36m0.0001[0m          0.0442       34.6348       36.8909        0.3741
    488          0.0002          0.0426       34.7377       37.0431        0.3745
    489          0.0006          0.0388       34.7890       37.1191        0.3740
    490          0.0006          0.0418       34.9467       37.0686        0.3772
    491          0.0006          0.0420       34.9012       36.8490        0.3740
    492          0.0003          0.0417       34.8554       36.9855        0.3740
    493          0.0003          0.0463       34.8695       36.9405        0.3749
    494          0.0006          0.0441       34.7528       36.8397        0.3531
    495          0.0004          0.0418       34.6168       36.7544        0.4045
    496          0.0003          0.0463       [35m34.5950[0m       36.7544    

In [None]:
# MD Simulation
md_run(calc=calc, starting_image=images[0].copy(), temp=300, count=100, label=label)

In [None]:
# Calculate forces of base and generated trajectory
ml_images = ase.io.read(label+".traj", ":")
emt_energy, ml_apparent_energy, ml_actual_energy = calculate_energies(images, ml_images)
emt_forces, ml_apparent_forces, ml_actual_forces = calculate_forces(images, ml_images, type="max")

In [None]:
# Time Plots
import matplotlib.pyplot as plt
%matplotlib inline

time_plots(emt_energy, [ml_actual_energy], None, ['ML-LJ'], 'energy', None )
time_plots(emt_forces, [ml_actual_forces], None, ['ML-LJ'], 'forces', None )

In [None]:
kde_plots(emt_forces, [ml_apparent_forces, ml_actual_forces] , ['ML-LJ apparent', 'ML-LJ actual'])

In [61]:
# Resample MD Simulation
import random
import copy
sample_points = random.sample(range(1, len(ml_images)), 10) #sample 10 points
images = ase.io.read("../../datasets/COCu/COCu_pbc_300K.traj", ":100")
resampled_images = copy.copy(images)
for i in sample_points:
    ml_image = ml_images[i].copy()
    ml_image.set_calculator(EMT())
    resampled_images.append(ml_image)

# Define Training data
label = "skorch_resample"
lj_data = lj_optimization(resampled_images, Gs, label)
forcetraining = True
training_data = AtomsDataset(images, Gaussian, Gs, forcetraining=forcetraining,
        label=label, cores=4, lj_data=lj_data)
scalings = training_data.scalings
unique_atoms = training_data.elements
fp_length = training_data.fp_length
device = "cpu"

# Train
net = NeuralNetRegressor(
    module=FullNN(unique_atoms, [fp_length, 5, 5], device, forcetraining=forcetraining),
    criterion=TanhLoss,
    criterion__force_coefficient=0.3,
    optimizer=torch.optim.Adam,
    lr=1e-2,
    batch_size=20,
    max_epochs=500,
    iterator_train__collate_fn=collate_amp,
    iterator_train__shuffle=True,
    iterator_valid__collate_fn=collate_amp,
    iterator_valid_shuffle=False,
    device=device,
    train_split=CVSplit(0.2, random_state=0),
    callbacks=[
        EpochScoring(
            forces_score,
            on_train=True,
            use_caching=True,
            target_extractor=target_extractor,
        ),
        EpochScoring(
            energy_score,
            on_train=True,
            use_caching=True,
            target_extractor=target_extractor,
        ), cp, load_best_valid_loss, LR_schedule
    ],
)

# Define calculator and train
calc = AMP(training_data, net, label=label)
calc.train(overwrite=True)

LJ optimization initiated...
Optimizer terminated successfully.
Calculating fingerprints...
Fingerprints Calculated!
  epoch    energy_score    forces_score    train_loss    valid_loss    cp     dur
-------  --------------  --------------  ------------  ------------  ----  ------
      1          [36m0.1102[0m          [32m6.0268[0m       [35m25.8569[0m       [31m22.6212[0m     +  0.3630
      2          [36m0.0925[0m          [32m2.9995[0m       [35m21.0314[0m       [31m17.0310[0m     +  0.2722
      3          [36m0.0821[0m          3.3918       [35m15.4188[0m       [31m13.9845[0m     +  0.2570
      4          [36m0.0572[0m          3.0379       [35m13.3430[0m       [31m12.8867[0m     +  0.2602
      5          [36m0.0324[0m          [32m2.3715[0m       [35m12.7989[0m       [31m12.8146[0m     +  0.2590
      6          [36m0.0318[0m          [32m1.7649[0m       [35m12.7018[0m       12.8146        0.2750
      7          [36m0.0317[0m      

     83          0.0065          0.0860        [35m5.9910[0m        5.8288        0.2493
     84          0.0062          0.0923        5.9940        5.8190        0.3309
     85          0.0068          0.0881        [35m5.9765[0m        [31m5.7982[0m     +  0.2730
     86          0.0056          0.0886        [35m5.9661[0m        5.7982        0.2575
     87          0.0061          0.0893        [35m5.9599[0m        [31m5.7939[0m     +  0.2591
     88          0.0063          0.1034        [35m5.9591[0m        [31m5.7799[0m     +  0.2710
     89          [36m0.0044[0m          0.0954        [35m5.9574[0m        5.7825        0.2880
     90          0.0058          0.0861        [35m5.9499[0m        5.8287        0.2466
     91          0.0070          0.0860        5.9823        5.8568        0.2213
     92          0.0063          0.0912        5.9799        [31m5.7765[0m     +  0.2395
     93          0.0068          0.0965        5.9600        [31m5.7607

    176          0.0062          0.0889        [35m5.6846[0m        5.5186        0.2805
    177          0.0055          0.0882        [35m5.6822[0m        [31m5.5099[0m     +  0.2690
    178          0.0052          0.0860        5.6873        5.5357        0.2618
    179          0.0050          0.0969        5.7119        5.5299        0.2623
    180          0.0058          0.1045        5.7520        5.7266        0.2594
    181          0.0069          0.0941        5.8505        5.5719        0.2691
    182          0.0066          0.0808        5.7901        5.5828        0.2611
    183          0.0053          0.0921        5.7765        5.5532        0.2583
    184          0.0058          0.0941        5.7165        5.5414        0.2340
    185          0.0047          0.0893        5.6839        5.5317        0.2345
    186          0.0056          0.0948        [35m5.6723[0m        5.5317        0.2483
    187          0.0051          0.0855        [35m5.6680[0m

    272          0.0064          0.0999        5.6151        5.3810        0.2636
    273          0.0062          0.0955        5.5633        5.3775        0.2694
    274          0.0064          0.0955        [35m5.5548[0m        [31m5.3533[0m     +  0.2346
    275          0.0050          0.0893        [35m5.5385[0m        5.3548        0.2293
    276          0.0055          0.0994        [35m5.5340[0m        5.3548        0.2521
    277          0.0066          0.0801        [35m5.5338[0m        [31m5.3504[0m     +  0.2666
    278          0.0051          0.0835        5.5450        5.3630        0.2656
    279          0.0061          0.0892        5.5515        5.3601        0.2434
    280          0.0060          0.0930        5.5783        5.4301        0.2747
    281          0.0050          0.0858        5.6227        5.4564        0.2513
    282          0.0049          0.0986        5.6287        5.4670        0.2657
    283          0.0051          0.0929     

    368          0.0050          0.1000        5.4525        5.2592        0.2605
    369          0.0042          0.0866        5.4568        5.2688        0.2522
    370          0.0052          0.0792        5.4898        5.3101        0.2275
    371          0.0068          0.0949        5.4968        5.3540        0.2282
    372          0.0086          0.0987        5.5213        5.2700        0.2470
    373          0.0115          0.0888        5.4739        5.2635        0.2874
    374          0.0103          0.0896        5.4555        [31m5.2319[0m     +  0.2519
    375          0.0039          0.0979        [35m5.4229[0m        [31m5.2274[0m     +  0.2539
    376          0.0068          0.0901        [35m5.4176[0m        5.2274        0.2642
    377          0.0044          0.0854        [35m5.4142[0m        5.2288        0.2580
    378          0.0049          0.0901        5.4248        5.2347        0.2602
    379          0.0057          0.0889        5.4235

    465          0.0052          0.0863        [35m5.3115[0m        [31m5.1014[0m     +  0.2265
    466          0.0057          0.0848        [35m5.3003[0m        5.1014        0.2899
    467          0.0055          0.0970        [35m5.2973[0m        [31m5.0955[0m     +  0.2879
    468          0.0053          0.0832        5.2994        5.1048        0.2279
    469          0.0050          0.0754        5.3266        5.1196        0.2642
    470          0.0053          0.0843        5.3935        5.5004        0.2549
    471          0.0047          0.0866        5.5912        5.2215        0.3020
    472          0.0077          0.0978        5.5561        5.4501        0.2404
    473          0.0062          0.0940        5.5477        5.1869        0.2907
    474          0.0098          0.0890        5.4088        5.1326        0.2799
    475          0.0054          0.0936        5.3469        5.1210        0.2675
    476          0.0051          0.0912        5.3344

In [62]:
# MD Simulation
md_run(calc=calc, starting_image=images[0].copy(), temp=300, count=100, label=label)

KeyError: 'be1e89e1cddc3f0ac60f2f7d71e69329'

In [None]:
# Calculate forces of base and generated trajectory
ml_resample_images = ase.io.read(label+".traj", ":")
emt_energy, ml_r_apparent_energy, ml_r_actual_energy = calculate_energies(images, ml_resample_images)
emt_forces, ml_r_apparent_forces, ml_r_actual_forces = calculate_forces(images, ml_resample_images, type="max")

In [None]:
# Time Plots

time_plots(emt_energy, [ml_actual_energy, ml_r_actual_energy], None, ['ML-LJ', 'ML-LJ resample'], 'energy', None )
time_plots(emt_forces, [ml_actual_forces, ml_r_actual_forces], None, ['ML-LJ', 'ML-LJ resample'], 'forces', None )

In [None]:
kde_plots(emt_forces, [ml_actual_forces, ml_r_actual_forces] , ['ML-LJ', 'ML-LJ resample'])