In [None]:
import cace
from cace.tools import torch_geometric
from cace.data import AtomicData
import torch
import torch.nn as nn

import ase
import ase.io

In [None]:
model_path = "./best_model.pth"
cace_nnp = torch.load(model_path, weights_only=False, map_location=torch.device('cpu'))

In [None]:
print(cace_nnp)

In [None]:
dir(cace_nnp)


In [None]:
isinstance(cace_nnp, torch.nn.Module)

In [None]:
config_file = "1-doped.xyz"
config = ase.io.read(config_file)

In [None]:
# prepare data
cutoff = 5.5
data_loader = torch_geometric.dataloader.DataLoader(
    dataset=[
        AtomicData.from_atoms(
            config, cutoff=cutoff
        )
    ],
    batch_size=1,
    shuffle=False,
    drop_last=False,
)

batch_base = next(iter(data_loader))
batch = batch_base.clone()

In [None]:
output = cace_nnp(batch.to_dict(), training=True)

In [None]:
output

```
{'CACE_energy': tensor([-0.8672], grad_fn=<SumBackward1>),
 'CACE_forces': tensor([[-1.7958e-01, -1.7900e-01, -4.4851e-01],
         [ 7.3744e-04, -1.7397e-03,  6.0522e-01],
         [ 6.6829e-02,  6.6681e-02,  5.1197e-01],
         [-1.3696e-04,  2.8252e-04, -2.1782e-01],
         [-4.7263e-03, -4.7280e-03, -7.5841e-01],
         [ 4.6168e-05, -9.0119e-05,  5.6561e-01],
         [ 1.7629e-03,  1.7628e-03,  1.2225e+00],
         [-2.0259e-06,  6.2241e-06, -6.3129e-01],
         [-2.2012e-04, -2.2026e-04,  4.8308e-01],
         [-8.4812e-07, -5.2516e-06, -3.5642e-01],
         [-1.6054e-05, -1.3541e-05, -1.2626e+00],
         [-6.1457e-07,  4.4424e-07,  7.4960e-01],
         [-1.7986e-01,  1.8003e-01, -4.4969e-01],
         [ 2.6261e-04,  2.3255e-02,  2.4504e-01],
         [ 6.8505e-02, -6.8553e-02,  5.0100e-01],
         [-1.6164e-03,  5.0609e-02, -5.5275e-01],
         [-2.3773e-02,  2.3775e-02, -4.9607e-01],
         [ 4.4597e-02, -5.3495e-02,  5.3141e-01],
         [ 8.1630e-02, -8.1628e-02,  3.5991e-01],
         [-1.7466e-01,  1.7603e-01, -4.7133e-01],
         [-1.2092e-01,  1.2092e-01,  2.6380e-01],
         [ 6.7240e-01, -6.7264e-01, -3.5470e-01],
         [ 6.4911e-02, -6.4911e-02, -3.5534e-01],
         [-1.6821e-01,  1.6821e-01,  5.3846e-01],
         [ 3.3063e-02, -7.9771e-04, -1.0469e-01],
         [ 5.3887e-06, -2.2173e-02,  2.4396e-01],
         [-2.0083e-02,  1.9813e-03,  5.6694e-01],
         [ 1.5904e-03, -5.0805e-02, -5.5238e-01],
         [ 1.5227e-02, -1.8937e-02, -5.1282e-01],
         [-4.4572e-02,  5.3572e-02,  5.3126e-01],
         [-7.8607e-02,  7.9778e-02,  3.6405e-01],
         [ 1.7466e-01, -1.7604e-01, -4.7132e-01],
         [ 1.2067e-01, -1.2078e-01,  2.6332e-01],
         [-6.7240e-01,  6.7264e-01, -3.5470e-01],
         [-6.4965e-02,  6.4924e-02, -3.5539e-01],
         [ 1.6822e-01, -1.6821e-01,  5.3847e-01],
         [ 1.7954e-01, -1.7937e-01, -4.4801e-01],
         [ 2.2418e-02, -1.2329e-04,  2.4420e-01],
         [-6.8613e-02,  6.8564e-02,  5.0073e-01],
         [ 5.0756e-02, -1.5677e-03, -5.5246e-01],
         [ 2.3819e-02, -2.3820e-02, -4.9608e-01],
         [-5.3552e-02,  4.4556e-02,  5.3129e-01],
         [-8.1634e-02,  8.1635e-02,  3.5989e-01],
         [ 1.7604e-01, -1.7466e-01, -4.7132e-01],
         [ 1.2092e-01, -1.2092e-01,  2.6380e-01],
         [-6.7264e-01,  6.7240e-01, -3.5470e-01],
         [-6.4912e-02,  6.4913e-02, -3.5534e-01],
         [ 1.6821e-01, -1.6822e-01,  5.3847e-01],
         [ 1.7914e-01,  1.7972e-01, -4.4909e-01],
         [ 3.6582e-02,  3.6515e-02,  9.4006e-02],
         [-6.6662e-02, -6.6810e-02,  5.1206e-01],
         [ 1.9213e-02,  1.9254e-02, -5.5882e-01],
         [ 4.7114e-03,  4.7105e-03, -7.5841e-01],
         [-9.6703e-03, -9.6837e-03,  5.8675e-01],
         [-1.7619e-03, -1.7606e-03,  1.2225e+00],
         [ 8.8576e-04,  8.8466e-04, -6.4575e-01],
         [ 2.2148e-04,  2.2032e-04,  4.8309e-01],
         [-1.9423e-05, -1.7329e-05, -3.5535e-01],
         [ 1.3966e-05,  1.3622e-05, -1.2626e+00],
         [-2.7854e-05, -2.7116e-05,  7.4970e-01],
         [-3.3192e-02, -7.3075e-05, -1.0469e-01],
         [ 3.6424e-02, -3.6447e-02,  9.1966e-02],
         [ 2.0080e-02, -1.7267e-03,  5.6702e-01],
         [ 1.7606e-02, -1.7593e-02, -5.4549e-01],
         [-1.5188e-02,  1.9034e-02, -5.1293e-01],
         [ 3.4889e-02, -3.4897e-02,  5.3110e-01],
         [ 7.8605e-02, -7.9779e-02,  3.6405e-01],
         [-1.7394e-01,  1.7393e-01, -4.7348e-01],
         [-1.2067e-01,  1.2078e-01,  2.6332e-01],
         [ 6.7251e-01, -6.7251e-01, -3.5434e-01],
         [ 6.4962e-02, -6.4923e-02, -3.5539e-01],
         [-1.6825e-01,  1.6825e-01,  5.3855e-01],
         [-1.3700e-04,  3.3248e-02, -1.0470e-01],
         [-2.3011e-02, -3.8092e-04,  2.4479e-01],
         [ 1.7895e-03, -2.0074e-02,  5.6705e-01],
         [-5.0657e-02,  1.6369e-03, -5.5266e-01],
         [-1.9011e-02,  1.5170e-02, -5.1298e-01],
         [ 5.3514e-02, -4.4611e-02,  5.3137e-01],
         [ 7.9779e-02, -7.8605e-02,  3.6405e-01],
         [-1.7603e-01,  1.7466e-01, -4.7133e-01],
         [-1.2078e-01,  1.2068e-01,  2.6332e-01],
         [ 6.7265e-01, -6.7240e-01, -3.5470e-01],
         [ 6.4923e-02, -6.4963e-02, -3.5539e-01],
         [-1.6821e-01,  1.6822e-01,  5.3847e-01],
         [ 5.8586e-04, -3.3003e-02, -1.0468e-01],
         [-3.6830e-02,  3.6805e-02,  9.2216e-02],
         [-1.9182e-03,  2.0086e-02,  5.6691e-01],
         [-1.7624e-02,  1.7636e-02, -5.4576e-01],
         [ 1.8958e-02, -1.5247e-02, -5.1277e-01],
         [-3.4941e-02,  3.4934e-02,  5.3088e-01],
         [-7.9778e-02,  7.8606e-02,  3.6405e-01],
         [ 1.7394e-01, -1.7394e-01, -4.7350e-01],
         [ 1.2078e-01, -1.2067e-01,  2.6332e-01],
         [-6.7251e-01,  6.7251e-01, -3.5434e-01],
         [-6.4922e-02,  6.4963e-02, -3.5539e-01],
         [ 1.6825e-01, -1.6825e-01,  5.3855e-01],
         [ 1.8526e-04, -3.5584e-04, -2.9660e-02],
         [-3.6395e-02, -3.6461e-02,  9.3941e-02],
         [ 2.1448e-05, -3.8322e-05,  5.5794e-01],
         [-1.9244e-02, -1.9204e-02, -5.5875e-01],
         [-3.5484e-05,  6.5508e-05, -7.8907e-01],
         [ 9.6986e-03,  9.6851e-03,  5.8682e-01],
         [-3.3075e-07,  7.7043e-07,  1.2294e+00],
         [-8.8660e-04, -8.8542e-04, -6.4576e-01],
         [ 2.4317e-06, -4.4467e-07,  4.8165e-01],
         [ 1.8162e-05,  1.4909e-05, -3.5535e-01],
         [ 2.5186e-06, -3.2997e-06, -1.2627e+00],
         [ 2.5182e-05,  2.9126e-05,  7.4970e-01],
         [ 2.4638e-04,  2.9811e-04, -1.5462e-02],
         [-7.6467e-06, -2.9034e-04,  2.2911e-02]], grad_fn=<SumBackward1>),
 'ewald_forces': tensor([[ 2.3439e-01,  2.3453e-01, -4.2193e-01],
         [ 3.7892e-04, -4.9347e-04,  1.3044e+00],
         [-7.0494e-02, -7.0603e-02,  4.0045e-01],
         [ 3.5615e-05, -6.6343e-05, -5.3789e-01],
         [-1.2558e-03, -1.2301e-03, -5.6650e-02],
         [ 3.1198e-05, -5.6528e-05,  8.3658e-02],
         [ 1.5968e-03,  1.5979e-03, -1.2538e+00],
         [-3.0762e-07, -4.2430e-06,  5.0943e-01],
         [-2.1970e-04, -2.2002e-04,  6.7592e-01],
         [ 3.5952e-06,  1.9119e-06, -5.1214e-01],
         [-1.4516e-05, -1.3790e-05,  5.1758e-01],
         [-1.3633e-06, -1.2126e-06,  1.4202e-01],
         [ 2.3365e-01, -2.3365e-01, -4.2267e-01],
         [ 2.6312e-04,  1.5398e-01,  3.2279e-01],
         [-7.3078e-02,  7.3042e-02,  3.8219e-01],
         [ 1.7639e-03, -3.8597e-02, -3.6006e-01],
         [-3.4830e-02,  3.4839e-02, -1.1276e-01],
         [ 1.4236e-02, -5.5298e-03,  8.4928e-02],
         [-8.4961e-02,  8.4960e-02, -4.7338e-01],
         [ 9.8848e-02, -9.6847e-02,  4.2856e-01],
         [ 1.0718e-01, -1.0718e-01,  5.5866e-01],
         [-5.9489e-01,  5.9465e-01, -5.7719e-01],
         [-6.6489e-02,  6.6487e-02, -1.9634e-01],
         [ 1.0280e-01, -1.0280e-01,  2.3225e-01],
         [ 9.6843e-02, -7.5151e-04, -2.5246e-01],
         [-2.6779e-04, -1.5433e-01,  3.2275e-01],
         [-3.5460e-02, -2.4278e-03,  3.2262e-01],
         [-1.7107e-03,  3.8693e-02, -3.5955e-01],
         [ 3.0290e-02, -3.3509e-02, -8.2309e-02],
         [-1.4223e-02,  5.5814e-03,  8.4818e-02],
         [ 8.7822e-02, -8.6649e-02, -4.6747e-01],
         [-9.8850e-02,  9.6850e-02,  4.2855e-01],
         [-1.0742e-01,  1.0731e-01,  5.5818e-01],
         [ 5.9489e-01, -5.9465e-01, -5.7719e-01],
         [ 6.6435e-02, -6.6476e-02, -1.9639e-01],
         [-1.0280e-01,  1.0280e-01,  2.3225e-01],
         [-2.3414e-01,  2.3414e-01, -4.2220e-01],
         [ 1.5418e-01,  2.9883e-04,  3.2271e-01],
         [ 7.3046e-02, -7.3082e-02,  3.8239e-01],
         [-3.8670e-02,  1.6904e-03, -3.5967e-01],
         [ 3.4924e-02, -3.4916e-02, -1.1280e-01],
         [-5.5686e-03,  1.4217e-02,  8.4843e-02],
         [ 8.4957e-02, -8.4959e-02, -4.7336e-01],
         [-9.6848e-02,  9.8850e-02,  4.2856e-01],
         [-1.0718e-01,  1.0718e-01,  5.5866e-01],
         [ 5.9465e-01, -5.9489e-01, -5.7719e-01],
         [ 6.6486e-02, -6.6487e-02, -1.9634e-01],
         [-1.0280e-01,  1.0280e-01,  2.3225e-01],
         [-2.3433e-01, -2.3419e-01, -4.2201e-01],
         [ 5.2129e-02,  5.2034e-02,  2.4370e-01],
         [ 7.0604e-02,  7.0494e-02,  4.0039e-01],
         [-3.0555e-02, -3.0496e-02, -3.1026e-01],
         [ 1.2023e-03,  1.2272e-03, -5.6640e-02],
         [-6.5610e-04, -6.8471e-04,  7.5294e-02],
         [-1.6008e-03, -1.5980e-03, -1.2538e+00],
         [ 9.3533e-04,  9.3444e-04,  4.9138e-01],
         [ 2.1854e-04,  2.1799e-04,  6.7592e-01],
         [-2.1373e-05, -2.1438e-05, -5.1107e-01],
         [ 1.3310e-05,  1.3094e-05,  5.1758e-01],
         [-2.5917e-05, -2.4787e-05,  1.4212e-01],
         [-9.6962e-02,  2.3187e-04, -2.5256e-01],
         [ 5.2411e-02, -5.2441e-02,  2.4374e-01],
         [ 3.5436e-02,  2.5147e-03,  3.2273e-01],
         [-2.8850e-02,  2.8870e-02, -2.9746e-01],
         [-3.0236e-02,  3.3601e-02, -8.2452e-02],
         [ 1.3539e-02, -1.3548e-02,  7.7408e-02],
         [-8.7826e-02,  8.6648e-02, -4.6747e-01],
         [ 9.9626e-02, -9.9624e-02,  4.2472e-01],
         [ 1.0742e-01, -1.0731e-01,  5.5818e-01],
         [-5.9479e-01,  5.9479e-01, -5.7683e-01],
         [-6.6437e-02,  6.6476e-02, -1.9639e-01],
         [ 1.0277e-01, -1.0277e-01,  2.3233e-01],
         [-3.5648e-04,  9.7019e-02, -2.5261e-01],
         [-1.5412e-01, -2.3287e-04,  3.2283e-01],
         [-2.4933e-03, -3.5422e-02,  3.2278e-01],
         [ 3.8619e-02, -1.7851e-03, -3.5993e-01],
         [-3.3577e-02,  3.0210e-02, -8.2515e-02],
         [ 5.5413e-03, -1.4240e-02,  8.4902e-02],
         [-8.6646e-02,  8.7822e-02, -4.6747e-01],
         [ 9.6850e-02, -9.8851e-02,  4.2856e-01],
         [ 1.0731e-01, -1.0742e-01,  5.5818e-01],
         [-5.9465e-01,  5.9489e-01, -5.7719e-01],
         [-6.6477e-02,  6.6435e-02, -1.9639e-01],
         [ 1.0280e-01, -1.0280e-01,  2.3225e-01],
         [ 6.2572e-04, -9.6787e-02, -2.5241e-01],
         [-5.2346e-02,  5.2313e-02,  2.4356e-01],
         [ 2.4488e-03,  3.5470e-02,  3.2257e-01],
         [ 2.8776e-02, -2.8758e-02, -2.9760e-01],
         [ 3.3531e-02, -3.0317e-02, -8.2245e-02],
         [-1.3594e-02,  1.3583e-02,  7.7148e-02],
         [ 8.6650e-02, -8.7825e-02, -4.6748e-01],
         [-9.9622e-02,  9.9624e-02,  4.2471e-01],
         [-1.0731e-01,  1.0742e-01,  5.5818e-01],
         [ 5.9479e-01, -5.9479e-01, -5.7683e-01],
         [ 6.6477e-02, -6.6438e-02, -1.9639e-01],
         [-1.0277e-01,  1.0277e-01,  2.3233e-01],
         [ 1.5863e-04, -3.0865e-04, -2.3184e-01],
         [-5.2078e-02, -5.2173e-02,  2.4377e-01],
         [ 3.7935e-05, -7.4372e-05,  3.0947e-01],
         [ 3.0524e-02,  3.0583e-02, -3.1022e-01],
         [-5.0645e-05,  9.6702e-05, -2.4708e-02],
         [ 6.9954e-04,  6.7047e-04,  7.5379e-02],
         [-7.5510e-07,  2.9405e-07, -1.2447e+00],
         [-9.3335e-04, -9.3312e-04,  4.9138e-01],
         [ 1.2732e-06,  7.9958e-07,  6.7449e-01],
         [ 2.0209e-05,  1.9460e-05, -5.1107e-01],
         [-8.6020e-07, -3.0314e-06,  5.1745e-01],
         [ 2.6016e-05,  2.7647e-05,  1.4212e-01],
         [ 2.5232e-05, -6.5890e-04, -2.2419e-01],
         [-4.4619e-04,  1.4254e-03, -4.8457e-01]], grad_fn=<MulBackward0>),
 'q': tensor([[-0.0768],
         [-0.3036],
         [ 0.1244],
         [-0.0505],
         [ 0.1689],
         [ 0.0734],
         [-0.5993],
         [-0.0663],
         [ 0.1358],
         [-0.6272],
         [-0.5350],
         [-0.5851],
         [-0.0764],
         [ 0.1782],
         [ 0.1242],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [ 0.0702],
         [ 0.1777],
         [ 0.1420],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [-0.0770],
         [ 0.1779],
         [ 0.1245],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [-0.0766],
         [ 0.0589],
         [ 0.1243],
         [-0.0502],
         [ 0.1689],
         [ 0.0734],
         [-0.5993],
         [-0.0663],
         [ 0.1358],
         [-0.6272],
         [-0.5350],
         [-0.5851],
         [ 0.0702],
         [ 0.0589],
         [ 0.1420],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [ 0.0702],
         [ 0.1781],
         [ 0.1420],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [ 0.0702],
         [ 0.0589],
         [ 0.1420],
         [-0.0503],
         [ 0.0259],
         [ 0.0737],
         [ 0.1122],
         [-0.0150],
         [-0.0469],
         [-0.6660],
         [ 0.0317],
         [-0.4945],
         [ 0.0693],
         [ 0.0589],
         [ 0.1420],
         [-0.0503],
         [ 0.1689],
         [ 0.0734],
         [-0.5993],
         [-0.0663],
         [ 0.1358],
         [-0.6272],
         [-0.5350],
         [-0.5851],
         [-0.4255],
         [ 0.9177]], grad_fn=<AddBackward0>),
 'ewald_potential': tensor([3.5615], grad_fn=<MulBackward0>),
 'tot_q': tensor([-10.0429], grad_fn=<SqueezeBackward1>)}
 ```

