Import Libraries

In [57]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
from flax import struct
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any

Unpickling the data

In [58]:
# Due to errors I was experiencing this seems to be the quickest fix I could find to allow me to unpickle the data
import sys
import types
import pickle

fake_module = types.ModuleType("DataSetup")

class DataStore:
    def __init__(self):
        pass

fake_module.DataStore = DataStore

sys.modules["DataSetup"] = fake_module

data_file_1 = r"C:\Users\samue\Downloads\Simulation.pickle"
data_file_2 = r"C:\Users\samue\Downloads\Simulation 2.pickle"

with open(data_file_1,"rb") as f:
    data_unpickled_1 = pickle.load(f)

with open(data_file_2,"rb") as f:
    data_unpickled_2 = pickle.load(f)

_,data_object_1 = data_unpickled_1
_,data_object_2 = data_unpickled_2

input_dataset_1 = jnp.array(data_object_1.Indata)
#data_index_1 = data_object_1.i
e_dataset_1 = jnp.array(data_object_1.SE)
e_prime_dataset_1 = jnp.array(data_object_1.Jac)

input_dataset_2 = jnp.array(data_object_2.Indata)
#data_index_2 = data_object_2.i
e_dataset_2 = jnp.array(data_object_2.SE)
e_prime_dataset_2 = jnp.array(data_object_2.Jac)

print(input_dataset_2.shape)
print(input_dataset_1.shape)
print(e_dataset_1.shape)
print(e_prime_dataset_1.shape)

input_dataset = jax.numpy.concatenate([input_dataset_1,input_dataset_2],axis=0)
target_e_dataset = jax.numpy.concatenate([e_dataset_1, e_dataset_2],axis=0)
target_e_dataset = jax.numpy.expand_dims(target_e_dataset,axis=1)
target_e_prime_dataset = jax.numpy.concatenate([e_prime_dataset_1,e_prime_dataset_2],axis=0)
print(input_dataset.shape)
print(target_e_dataset.shape)
print(target_e_prime_dataset.shape)

(10000, 152, 3)
(10000, 152, 3)
(10000,)
(10000, 152, 3)
(20000, 152, 3)
(20000, 1)
(20000, 152, 3)


Redimensionalise

In [59]:
def Redimensionalise(self):
    self.Disp = jnp.zeros((self.Dims,self.Dims,self.Dims,3))
    m = 0
    for i in range(self.Dims):
        for j in range(self.Dims):
            for k in range(self.Dims):
                if self.xInMesh[0][i,j,k] == 0 or self.xInMesh[0][i,j,k] == 1 or self.xInMesh[1][i,j,k] == 0 or self.xInMesh[1][i,j,k] == 1 or self.xInMesh[2][i,j,k] == 0 or self.xInMesh[2][i,j,k] == 1:
                    self.Disp[i,j,k,:] = self.RandDisp[self.Index,m,:]
                    m = m +1
    return self.Disp
    

RNG key

In [60]:
seed = 42 # This can be changed but is here to make the results easy to reproduce

base_key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(base_key)

Hyper Params

In [61]:
Epochs = 10
alpha = 1.0
gamma = 1.0
lambda_ = 1.0
Learn_Rate = 0.001
beta_1 = 0.9
beta_2 = 0.999
Batch_size = 10

Dataset

In [62]:
batch_num = input_dataset.shape[0] // Batch_size

# input_dataset target_e_dataset target_e_prime_dataset 

input_dataset = input_dataset.reshape((20000,456))
target_e_dataset = target_e_dataset.reshape((20000,))
target_e_prime_dataset = target_e_prime_dataset.reshape((20000,456))

Dataset = {
    'displacements':input_dataset,
    'target_e':target_e_dataset,
    'target_e_prime':target_e_prime_dataset
}


Node Classes and Acivations

In [63]:
class Linear(nnx.Module):
    """Linear node for neural network"""

    def __init__(self,din: int,dout: int,*,rngs: nnx.Rngs):
        key = rngs.params()
        self.W = nnx.Param(jax.random.uniform(key=key, shape=(din,dout)))
        self.b = nnx.Param(jnp.zeros(shape=(dout,)))
        self.din, self.dout = din, dout

    def __call__(self,x: jax.Array):
        return(x @ self.W + self.b)
    
def SiLU(x: jax.Array):
    """Sigmoid Weighted Linear Unit activation function"""
    return x * jax.nn.sigmoid(x)

