In [1]:
import json
import numpy as np
from dscribe.descriptors import ACSF
from pymatgen.core import Structure
from ase import Atoms
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
S_min = -1
S_max = 1
acsf_descriptors = []
target_energies = []

def precondition_symmetry_functions(symmetry_functions, S_min, S_max):
    # Step 1: Shifting the Mean to Zero
    mean_values = np.mean(symmetry_functions, axis=0)
    shifted_symmetry_functions = symmetry_functions - mean_values

    # Step 2: Rescaling to a Predefined Interval
    min_values = np.min(shifted_symmetry_functions, axis=0)
    max_values = np.max(shifted_symmetry_functions, axis=0)
    
    # Ensure no division by zero
    max_values[max_values == min_values] += 1e-12
    
    rescaled_symmetry_functions = (shifted_symmetry_functions - min_values) / (max_values - min_values) * (S_max - S_min) + S_min
    
    return rescaled_symmetry_functions


with open('position_force_train_all.json', 'r') as f:
    p_f = json.loads(f.read())

acsf = ACSF(
    species=["B"],
    r_cut=4.0,
    g2_params=[[1, 1], [1, 2], [1, 3]],
    g4_params=[[1, 1, 1], [1, 2, 1], [1, 1, -1], [1, 2, -1]],
    periodic=True
)

# Initialize empty lists to store ACSF descriptors and target energies
acsf_descriptors = []
target_energies = []

# for i in range(len(p_f)):
for i in range(1000):
    pmg_structure = Structure.from_dict(p_f[i]['structure'])
    energy = p_f[i]['energy']
    forces = p_f[i]['forces']
    bc = p_f[i]['bc']
    pbc = [0, 0, 0] if 'free' in bc else [1,1,1]
    ase_structure = Atoms(pmg_structure.composition.formula, positions=pmg_structure.cart_coords, cell=pmg_structure.lattice.matrix, pbc=pbc)
    acsf_value = acsf.create(ase_structure)
    acsf_descriptors.append(precondition_symmetry_functions(acsf_value, S_min, S_max))
#     print(precondition_symmetry_functions(acsf_value, S_min, S_max))
    target_energies.append(energy)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.nn.utils.rnn as rnn_utils
from sklearn.preprocessing import StandardScaler

packed_acsf_tensor = rnn_utils.pack_sequence([torch.tensor(seq) for seq in acsf_descriptors], enforce_sorted=False)

# Use the packed sequence directly or convert it to a padded sequence if needed
# Example: Convert packed sequence to a padded sequence
padded_acsf_tensor, _ = rnn_utils.pad_packed_sequence(packed_acsf_tensor, batch_first=True, padding_value=0.0)

# Assuming target_energies is a NumPy array containing the target total energies for each structure
target_energies = torch.tensor(target_energies, dtype=torch.float32)

# Combine ACSF descriptors and target energies into a PyTorch Dataset
dataset = TensorDataset(padded_acsf_tensor, target_energies)

# Split the dataset into training and validation sets (80% - 20%)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader for training and validation sets
batch_size = 500
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# Define your neural network architecture
class AtomEnergyNN(nn.Module):
    def __init__(self, input_size, hidden_size_1, hidden_size_2, output_size):
        super(AtomEnergyNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size_1, dtype=torch.double)
        self.relu1 = nn.ReLU()                                                         #self.tanh1 = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size_1, hidden_size_2, dtype=torch.double)
        self.relu2= nn.ReLU()
        self.fc3 = nn.Linear(hidden_size_2, output_size, dtype=torch.double)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# Set parameters
input_size = 8  # Number of features in ACSF descriptor for each atom
hidden_size_1 = 16  # Number of neurons in the hidden layer 1
hidden_size_2 = 16 # Number of neurons in the hidden layer 2
output_size = 1  # Single output for each atom's energy

# Create the neural network
model = AtomEnergyNN(input_size, hidden_size_1, hidden_size_2, output_size)

