Import Libraries

In [24]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any
import jraph
from itertools import combinations
import meshio
import numpy as np
import os

Hyper Params

In [25]:
Epochs = 500
alpha = 1.0
gamma = 1.0
lambda_ = 1.0
beta_1 = 0.999
beta_2 = 0.9
batch_size = 40
train_split = 0.9
CV_split = 0.05
test_split = 0.05
Learn_Rate = 0.001

RNG key

In [None]:
seed = 42 # This can be changed but is here to make the results easy to reproduce
base_key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(base_key)

Graph gen

In [None]:
def Get_known(boundary_points, points):
    is_known = jnp.zeros(points.shape[0]) 
    is_known = is_known.at[boundary_points].set(1)
    return is_known

def build_send_receive(cell):
    sender_array = []
    receiver_array = []
    for edge in combinations(cell,2):
        sender_array.append(edge[0])
        receiver_array.append(edge[1])
    return sender_array, receiver_array

def build_graphs(senders, receivers, positions, boundary_points, U) -> jraph.GraphsTuple:
    is_known = Get_known(boundary_points, positions)
    U_applied = jnp.zeros_like(U).at[boundary_points].set(U[boundary_points])
        
    node_features = jnp.concatenate([positions, U_applied, is_known], axis=1)
    num_nodes = positions.shape[0]

    graph = jraph.GraphsTuple(
        nodes=node_features,
        senders=senders,
        receivers=receivers,
        edges=None,
        globals=None, 
        n_node=jnp.array([num_nodes]),
        n_edge=jnp.array([len(senders)])
    )
    return graph

import data

In [58]:
# Define the path to your result file
filepath = os.path.join('data', 'vtk', 'u_final.vtu')

if not os.path.exists(filepath):
    print(f"Error: '{filepath}' not found. Please check the file path.")
else:
    mesh = meshio.read(filepath)

    positions = mesh.points
    right_face_indices = np.where(np.isclose(positions[:, 0], 1.0))[0]
    element_connectivity = mesh.cells[0].data

    unique_edges = set()

    for element in element_connectivity:
        element_senders, element_receivers = build_send_receive(element)
        
        for i in range(len(element_senders)):
            edge = tuple(sorted((element_senders[i], element_receivers[i])))
            unique_edges.add(edge)

    edge_list = jnp.array(list(unique_edges))
    senders = edge_list[:, 0]
    receivers = edge_list[:, 1]

    on_face_x0 = np.isclose(positions[:, 0], 0.0)
    on_face_x1 = np.isclose(positions[:, 0], 1.0)
    on_face_y0 = np.isclose(positions[:, 1], 0.0)
    on_face_y1 = np.isclose(positions[:, 1], 1.0)
    on_face_z0 = np.isclose(positions[:, 2], 0.0)
    on_face_z1 = np.isclose(positions[:, 2], 1.0)

    is_on_any_face = (on_face_x0 | on_face_x1 |
                      on_face_y0 | on_face_y1 |
                      on_face_z0 | on_face_z1)

    boundary_nodes = np.where(is_on_any_face)[0]

    print("Data extraction complete.\n")
    print(f"Positions array shape: {positions.shape}")
    print(f"Boundary indices array shape: {boundary_nodes.shape}")
    print(f"Senders array shape: {senders.shape}")
    print(f"Receivers array shape: {receivers.shape}")

Data extraction complete.

Positions array shape: (1331, 3)
Boundary indices array shape: (602,)
Senders array shape: (14230,)
Receivers array shape: (14230,)


Unpickling the data

In [None]:
import sys
import types
import pickle

fake_module = types.ModuleType("DataSetup")

class DataStore:
    def __init__(self):
        pass

fake_module.DataStore = DataStore

sys.modules["DataSetup"] = fake_module

data_file = r"/home/samuel/Github/Research-Placement/data/simulation_results.pkl"

with open(data_file,"rb") as f:
    data_unpickled_1 = pickle.load(f)

dataset_dict = data_unpickled_1

# Not tunable, is known from how many sims ran
num_sims = 5
# permutation list for batching
index_list = jnp.arange(num_sims)
permutated_index_list = jax.random.permutation(jax.random.PRNGKey(0), index_list)

print(dataset_dict[1]['boundary_strain_energy_gradient'].shape)

