Import Libraries

In [None]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any
import jraph
from itertools import combinations

RNG key

In [None]:
seed = 42 # This can be changed but is here to make the results easy to reproduce
base_key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(base_key)

import data

Standardising functions

In [None]:
def mean_and_std_dev(data,*,train_split):
    split_idx = int(data.shape[0] * train_split)
    train_data = data[:split_idx]
    
    mean = jnp.mean(train_data, axis=0)
    std_dev = jnp.std(train_data, axis=0)
    return {'mean':mean, 'std_dev':std_dev}

def scale_data(data,*, data_params):
    return (data - data_params['mean']) / data_params['std_dev']
    

def unscale_data(data,*,data_params):
    return (data * data_params['std_dev']) + data_params['mean']

Create graph from mesh function

In [None]:
def Get_known(cells, points):
    # The faces of a cube based on its points
    faces_of_cube = [[0, 1, 5, 4], [1, 2, 6, 5], [2, 3, 7, 6],
                       [3, 0, 4, 7], [0, 1, 2, 3], [4, 5, 6, 7]]
    
    # construct faces
    faces = []
    for cell in cells:
        for face_points in faces_of_cube:
            face = [cell[i] for i in face_points]
            faces.append(face)

    # identify faces that only appear once and are therefore edge faces
    edge_faces = []
    for face in faces:
        count = jnp.sum((faces==face))
        if count == 1:
            edge_faces.append(face)
    
    # deconstruct edge faces into a set of edge points
    edge_points = set()
    for edge_face in edge_faces:
        for point in edge_face:
            edge_points.add(point)

    # for all edge points the displacement is known so construct the is_known feature as such
    is_known = jnp.zeros(points.shape[0]) 
    for point in range(edge_points):
        is_known[point] = 1
    return is_known, edge_points

def build_send_receive(cells):
    sender_array = []
    receiver_array = []
    for edge in combinations(cells):
        sender_array.append(edge[0])
        receiver_array.append(edge[1])
    return sender_array, receiver_array

def build_graph(cells, points, U) -> jraph.GraphsTuple:
    is_known, _ = Get_known(cells,points)
    node_features = jnp.concatenate([points, U, is_known], axis=1)
    sender_array, receiver_array = build_send_receive(cells)
    num_nodes = points.shape[0]

    graph = jraph.GraphsTuple(
        nodes=node_features,
        senders=sender_array,
        receivers=receiver_array,
        edges=None,
        globals=None,
        n_node=jnp.array([num_nodes]),
        n_edge=jnp.array([len(sender_array)])
    )
    return graph


Activation Functions

In [None]:
def Silu(x: nnx.Array) -> nnx.Array:
    return x * nnx.sigmoid(x)

Linear layer

In [None]:
class Linear(nnx.Module):
    """
    Applies trainable linear transformation to input vector x
    Inputs: x: din dimensional row vectors as matrix
    Return: Transformed dout dimensional vector
    Trainable Params: w: d dimensional row vector, b: d dimensional row vector
    """
    def __init__(self, din: int, dout: int,*, rngs: nnx.Rngs):
        self.din, self.dout = din, dout
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.w = nnx.Param(initialiser(key=key, shape=(din,dout)))
        self.b = nnx.Param(initialiser(key=key, shape=(dout,)))
    
    def __call__(self, x: jnp.Array):
        return x @ self.w + self.b

GAT Layer

In [None]:
class GAT(nnx.Module):
    def __init__(self, in_features, out_features,*,rngs):
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.W = nnx.Param(initialiser(key=key, shape=(in_features, out_features)))
        self.A = nnx.Param(initialiser(key=key, shape=(2 * out_features, 1)))
        self.SoftMax = jraph.segment_softmax()
        self.Leaky_Relu = nnx.leaky_relu()

    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:

        if graph.n_node is None:
            raise ValueError("GAT requires nodes to have features")
        
        h_sender = graph.nodes[graph.senders] @ self.W
        h_receiver = graph.nodes[graph.receivers] @ self.W

        send_receive_features = jnp.concatenate([h_sender, h_receiver], axis=-1)
        attention_scores = self.Leaky_Relu(send_receive_features @ self.A)
        
        attention_coefficients = self.SoftMax(
            logits=attention_scores, 
            segments_ids=graph.receivers,
            num_segments=graph.n_node
        )

        weighted_features = attention_coefficients * h_sender

        aggregate_nodes = jraph.aggregate_edges_for_nodes(
            graph=graph,
            edge_features=weighted_features,
            aggregate_fn=jnp.sum
        )

        return graph._replace(nodes=aggregate_nodes)

SAGPool WIP