# Define the loss function and optimizer
criterion = torch.nn.MSELoss()  # Mean Squared Error loss
# optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam optimizer
optimizer = optim.SGD(model.parameters(), lr=0.05)     #SGD optimizer

In [3]:
import torch
import torch.optim as optim

# Define your optimizer, for example:
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Define your loss function, for example:
loss_function = torch.nn.MSELoss()
requires_grad = True

# Define batch size
batch_size = 100  # Adjust as needed

# Define number of epochs
num_epochs = 10  # Adjust as needed

# Training loop with epochs and batch training
for epoch in range(num_epochs):
    batch_no = 0
    # Iterate over the data in batches within each epoch
    for batch_start in range(0, len(padded_acsf_tensor), batch_size):
        # Zero the gradients
        optimizer.zero_grad()
        batch_no += 1
        batch_end = min(batch_start + batch_size, len(padded_acsf_tensor))
        batch_acsf_tensor = padded_acsf_tensor[batch_start:batch_end]

        # Compute total energy for each structure in the batch
        total_energies = []
        target_total_energies = []
        for struct_acsf_tensor, target_energy in zip(batch_acsf_tensor, target_energies[batch_start:batch_end]):
            struct_acsf_tensor = struct_acsf_tensor[struct_acsf_tensor.any(dim=1)]
            total_energy = model(struct_acsf_tensor).sum().item()
            total_energies.append(total_energy)
            target_total_energies.append(target_energy.item())

            # Print predicted and actual energies for each structure
            print(f"Epoch {epoch + 1}/{num_epochs} Batch {batch_no}")
            print('Predicted Energy and actual energy:', total_energy, target_energy.item())

        # Compute loss for the batch
        predicted_energies = torch.tensor(total_energies, dtype=torch.float32, requires_grad = True)
        target_energies_batch = torch.tensor(target_total_energies, dtype=torch.float32)
        loss = criterion(predicted_energies, target_energies_batch)
        print('Loss:', loss.item())

        # Check if gradients are being computed for fc1.weight
        for name, param in model.named_parameters():
            if name == 'fc2.weight':
                print(name, param.data)  # Ensure that fc1.weight requires gradients

#         # Zero the gradients
#         optimizer.zero_grad()
        
        # Backpropagation
        loss.backward()

        # Update the weights
        optimizer.step()

Epoch 1/10 Batch 1
Predicted Energy and actual energy: -7.028735051439224 -157.06329345703125
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -7.028735051439224 -157.06329345703125
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -6.878911497869062 -156.97494506835938
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -7.062438862461599 -156.88255310058594
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -7.443200856212064 -156.6575164794922
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -6.730251580028099 -156.48768615722656
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -6.4377556966548255 -156.21478271484375
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -22.55747415929531 -554.1339111328125
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -22.55747415929531 -554.1339111328125
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -22.82103504622723 -553.8827514648438
Epoch 1/10 Batch 1
Predicted Energy and actual energy: -23.2378

Epoch 1/10 Batch 8
Predicted Energy and actual energy: -27.389980676250666 -563.1913452148438
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -26.603744334632356 -559.9202270507812
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -26.467459043767647 -559.6945190429688
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -27.298635297988266 -559.38671875
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -26.96004689101582 -553.6113891601562
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.924128905842245 -159.5568084716797
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.886799761364152 -159.58363342285156
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.846030514775648 -159.61676025390625
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.821432623584994 -159.6465606689453
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.877238760437445 -159.62950134277344
Epoch 1/10 Batch 8
Predicted Energy and actual energy: -6.9054690194