(1331, 3)


Pre-processing functions - Need to be changed when preprocessing is implemented to accomodate the data format

In [None]:
def mean_and_std_dev(data,*,train_split):
    split_idx = int(data.shape[0] * train_split)
    train_data = data[:split_idx]
    mean = jnp.mean(train_data, axis=0)
    std_dev = jnp.std(train_data, axis=0)
    return {'mean':mean, 'std_dev':std_dev}

def scale_data(data,*, data_params):
    return (data - data_params['mean']) / data_params['std_dev']
    
def unscale_data(data,*,data_params):
    return (data * data_params['std_dev']) + data_params['mean']

Data pre-processing and graph building

In [None]:
# Pre-processing
# Data preprocessing is ignored for now
processed_dataset_dict = dataset_dict

graphs_list = []
displacements_list = []
target_e_list = []
target_e_prime_list = []

# Graph Building
for i in range(num_sims):
    U = processed_dataset_dict[i]['full_displacement_vector']
    U = jnp.array(U)
    graph = build_graphs(senders, receivers, positions, boundary_nodes, U)
    graphs_list.append(graph)
    displacements_list.append(U)
    target_e = jnp.array(processed_dataset_dict[i]['strain_energy'])
    target_e_list.append(target_e)
    target_e_prime = jnp.array(processed_dataset_dict[i]['boundary_strain_energy_gradient'])
    target_e_prime_list.append(target_e_prime)

dataset = {
    'graphs_list': graphs_list,
    'displacements': displacements_list,
    'target_e': target_e_list,
    'target_e_prime': target_e_prime_list
}

param_dict = {
    
}

Batching functions

In [None]:
def batch_and_split_dataset(dataset_dict, batch_size, train_split, CV_split, test_split, permutated_index_list, shuffle=True):

    n_train_samples = int(train_split * permutated_index_list.shape[0]) 
    n_test_samples = int(test_split * permutated_index_list.shape[0]) 
    n_CV_samples = int(CV_split * permutated_index_list.shape[0]) 

    if shuffle:
        train_idx = permutated_index_list[:n_train_samples]
        test_idx = permutated_index_list[n_train_samples:(n_train_samples + n_test_samples)]
        CV_idx = permutated_index_list[(n_train_samples + n_test_samples):]
    else:
        index_list = range(len(permutated_index_list))
        train_idx = index_list[:n_train_samples]
        test_idx = index_list[n_train_samples:(n_train_samples + n_test_samples)]
        CV_idx = index_list[(n_train_samples + n_test_samples):]

    def batch_indices(idx):
        num_batches = len(idx) // batch_size
        for i in range(num_batches):
            start = i * batch_size
            end = start + batch_size
            batch_idx = idx[start:end]

            graphs_in_batch = [dataset_dict['graphs_list'][i] for i in batch_idx]
            displacements_batch = dataset_dict['displacements'][batch_idx]
            e_batch = dataset_dict['target_e'][batch_idx]
            e_prime_batch = dataset_dict['target_e_prime'][batch_idx]

            batched_graphs = jraph.batch(graphs_in_batch)
            batched_displacements = jnp.array(displacements_batch)
            batched_e = jnp.array(e_batch)
            batched_e_prime = jnp.array(e_prime_batch)

            yield {'graphs': batched_graphs, 'displacements': batched_displacements, 'target_e': batched_e, 'target_e_prime': batched_e_prime}
    
    train_batches = list(batch_indices(train_idx))
    test_batches = list(batch_indices(test_idx))
    CV_batches = list(batch_indices(CV_idx))

    return train_batches, CV_batches, test_batches

Batching graphs

In [None]:
train_batches, CV_batches, test_batches = batch_and_split_dataset(
    dataset_dict, 
    batch_size, 
    train_split, 
    CV_split, 
    test_split, 
    permutated_index_list, 
    shuffle=True
)

Activation Functions

In [None]:
def Silu(x: nnx.Array) -> nnx.Array:
    return x * nnx.sigmoid(x)

Linear layer