In [None]:
class SAGPool(nnx.Module): 
    pass

Model

In [None]:
class GNN(nnx.Module):
    def __init__(self, input_dim: int, embedding_dim: int, output_dim: int, rngs: nnx.Rngs):
        self.embedding_layer = Linear(input_dim, embedding_dim, rngs=rngs)
        self.decoding_layer = Linear(embedding_dim, output_dim, rngs=rngs)

        self.ReLU = nnx.relu()

        self.encoderL1 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL1 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL2 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL2 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL3 = GAT(embedding_dim, embedding_dim, rngs=rngs)
    
    def embedder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        nodes = graph.nodes
        embeddings = self.embedding_layer(nodes)
        return graph._replace(nodes=embeddings)
    
    def apply_activation_and_res(self, graph: jraph.GraphsTuple, residual: nnx.Array) -> jraph.GraphsTuple:
        nodes = graph.nodes
        activated_nodes = self.ReLU(nodes) + residual
        return graph._replace(nodes=activated_nodes)
    
    def apply_res(self, graph: jraph.GraphsTuple, residual: nnx.Array):
        new_nodes = graph.nodes + residual
        return graph._replace(nodes=new_nodes)
        
    def decoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Switch to SAGPool when its finished
        aggregate_nodes = jraph.aggregate_nodes(graph, jnp.sum)
        return self.decoding_layer(aggregate_nodes)
        
    def forward_pass(self, G: jraph.GraphsTuple, use_running_average: bool) -> nnx.Array:
        G = self.embedder(G)
        res1 = G.nodes

        G = self.encoderL1(G)
        self.BatchNormL1.use_running_average = use_running_average
        nodes_norm = self.BatchNormL1(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res1)
        res2 = G.nodes

        G = self.encoderL2(G)
        self.BatchNormL2.use_running_average = use_running_average
        nodes_norm = self.BatchNormL2(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res2)
        res3 = G.nodes

        G = self.encoderL3(G)
        G = self.apply_res(G, res3)

        e = self.decoder(G)
        return e
    
    def __call__(self, G: jraph.GraphsTuple, use_running_average):

        e = self.forward_pass(G, use_running_average)
        grad_graph = jax.grad(self.forward_pass, argnums=0)(G, use_running_average)
        e_prime = grad_graph.nodes[:,4:7]

        return e, e_prime

Loss function

In [None]:
optimiser = optax.chain(
    optax.add_decayed_weights(weight_decay=1e-5),
    optax.adam(
    learning_rate=Learn_Rate, 
    b1=beta_1, 
    b2=beta_2
    )
)

def loss_fn(graph_batch, target_e, target_e_prime,*, Model, Dataset_parameters, alpha, gamma, lam): 
    """
    Calculates the loss of a model, works to minimise the mean square error of both 
    the strain energy prediction and the strain energy derivative prediction,
    whilst forcing the function through zero.
    """
    
    prediction_e, prediction_e_prime = Model(x, Dataset_parameters)
    loss_e = jnp.mean((prediction_e - target_e)**2)
    loss_e_prime = jnp.mean(optax.huber_loss(prediction_e_prime, target_e_prime))

    mean_e = Dataset_parameters['target_e']['mean']
    std_dev_e = Dataset_parameters['target_e']['std_dev']
    target_zero = (0 - mean_e) / std_dev_e
    
    x_zero = jnp.zeros(x[0].shape)
    x_zero = jnp.expand_dims(x_zero, axis=0)
    prediction_zero, _ = Model(x_zero, Dataset_parameters)
    loss_zero = jnp.mean((prediction_zero - target_zero)**2)

    return (alpha * loss_e + gamma * loss_e_prime + lam * loss_zero)

Batch Graphs

In [None]:
def batch_and_split_dataset(graphs, batch_size, train_split, CV_split, test_split, key):
    shuffled_graphs_index = jax.random.permutation(key, len(graphs))
    n_train_batches = (train_split * len(graphs)) // batch_size
    n_test_batches = (test_split * len(graphs)) // batch_size
    n_CV_batches = (CV_split * len(graphs)) // batch_size

    train_batches_i = shuffled_graphs_index[:n_train_batches]
    test_batches_i = shuffled_graphs_index[n_train_batches:(n_train_batches + n_test_batches)]
    CV_batches_i = shuffled_graphs_index[(n_train_batches + n_test_batches):(n_train_batches + n_test_batches + n_CV_batches)]

    train_batches = graphs[train_batches_i]
    test_batches = graphs[test_batches_i]
    CV_batches = graphs[CV_batches_i]
    return train_batches, CV_batches, test_batches

Train Loop