Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

Device Selection

In [2]:
if torch.cuda.is_available():
    my_device = torch.device("cuda")
    print("GPU Available")
else:
    my_device = torch.device("cpu")
    print("GPU Not Available")

GPU Not Available


Hyperparameters

In [3]:
Learn_Rate = 0.001
beta1 = 0.999
beta2 = 0.9
Epochs = 1000000
alpha = 1.0
gamma = 1.0

Dataset

In [4]:
# Wating for Dataset, input will be 3x3 strain matricies (flattened to 9x1), and labels will be scalar energy values and a 3x3 energy derivative matrix (flattened to 9x1)
import pickle

data_file = r"C:\Users\samue\Downloads\Simulation.pickle"

with open(data_file,"rb") as f:
    data_unpickled = pickle.load(f)

ModuleNotFoundError: No module named 'DataSetup'

Network Architecture and Energy Gradient Calculator

In [None]:
strain_input_dims = 9
energy_dims = 1

class Energy_Net(nn.module):
    def __init__(self, strain_input_dims, energy_dims):
        super().__init__()

        self.Layer1 = nn.Linear(strain_input_dims,1024)
        self.Layer2 = nn.Linear(1024,512)
        self.Layer3 = nn.Linear(512,128)
        self.Layer4 = nn.Linear(128,32)
        self.Layer5 = nn.Linear(32,energy_dims)

        self.silu = nn.SiLU()
    
    def forward(self,x):

        if not x.requires_grad:
            x.requires_grad_(True)

        x = self.Layer1(x)
        x = self.silu(x)
        x = self.Layer2(x)
        x = self.silu(x)
        x = self.Layer3(x)
        x = self.silu(x)
        x = self.Layer4(x)
        x = self.silu(x)
        energy = self.Layer5(x)

        energy_derivatives = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            create_graph=True,
            retain_graph=True
        )[0]

        return energy, energy_derivatives

Model = Energy_Net(strain_input_dims, energy_dims)

Optimiser and Loss

In [None]:
optimiser = torch.optim.Adam(Model.parameters(),lr = Learn_Rate,betas=(beta1, beta2))
loss = torch.nn.MSELoss()

Training Loop

In [None]:
dataloader = 0
loss_record = []

for epoch in Epochs:

    running_loss = 0

    for batch in tqdm(dataloader,desc=f"Epoch {epoch}/{Epochs}", leave=False):
        input_batch, target_energy_batch, target_energy_Deriv_batch = batch # Correct this to get the actual values

        optimiser.zero_grad()

        energy_pred, energy_deriv_pred = Model(input_batch)

        loss_E = loss(energy_pred,target_energy_batch)

        loss_E_Deriv = loss(energy_deriv_pred,target_energy_Deriv_batch)

        loss_total = alpha * loss_E + gamma * loss_E_Deriv

        loss_total.backward()

        optimiser.step()

        if torch.isnan(loss):
            print(f"Loss became NaN at batch {i} in epoch {epoch}!")
            if torch.isnan(Model.Layer1.weight).any():
                print("Model weights have been corrupted by NaN values.")
            break

        running_loss += loss.item()
    
    loss_record.append(running_loss)   

# Plot Loss Across Training 
plt.plot(loss_record)
print(f"final loss after {Epochs} epcohs is {loss_record[len(loss_record)-1]}")

Testing Loop

In [None]:
dataloader = 0 # correct this to be a test set not the training set
energy_rmse = 0
energy_deriv_rmse = 0
loss_record_test = []

with torch.no_grad():

    for databatch in tqdm(dataloader,desc=f"Epoch {epoch}/{Epochs}", leave=False):

        input, energy_target, energy_deriv_target = databatch

        energy_pred, energy_deriv_pred = Model(input)

        loss_E = loss(energy_pred,target_energy_batch)

        loss_E_Deriv = loss(energy_deriv_pred,target_energy_Deriv_batch)

        loss_total = alpha * loss_E + gamma * loss_E_Deriv

        energy_rmse_batch = torch.sqrt(torch.mean((energy_pred - energy_target) ** 2))
        energy_deriv_rmse_batch = torch.sqrt(torch.mean((energy_deriv_pred - energy_deriv_target) ** 2))

        energy_rmse += energy_rmse_batch
        energy_deriv_rmse += energy_deriv_rmse_batch

        loss_record_test.append(loss_total)

energy_rmse_mean = energy_rmse/len(energy_rmse)
energy_deriv_rmse_mean = energy_deriv_rmse/len(energy_deriv_rmse)

print(energy_rmse_mean, "/n",energy_deriv_rmse_mean)

plt.plot(loss_record_test)