fc2.weight tensor([[-2.3017e-02,  2.4867e-01, -9.0384e-02,  2.2981e-01, -2.0484e-01,
         -7.7866e-03,  1.2468e-01, -1.2984e-01,  3.2675e-02, -1.8956e-01,
          1.7521e-01, -2.0414e-01, -1.8847e-01, -1.8600e-01, -1.6917e-01,
         -1.3337e-01],
        [-1.7104e-01, -2.1066e-01, -1.5134e-01, -9.9021e-02, -1.8519e-01,
         -1.1152e-02, -6.8181e-02, -8.8197e-02, -1.1343e-01, -2.4065e-01,
          1.3579e-01, -7.6533e-02,  2.0688e-01,  4.0513e-02, -9.6893e-02,
         -2.3026e-02],
        [ 7.6858e-02,  2.0925e-01, -1.9633e-01,  1.4729e-01,  3.5941e-02,
          1.9597e-01,  3.7686e-02, -8.8825e-02,  3.4985e-02, -2.0993e-01,
         -1.1715e-02, -1.7662e-01,  1.8629e-01, -6.0084e-02, -1.0374e-01,
          7.6331e-02],
        [ 5.5914e-02, -1.4660e-01, -3.0199e-02,  1.1367e-01, -2.2873e-01,
          4.6192e-03,  2.3986e-01, -2.2636e-01, -1.1887e-04,  1.5854e-01,
         -2.2687e-01,  1.1218e-01, -1.2671e-01, -1.0015e-01,  1.0271e-02,
          1.3589e-01],
        [

Epoch 3/10 Batch 5
Predicted Energy and actual energy: -10.849943537491846 -305.1438903808594
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -10.944896055661541 -304.95477294921875
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -7.702496477672496 -152.04139709472656
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -7.702496477672496 -152.04139709472656
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -7.756544908984608 -151.90655517578125
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -7.469110980797621 -151.60496520996094
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -7.610405180999489 -151.28176879882812
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -8.444305470420346 -157.99252319335938
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -8.444305470420346 -157.99252319335938
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -8.554965643384072 -157.9524688720703
Epoch 3/10 Batch 5
Predicted Energy and actual energy: -8.69

Epoch 4/10 Batch 2
Predicted Energy and actual energy: -29.682536423870083 -539.6726684570312
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -15.63303212179122 -279.4229736328125
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -27.44631429500475 -591.80908203125
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -27.44631429500475 -591.80908203125
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -27.422684515300066 -591.5975952148438
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -27.990988596250766 -591.2819213867188
Epoch 4/10 Batch 2
Predicted Energy and actual energy: -28.19026306601548 -589.4938354492188
Loss: 141945.984375
fc2.weight tensor([[-2.3017e-02,  2.4867e-01, -9.0384e-02,  2.2981e-01, -2.0484e-01,
         -7.7866e-03,  1.2468e-01, -1.2984e-01,  3.2675e-02, -1.8956e-01,
          1.7521e-01, -2.0414e-01, -1.8847e-01, -1.8600e-01, -1.6917e-01,
         -1.3337e-01],
        [-1.7104e-01, -2.1066e-01, -1.5134e-01, -9.9021e-02, -1.8519e-01,


Predicted Energy and actual energy: -22.55747415929531 -554.1339111328125
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -22.55747415929531 -554.1339111328125
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -22.82103504622723 -553.8827514648438
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -23.23789281889378 -553.5497436523438
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -23.16001686047438 -552.1422119140625
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -22.976691178799026 -551.7311401367188
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -23.94880447658562 -549.706787109375
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -24.24805240658063 -549.3270874023438
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -24.147956379548372 -548.971923828125
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -29.378668975774502 -537.402587890625
Epoch 5/10 Batch 1
Predicted Energy and actual energy: -29.378668975774502 -537.402587890

Epoch 5/10 Batch 8
Predicted Energy and actual energy: -6.877238760437445 -159.62950134277344
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -6.905469019458209 -159.61947631835938
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -7.092182250751668 -159.50430297851562
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -14.278576823610369 -287.468505859375
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -14.278576823610369 -287.468505859375
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -14.132089262637505 -287.2696533203125
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -7.6951749298994185 -162.3910675048828
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -7.6951749298994185 -162.3910675048828
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -7.999498517337387 -162.26876831054688
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -8.221122128154201 -161.78965759277344
Epoch 5/10 Batch 8
Predicted Energy and actual energy: -7.5231

Epoch 6/10 Batch 6
Predicted Energy and actual energy: -13.710668351842644 -286.7972412109375
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -13.681451248815087 -286.7265625
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -13.720251668221753 -286.3158874511719
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -9.141044934598613 -148.51144409179688
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -9.141044934598613 -148.51144409179688
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -9.154658000601732 -148.49436950683594
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -9.103013221744272 -148.40267944335938
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -15.529030790702972 -295.3515625
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -15.529030790702972 -295.3515625
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -15.58476926203219 -295.17523193359375
Epoch 6/10 Batch 6
Predicted Energy and actual energy: -15.506277440778009 -29

Epoch 7/10 Batch 5
Predicted Energy and actual energy: -7.702496477672496 -152.04139709472656
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -7.756544908984608 -151.90655517578125
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -7.469110980797621 -151.60496520996094
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -7.610405180999489 -151.28176879882812
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.444305470420346 -157.99252319335938
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.444305470420346 -157.99252319335938
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.554965643384072 -157.9524688720703
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.695629755723296 -157.8834686279297
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.576162828624028 -157.86012268066406
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -8.995623887053732 -157.691162109375
Epoch 7/10 Batch 5
Predicted Energy and actual energy: -9.098874

Epoch 8/10 Batch 3
Predicted Energy and actual energy: -27.940142755362228 -589.30029296875
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -7.990352829323746 -169.84115600585938
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -7.990352829323746 -169.84115600585938
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -8.170109318068608 -169.78565979003906
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -15.960694447840893 -282.67266845703125
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -13.424784588422067 -299.0736389160156
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -13.424784588422067 -299.0736389160156
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -13.115347178335941 -298.84228515625
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -12.585607367430772 -297.9523010253906
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -12.761300902918094 -297.5878601074219
Epoch 8/10 Batch 3
Predicted Energy and actual energy: -11.8268

Epoch 9/10 Batch 1
Predicted Energy and actual energy: -28.48130361646652 -565.9724731445312
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -29.071678323834636 -565.205322265625
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -28.87279630583756 -565.0123291015625
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -29.168999909201553 -563.8114624023438
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -27.190386859005788 -578.1500244140625
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -27.190386859005788 -578.1500244140625
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -27.62974154781305 -577.9699096679688
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -27.525356037812465 -577.7468872070312
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -27.68335764829103 -577.9420166015625
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -25.71196434887273 -586.5313110351562
Epoch 9/10 Batch 1
Predicted Energy and actual energy: -25.7119643

Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.423725949335292 -590.432373046875
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.51797191907157 -589.7951049804688
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -26.912609445986075 -587.9435424804688
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.164035142494665 -587.1688842773438
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.27080900119991 -586.5784912109375
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -26.777620122438385 -586.2308349609375
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.63071940676708 -585.3802490234375
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.614514273284886 -585.2945556640625
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.66009889061851 -553.6220092773438
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.66009889061851 -553.6220092773438
Epoch 9/10 Batch 9
Predicted Energy and actual energy: -27.5409518

Epoch 10/10 Batch 7
Predicted Energy and actual energy: -28.44563674737791 -563.4025268554688
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -28.59110943097112 -563.2037963867188
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -28.497216316685957 -562.637939453125
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -28.470688038480375 -559.8316650390625
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -28.470264659784178 -559.8170166015625
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -13.648345077429637 -285.7696533203125
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -13.648345077429637 -285.7696533203125
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -13.752747172713025 -285.7219543457031
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -14.723873886359964 -300.77044677734375
Epoch 10/10 Batch 7
Predicted Energy and actual energy: -14.723873886359964 -300.77044677734375
Epoch 10/10 Batch 7
Predicted Energy and actual ene