Import Libraries

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
import optax
import matplotlib.pyplot as plt

Unpickling the data

In [None]:
# Due to errors I was experiencing this seems to be the quickest fix I could find to allow me to unpickle the data
import sys
import types
import pickle

fake_module = types.ModuleType("DataSetup")

class DataStore:
    def __init__(self):
        pass

fake_module.DataStore = DataStore

sys.modules["DataSetup"] = fake_module

data_file = r"C:\Users\samue\Downloads\Simulation.pickle"

with open(data_file,"rb") as f:
    data_unpickled = pickle.load(f)

data_index,data_object = data_unpickled

print(dir(data_object))

['DIR', 'Indata', 'Jac', 'SE', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__firstlineno__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__static_attributes__', '__str__', '__subclasshook__', '__weakref__', 'i']


Dataset

RNG key

In [None]:
seed = 42 # This can be changed but is here to make the results easy to reproduce

base_key = jax.random.PRNGkey(seed)
rngs = nnx.Rngs(base_key)

Hyper Params

In [None]:
Epochs = 1000
alpha = 1.0
gamma = 1.0
Learn_Rate = 0.001
beta_1 = 0.9
beta_2 = 0.999

Node Classes and Acivations

In [None]:
class Linear(nnx.module):
    def __init__(self,din: int,dout: int,*,rngs: nnx.Rngs):
        key = rngs.params()
        self.W = nnx.Param(jax.random.uniform(key=key, shape=(din,dout)))
        self.b = nnx.Param(jnp.zeros(shape=(dout,)))
        self.din, self.dout = din, dout

    def __call__(self,x: jax.Array):
        return(x @ self.W + self.b)
    
def SiLU(x: jax.Array):
    return x * jax.nn.sigmoid(x)

Model Architecture

In [None]:
class energy_prediction(nnx.module):
    def __init__(self,dim_in: int, dim_hidden1_in: int, dim_hidden2_in: int,dim_hidden3_in, dim_out: int,*,rngs: nnx.Rngs):
        self.layer1 = Linear(din=dim_in,dout=dim_hidden1_in)
        self.layer2 = Linear(din=dim_hidden1_in,dout=dim_hidden2_in, rngs=rngs)
        self.layer3 = Linear(din=dim_hidden2_in,dout=dim_hidden3_in)
        self.layer4 = Linear(din=dim_hidden3_in,dout=dim_out)
        self.silu = SiLU()
        
    def __call__(self,x_in):
        # pass to calculate e
        def forwardPass(x):
            x = self.layer1(x)
            x = self.silu(x)
            x = self.layer2(x)
            x = self.silu(x)
            x = self.layer3(x)
            x = self.silu(x)
            x = self.layer4(x)
            return x.squeeze()
        
        e = forwardPass(x_in)

        # pass to calculate e_prime
        dedx = jax.grad(forwardPass,argnums=(0))
        e_prime = dedx(x_in)

        return e, e_prime

Model = energy_prediction()

Define optimiser and loss

In [None]:
optimiser = optax.adam(learning_rate=Learn_Rate, b1=beta_1, b2=beta_2)

class MSE_loss:
    def __init__(self,model_class,alpha,gamma):
        self.model_class = model_class
        self.alpha = alpha
        self.gamma = gamma

    def __call__(self, params, x, targets):
        model = nnx.apply(self.model_class,params)
        predictions = model(x)
        loss_e = jnp.mean((predictions[0] - targets[0])**2)
        loss_e_prime = jnp.mean((predictions[1] - predictions[1])**2)
        loss = self.alpha * loss_e + self.gamma * loss_e_prime
        return loss

loss = MSE_loss(model_class=energy_prediction,alpha=alpha,gamma=gamma)