Model Architecture

In [64]:
class energy_prediction(nnx.Module):
    """Model architecture"""

    def __init__(self,dim_in: int, dim_hidden1_in: int, dim_hidden2_in: int,dim_hidden3_in, dim_out: int,*,rngs: nnx.Rngs):
        self.layer1 = Linear(din=dim_in,dout=dim_hidden1_in,rngs=rngs)
        self.layer2 = Linear(din=dim_hidden1_in,dout=dim_hidden2_in,rngs=rngs)
        self.layer3 = Linear(din=dim_hidden2_in,dout=dim_hidden3_in,rngs=rngs)
        self.layer4 = Linear(din=dim_hidden3_in,dout=dim_out,rngs=rngs)
        self.silu = SiLU
        
    def __call__(self,x_in):
        # pass to calculate e
        def forwardPass(x):
            x = self.layer1(x)
            x = self.silu(x)
            x = self.layer2(x)
            x = self.silu(x)
            x = self.layer3(x)
            x = self.silu(x)
            x = self.layer4(x)
            return x.squeeze(-1)
        
        e = forwardPass(x_in)
        dedx = jax.vmap(jax.grad(forwardPass,argnums=(0)))
        e_prime = dedx(x_in)

        return e, e_prime


Define optimiser and loss

In [65]:
optimiser = optax.adam(learning_rate=Learn_Rate, b1=beta_1, b2=beta_2)

def loss_fn(x: jax.Array, target_e, target_e_prime,*, Model,alpha,gamma,lam): 
    """
    Calculates the loss of a model, works to minimise the mean square error of both 
    the strain energy prediction and the strain energy derivative prediction,
    whilst forcing the function through zero.
    """
    prediction_e, prediction_e_prime = Model(x)
    loss_e = jnp.mean((prediction_e - target_e)**2)
    loss_e_prime = jnp.mean((prediction_e_prime - target_e_prime)**2)

    target_zero = 0
    x_zero = jnp.zeros(x[0].shape)
    x_zero = jnp.expand_dims(x_zero, axis=0)
    prediction_zero, _ = Model(x_zero)
    loss_zero = jnp.mean((prediction_zero - target_zero)**2)

    return (alpha * loss_e + gamma * loss_e_prime + lam * loss_zero)

Train State Bundle

In [66]:
@nnx.dataclass
class TrainState(nnx.Object):
    params: Any
    graph_def: Any 
    state: Any
    alpha: float 
    gamma: float 
    lambda_: float 

Train Step

In [67]:
@nnx.jit
def training_step(params,state,opt_state,batch,*,graph_def,alpha,gamma,lambda_):

    disp_in = batch['displacements']
    e_target = batch['target_e']
    e_prime_target = batch['target_e_prime']

    def wrapped_loss_fn(params_,state_):
        Model, new_state = nnx.merge(graph_def,params_,state_)
        loss = loss_fn(
            disp_in,
            e_target,
            e_prime_target,
            model=Model,
            alpha=alpha,
            gamma=gamma,
            lam=lambda_
        )
        return loss, new_state

    (loss, new_state), grads = nnx.value_and_grad(wrapped_loss_fn, argnums=0,has_aux=True)(params, state) 
    updates, new_opt_state = optimiser.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_state, new_opt_state, loss

Batch Creator and test set creator

In [68]:
def split_and_batch_dataset(dataset, batch_size, test_split=0.2, shuffle=True):
    """
    Splits the dataset into training and test sets, then yields batches for each.
    Returns: (train_batches, test_batches)
    """
    N = dataset['displacements'].shape[0]
    indices = jnp.arange(N)
    if shuffle:
        indices = jax.random.permutation(jax.random.PRNGKey(0), indices)
    split_idx = int(N * (1 - test_split))
    train_idx = indices[:split_idx]
    test_idx = indices[split_idx:]

    def batch_indices(idx):
        for start in range(0, len(idx), batch_size):
            end = start + batch_size
            batch_idx = idx[start:end]
            yield {key: value[batch_idx] for key, value in dataset.items()}

    train_batches = list(batch_indices(train_idx))
    test_batches = list(batch_indices(test_idx))
    return train_batches, test_batches

Create test and train batches

In [69]:
train_batches, test_batches = split_and_batch_dataset(
    Dataset, 
    Batch_size, 
    test_split=0.2, 
    shuffle=True
)