In [None]:
class Linear(nnx.Module):
    """
    Applies trainable linear transformation to input vector x
    Inputs: x: din dimensional row vectors as matrix
    Return: Transformed dout dimensional vector
    Trainable Params: w: d dimensional row vector, b: d dimensional row vector
    """
    def __init__(self, din: int, dout: int,*, rngs: nnx.Rngs):
        self.din, self.dout = din, dout
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.w = nnx.Param(initialiser(key=key, shape=(din,dout)))
        self.b = nnx.Param(initialiser(key=key, shape=(dout,)))
    
    def __call__(self, x: jnp.Array):
        return x @ self.w + self.b

GAT Layer

In [None]:
class GAT(nnx.Module):
    def __init__(self, in_features, out_features,*,rngs):
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.W = nnx.Param(initialiser(key=key, shape=(in_features, out_features)))
        self.A = nnx.Param(initialiser(key=key, shape=(2 * out_features, 1)))
        self.SoftMax = jraph.segment_softmax()
        self.Leaky_Relu = nnx.leaky_relu()

    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:

        if graph.n_node is None:
            raise ValueError("GAT requires nodes to have features")
        
        h_sender = graph.nodes[graph.senders] @ self.W
        h_receiver = graph.nodes[graph.receivers] @ self.W

        send_receive_features = jnp.concatenate([h_sender, h_receiver], axis=-1)
        attention_scores = self.Leaky_Relu(send_receive_features @ self.A)
        
        attention_coefficients = self.SoftMax(
            logits=attention_scores, 
            segments_ids=graph.receivers,
            num_segments=graph.n_node
        )

        weighted_features = attention_coefficients * h_sender

        aggregate_nodes = jraph.aggregate_edges_for_nodes(
            graph=graph,
            edge_features=weighted_features,
            aggregate_fn=jnp.sum
        )

        return graph._replace(nodes=aggregate_nodes)

SAGPool WIP

In [None]:
class SAGPool(nnx.Module): 
    pass

Model

In [None]:
class GNN(nnx.Module):
    def __init__(self, input_dim: int, embedding_dim: int, output_dim: int, rngs: nnx.Rngs):
        self.embedding_layer = Linear(input_dim, embedding_dim, rngs=rngs)
        self.decoding_layer = Linear(embedding_dim, output_dim, rngs=rngs)

        self.ReLU = nnx.relu()

        self.encoderL1 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL1 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL2 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL2 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL3 = GAT(embedding_dim, embedding_dim, rngs=rngs)
    
    def embedder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        nodes = graph.nodes
        embeddings = self.embedding_layer(nodes)
        return graph._replace(nodes=embeddings)
    
    def apply_activation_and_res(self, graph: jraph.GraphsTuple, residual: nnx.Array) -> jraph.GraphsTuple:
        nodes = graph.nodes
        activated_nodes = self.ReLU(nodes) + residual
        return graph._replace(nodes=activated_nodes)
    
    def apply_res(self, graph: jraph.GraphsTuple, residual: nnx.Array):
        new_nodes = graph.nodes + residual
        return graph._replace(nodes=new_nodes)
        
    def decoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Switch to SAGPool when its finished
        aggregate_nodes = jraph.aggregate_nodes(graph, jnp.sum)
        return self.decoding_layer(aggregate_nodes)
        
    def forward_pass(self, G: jraph.GraphsTuple, use_running_average: bool) -> nnx.Array:
        G = self.embedder(G)
        res1 = G.nodes

        G = self.encoderL1(G)
        self.BatchNormL1.use_running_average = use_running_average
        nodes_norm = self.BatchNormL1(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res1)
        res2 = G.nodes

        G = self.encoderL2(G)
        self.BatchNormL2.use_running_average = use_running_average
        nodes_norm = self.BatchNormL2(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res2)
        res3 = G.nodes

        G = self.encoderL3(G)
        G = self.apply_res(G, res3)

        e = self.decoder(G)
        return e
    
    def __call__(self, G: jraph.GraphsTuple, use_running_average):

        e = self.forward_pass(G, use_running_average)
        grad_graph = jax.grad(self.forward_pass, argnums=0)(G, use_running_average)
        e_prime = grad_graph.nodes[:,4:7]

        return e, e_prime

Loss function and Optimiser

In [None]:
optimiser = optax.chain(
    optax.add_decayed_weights(weight_decay=1e-5),
    optax.adam(
    learning_rate=Learn_Rate, 
    b1=beta_1, 
    b2=beta_2
    )
)

