Import Libraries

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm

Unpickling the data

In [None]:
# Due to errors I was experiencing this seems to be the quickest fix I could find to allow me to unpickle the data
import sys
import types
import pickle

fake_module = types.ModuleType("DataSetup")

class DataStore:
    def __init__(self):
        pass

fake_module.DataStore = DataStore

sys.modules["DataSetup"] = fake_module

data_file = r"C:\Users\samue\Downloads\Simulation.pickle"

with open(data_file,"rb") as f:
    data_unpickled = pickle.load(f)

data_index,data_object = data_unpickled

print(dir(data_object))

['DIR', 'Indata', 'Jac', 'SE', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__firstlineno__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__static_attributes__', '__str__', '__subclasshook__', '__weakref__', 'i']


Dataset

RNG key

In [None]:
seed = 42 # This can be changed but is here to make the results easy to reproduce

base_key = jax.random.PRNGkey(seed)
rngs = nnx.Rngs(base_key)

Hyper Params

In [None]:
Epochs = 1000
alpha = 1.0
gamma = 1.0
lambda_ = 1.0
Learn_Rate = 0.001
beta_1 = 0.9
beta_2 = 0.999

Node Classes and Acivations

In [None]:
class Linear(nnx.module):
    """Linear node for neural network"""

    def __init__(self,din: int,dout: int,*,rngs: nnx.Rngs):
        key = rngs.params()
        self.W = nnx.Param(jax.random.uniform(key=key, shape=(din,dout)))
        self.b = nnx.Param(jnp.zeros(shape=(dout,)))
        self.din, self.dout = din, dout

    def __call__(self,x: jax.Array):
        return(x @ self.W + self.b)
    
def SiLU(x: jax.Array):
    """Sigmoid Weighted Linear Unit activation function"""
    return x * jax.nn.sigmoid(x)

Model Architecture

In [None]:
class energy_prediction(nnx.module):
    """Model architecture"""

    def __init__(self,dim_in: int, dim_hidden1_in: int, dim_hidden2_in: int,dim_hidden3_in, dim_out: int,*,rngs: nnx.Rngs):
        self.layer1 = Linear(din=dim_in,dout=dim_hidden1_in)
        self.layer2 = Linear(din=dim_hidden1_in,dout=dim_hidden2_in, rngs=rngs)
        self.layer3 = Linear(din=dim_hidden2_in,dout=dim_hidden3_in)
        self.layer4 = Linear(din=dim_hidden3_in,dout=dim_out)
        self.silu = SiLU()
        
    def __call__(self,x_in):
        # pass to calculate e
        def forwardPass(x):
            x = self.layer1(x)
            x = self.silu(x)
            x = self.layer2(x)
            x = self.silu(x)
            x = self.layer3(x)
            x = self.silu(x)
            x = self.layer4(x)
            return x.squeeze()
        
        e = forwardPass(x_in)

        # pass to calculate e_prime
        dedx = jax.grad(forwardPass,argnums=(0))
        e_prime = dedx(x_in)

        return e, e_prime


Define optimiser and loss

In [None]:
optimiser = optax.adam(learning_rate=Learn_Rate, b1=beta_1, b2=beta_2)

def loss_fn(x: jax.Array, target_e, target_e_prime,*, model,alpha,gamma,lam): 
    """
    Calculates the loss of a model, works to minimise the mean square error of both 
    the strain energy prediction and the strain energy derivative prediction,
    whilst forcing the function through zero.
    """
    prediction_e, prediction_e_prime = model(x)
    loss_e = jnp.mean((prediction_e - target_e)**2)
    loss_e_prime = jnp.mean((prediction_e_prime - target_e_prime)**2)
    target_zero = 0
    x_zero = jnp.zeros(x[0].shape)
    x_zero = jnp.expand_dims(x_zero, axis=0)

    prediction_zero, _ = model(x_zero)
    loss_zero = jnp.mean((prediction_zero - target_zero)**2)

    return (alpha * loss_e + gamma * loss_e_prime + lam * loss_zero)

Train State Bundle

In [None]:
class TrainState:
    def __init__(self,model,params,optimiser,opt_state,alpha,gamma,lambda_):
        self.model = model
        self.params = params
        self.optimiser = optimiser
        self.opt_state = opt_state
        self.alpha = alpha
        self.gamma = gamma
        self.lambda_ = lambda_

Train Step

In [None]:
@nnx.jit
def training_step(TrainState,batch,loss_fn):
    Model = TrainState.model
    params = TrainState.params
    optimiser = TrainState.optimiser
    opt_state = TrainState.opt_state
    alpha = TrainState.alpha
    gamma = TrainState.gamma
    lambda_ = TrainState.lambda_

    strain_in, e_target, e_prime_target = batch

    def wrapped_loss_fn(Model):
        loss = loss_fn(
            strain_in,
            e_target,
            e_prime_target,
            Model,
            alpha=alpha,
            gamma=gamma,
            lam=lambda_
        )
        return loss

    loss = loss_fn(strain_in,e_target,e_prime_target,Model,alpha=alpha,gamma=gamma,lam=lambda_)
    grads = nnx.grad(wrapped_loss_fn)(Model)

    updates, new_opt_state = optimiser.update(grads.parameters(), opt_state, Model.parameters())
    nnx.apply_updates(Model.parameters(), updates)

    new_state = TrainState.replace(opt_state=new_opt_state)
    return new_state, loss

Training Loop

In [None]:
# Instantiate energy prediction NN
Model = energy_prediction(
    dim_in=9, 
    dim_hidden1_in=1024,
    dim_hidden2_in=512,
    dim_hidden3_in=124, 
    dim_out=1,
    rngs=rngs
    )

params = nnx.get_parameters(Model) 

opt_state = optimiser.init(params)
train_state = TrainState(
    model=Model,
    params=params,
    optimiser=optimiser,
    opt_state=opt_state
    )

loss_record = []

for epoch in Epochs:
    running_loss = 0

    for batch in tqdm(dataloader,desc=f"Epoch {epoch}/{Epochs}", leave=False):
        new_state, loss_batch = training_step(TrainState,batch,loss_fn)

        running_loss += loss_batch
    
    loss_record.append(running_loss)