In [None]:
output['q'].detach().numpy()

In [None]:
species_list = config.get_chemical_symbols()
Au_indices = [idx for idx in range(len(species_list)) if species_list[idx] == 'Au']


In [None]:
config.get_positions()

In [None]:
fac = (90.0474)**0.5
def calculate_charges(config, cace_nnp, cutoff=5.5):
    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[
            AtomicData.from_atoms(
                config, cutoff=cutoff
            )
        ],
        batch_size=1,
        shuffle=False,
        drop_last=False,
    )

    batch_base = next(iter(data_loader))
    batch = batch_base.clone()

    output = cace_nnp(batch.to_dict(), training=True)
    charges = output['q'].detach().numpy()
    return charges/fac

In [None]:
undoped1_config = ase.io.read('1-undoped.xyz')
doped1_config   = ase.io.read('1-doped.xyz')
undoped3_config = ase.io.read('3-undoped.xyz')
doped3_config = ase.io.read('3-doped.xyz')

my_doped1_config = ase.io.read('my-1-doped.xyz')

In [None]:
test_config = my_doped1_config

species_list = config.get_chemical_symbols()
Au_indices = [idx for idx in range(len(species_list)) if species_list[idx] == 'Au']

atomic_charges = calculate_charges(test_config, cace_nnp)

print(atomic_charges[Au_indices])

undoped1_config Au charges 
- 3.7371927e-05
- 1.2652083e-01

(my doped1_config) Au charges (Au's are in the exact same position as above, no additional relaxation)
- 3.7371927e-05
- 1.2652083e-01

doped1_config Au charges 
- -0.04484142
- 0.0967114 
 
undoped3_config 
- -0.20415413
- -0.20426635

doped3_config
- -0.213786  
- -0.21378516

So there isn't any effective charge transfer to the hydrogens, it's just that the charge is super sensitive to the positions. 
If this wasn't the case, I was going to have to go into a much deeper dive into this model to figure out why. 

Should run MD to see how variable these charges are, but that is ultimately the reason for why in Figure 7 of that paper it seems like there is a response charge on the hydrogen. 