def loss_fn(graph_batch, target_e, target_e_prime,*, Model, Dataset_parameters, alpha, gamma, lam): 
    """
    Calculates the loss of a model, works to minimise the mean square error of both 
    the strain energy prediction and the strain energy derivative prediction,
    whilst forcing the function through zero.
    """
    
    prediction_e, prediction_e_prime = Model(graph_batch, Dataset_parameters)
    loss_e = jnp.mean((prediction_e - target_e)**2)
    loss_e_prime = jnp.mean(optax.huber_loss(prediction_e_prime, target_e_prime))

    mean_e = Dataset_parameters['target_e']['mean']
    std_dev_e = Dataset_parameters['target_e']['std_dev']
    target_zero = (0 - mean_e) / std_dev_e
    
    zero_graph = jraph.graphstuple(
        
    )
    prediction_zero, _ = Model(zero_graph, Dataset_parameters)
    loss_zero = jnp.mean((prediction_zero - target_zero)**2)

    return (alpha * loss_e + gamma * loss_e_prime + lam * loss_zero)

Training Dataclass

In [None]:
@nnx.dataclass
class TrainState(nnx.Object):
    params: Any
    graph_def: Any
    state: Any

CV loss

In [None]:
def CV_loss_fn(CV_batches, graph_def, params, state, dataset_parameters, alpha, gamma, lambda_):
    Model = nnx.merge(graph_def, params, state)
    CV_loss = 0
    batch_count = 0

    for CV_batch in CV_batches:
        batch_count += 1
        graph_batch = CV_batch['graphs']
        target_e_batch = CV_batch['target_e']
        target_e_prime_batch = CV_batch['target_e_prime']

        loss = loss_fn(
            graph_batch,
            target_e_batch,
            target_e_prime_batch,
            Model=Model,
            Dataset_parameters=dataset_parameters,
            alpha=alpha,
            gamma=gamma,
            lam=lambda_
        )

        CV_loss += loss

    CV_loss = CV_loss / batch_count

    return CV_loss

Train Step

In [None]:
@jax.jit
def train_step(params, graph_def, state, opt_state, GraphandTarget_batch, *, dataset_parameters, alpha, gamma, lambda_):

    target_e_batch = GraphandTarget_batch['target_e']
    target_e_prime_batch = GraphandTarget_batch['target_e_prime']
    graph_batch = GraphandTarget_batch['graphs']

    def wrapped_loss(params_, state_):
        Model = nnx.merge(graph_def, params_, state_)
        loss = loss_fn(
            graph_batch,
            target_e_batch,
            target_e_prime_batch,
            Model=Model,
            Dataset_parameters=dataset_parameters,
            alpha=alpha,
            gamma=gamma,
            lam=lambda_
        )
        return loss
    
    loss, grads = nnx.value_and_grad(wrapped_loss, argnums=0)(params, state)
    updates, new_opt_state = optimiser.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    new_state = state
    
    return new_params, new_state, new_opt_state, loss

Train Loop

In [None]:
# Instantiate energy prediction NN
Model = GNN(
    input_dim=node_features[0], 
    embedding_dim=128,
    output_dim=1,
    rngs=rngs
)
input_dim: int, embedding_dim: int, output_dim: int, rngs: nnx.Rngs

graph_def,params,state = nnx.split(Model,nnx.Param,nnx.State)
opt_state = optimiser.init(params)

train_state = TrainState(
    graph_def=graph_def,
    params=params,
    state=state,
    )

loss_record = []

for epoch in range(Epochs):
    running_loss = 0.0
    batch_count = 0
    for batch in tqdm(train_batches, desc=f"Epoch {epoch}/{Epochs}", leave=False):

        new_params, new_state, new_opt_state, batch_loss = train_step(

        )

        CV_loss = CV_loss_fn(
            CV_batches,
            graph_def,
            new_params,
            new_state,
            dataset_parameters,
            alpha,
            gamma,
            lambda_
        )

        opt_state = new_opt_state
        train_state.params = new_params
        train_state.state = new_state


Trained Model storage

In [None]:
@nnx.dataclass
class Model_storage(nnx.Object):
    params: Any
    graph_def: Any
    states: Any