Training Loop

In [70]:
# Instantiate energy prediction NN
Model = energy_prediction(
    dim_in=input_dataset.shape[1], 
    dim_hidden1_in=2024,
    dim_hidden2_in=1012,
    dim_hidden3_in=212, 
    dim_out=1,
    rngs=rngs
)

graph_def,params,state = nnx.split(Model,nnx.Param,nnx.State)
opt_state = optimiser.init(params)

train_state = TrainState(
    graph_def=graph_def,
    params=params,
    state=state,
    alpha=alpha,
    gamma=gamma,
    lambda_=lambda_
    )

loss_record = []

for epoch in range(Epochs):
    running_loss = 0.0
    batch_count = 0

    for batch in tqdm(train_batches,desc=f"Epoch {epoch}/{Epochs}", leave=False):
        print(f"DEBUG: The variable 'loss_fn' is actually: {loss_fn}")
        new_params, new_state, new_opt_state, loss_batch = training_step(
            train_state.params,
            train_state.state,
            opt_state,
            batch,
            graph_def=train_state.graph_def,
            alpha=train_state.alpha,
            gamma=train_state.gamma,
            lambda_=train_state.lambda_
        )

        opt_state = new_opt_state

        train_state = train_state.replace(
            graph_def=train_state.graph_def,
            params=new_params,
            state=new_state,
            alpha=train_state.alpha,
            gamma=train_state.gamma,
            lambda_=train_state.lambda_
        )

        running_loss += loss_batch
        batch_count += 1
    
    avg_loss = avg_loss = running_loss / batch_count if batch_count > 0 else 0.0
    loss_record.append(avg_loss)

                                                    

DEBUG: The variable 'loss_fn' is actually: <function loss_fn at 0x00000200D6FD8360>




TypeError: cannot unpack non-iterable energy_prediction object

Final model storage

In [None]:
@nnx.dataclass
class ModelData(nnx.Object):
    model_def: Any
    params: Any
    trained: bool

Create Final model instance

In [None]:
model_def_trained = train_state.model_def
params_trained = train_state.params

model_data = ModelData(
    model_def=model_def_trained,
    params=params_trained,
    trained=True
)

Plots

In [None]:
plt.plot(loss_record)

Eval State

In [None]:
test_model = train_state.model

class Eval_state:
    def __init__(self,model,params,alpha,gamma,lambda_):
        self.model = model
        self.params = params
        self.alpha = alpha
        self.gamma = gamma
        self.lambda_ = lambda_

evaluation_state = Eval_state(
    model=test_model,
    params=1,
    alpha=alpha,
    gamma=gamma,
    lambda_=lambda_
)

Model Testing

In [None]:
def avg_abs_error(pred,target):
    n1 = pred.shape[0]
    n2 = target.shape[0]

    if n1 != n2:
        raise("Error: inputs must have matching shape")
    
    return (jnp.sum(jnp.abs(pred - target)) / n1)

def test_model(evaluation_state,Batch_size,*,loss_fn):
    test_model = evaluation_state.model
    alpha = evaluation_state.alpha
    gamma = evaluation_state.gamma
    lambda_ = evaluation_state.lambda_
    
    loss_test = 0.0

    for batch in test_batches:
        displacements_test = batch['displacements']
        e_target_test = batch['target_e']
        e_prime_target_test = batch['target_e_prime']

        e_pred_test, e_prime_pred_test = test_model(displacements_test)

        batch_loss_test = loss_fn(
            displacements_test,
            e_target_test,
            e_prime_target_test,
            test_model,
            alpha=alpha,
            gamma=gamma,
            lam=lambda_
        )

        loss_test += batch_loss_test

        avg_e_abs_error = avg_abs_error(e_pred_test,e_target_test)
        avg_e_prime_abs_error = avg_abs_error(e_prime_pred_test,e_prime_target_test)

    avg_loss_test = loss_test / Batch_size
    zero_val_e,zero_val_e_prime = test_model(jnp.zeros())
    test_e_zero_error = avg_abs_error(zero_val_e, jnp.zeros_like(zero_val_e))
    test_e_prime_zero_error = avg_abs_error(zero_val_e_prime, jnp.zeros_like(zero_val_e_prime))

    return avg_loss_test, avg_e_abs_error, avg_e_prime_abs_error, test_e_zero_error, test_e_prime_zero_error

