Import Libraries

In [1]:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from flax import nnx
import optax
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any
import jraph

Unpickling the data and importing the mesh

Create graph from mesh function

In [None]:
def build_graph() -> jraph.GraphsTuple:
    graph = 0
    return graph

Activation Functions

In [None]:
def Silu(x: nnx.Array) -> nnx.Array:
    return x * nnx.sigmoid(x)

Linear layer

In [None]:
class Linear(nnx.Module):
    """
    Applies trainable linear transformation to input vector x
    Inputs: x: din dimensional row vectors as matrix
    Return: Transformed dout dimensional vector
    Trainable Params: w: d dimensional row vector, b: d dimensional row vector
    """
    def __init__(self, din: int, dout: int,*, rngs: nnx.Rngs):
        self.din, self.dout = din, dout
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.w = nnx.Param(initialiser(key=key, shape=(din,dout)))
        self.b = nnx.Param(initialiser(key=key, shape=(dout,)))
    
    def __call__(self, x: jnp.Array):
        return x @ self.w + self.b

GAT Layer

In [None]:
class GAT(nnx.Module):
    def __init__(self, in_features, out_features,*,rngs):
        key = rngs.params()
        initialiser = nnx.initializers.lecun_normal()
        self.W = nnx.Param(initialiser(key=key, shape=(in_features, out_features)))
        self.A = nnx.Param(initialiser(key=key, shape=(2 * out_features, 1)))
        self.SoftMax = jraph.segment_softmax()
        self.Leaky_Relu = nnx.leaky_relu()

    def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:

        if graph.n_node is None:
            raise ValueError("GAT requires nodes to have features")
        
        h_sender = graph.nodes[graph.senders] @ self.W
        h_receiver = graph.nodes[graph.receivers] @ self.W

        send_receive_features = jnp.concatenate([h_sender, h_receiver], axis=-1)
        attention_scores = self.Leaky_Relu(send_receive_features @ self.A)
        
        attention_coefficients = self.SoftMax(
            logits=attention_scores, 
            segments_ids=graph.receivers,
            num_segments=graph.n_node
        )

        weighted_features = attention_coefficients * h_sender

        aggregate_nodes = jraph.aggregate_edges_for_nodes(
            graph=graph,
            edge_features=weighted_features,
            aggregate_fn=jnp.sum
        )

        return graph._replace(nodes=aggregate_nodes)

SAGPool WIP

In [None]:
class SAGPool(nnx.Module): 
    pass

Model

In [None]:
class GNN(nnx.Module):
    def __init__(self, input_dim: int, embedding_dim: int, output_dim: int, rngs: nnx.Rngs):
        self.embedding_layer = Linear(input_dim, embedding_dim, rngs=rngs)
        self.decoding_layer = Linear(embedding_dim, output_dim, rngs=rngs)

        self.ReLU = nnx.relu()

        self.encoderL1 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL1 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL2 = GAT(embedding_dim, embedding_dim, rngs=rngs)
        self.BatchNormL2 = nnx.BatchNorm(num_features=embedding_dim, rngs=rngs)
        self.encoderL3 = GAT(embedding_dim, embedding_dim, rngs=rngs)
    
    def embedder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        nodes = graph.nodes
        embeddings = self.embedding_layer(nodes)
        return graph._replace(nodes=embeddings)
    
    def apply_activation_and_res(self, graph: jraph.GraphsTuple, residual: nnx.Array) -> jraph.GraphsTuple:
        nodes = graph.nodes
        activated_nodes = self.ReLU(nodes) + residual
        return graph._replace(nodes=activated_nodes)
    
    def apply_res(self, graph: jraph.GraphsTuple, residual: nnx.Array):
        new_nodes = graph.nodes + residual
        return graph._replace(nodes=new_nodes)
        
    def decoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: # Switch to SAGPool when its finished
        aggregate_nodes = jraph.aggregate_nodes(graph, jnp.sum)
        return self.decoding_layer(aggregate_nodes)
        
    def forward_pass(self, G: jraph.GraphsTuple, use_running_average: bool) -> nnx.Array:
        G = self.embedder(G)
        res1 = G.nodes

        G = self.encoderL1(G)
        self.BatchNormL1.use_running_average = use_running_average
        nodes_norm = self.BatchNormL1(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res1)
        res2 = G.nodes

        G = self.encoderL2(G)
        self.BatchNormL2.use_running_average = use_running_average
        nodes_norm = self.BatchNormL2(G.nodes)
        G = G._replace(nodes=nodes_norm)
        G = self.apply_activation_and_res(G, res2)
        res3 = G.nodes

        G = self.encoderL3(G)
        G = self.apply_res(G, res3)

        e = self.decoder(G)
        return e
    
    def __call__(self, G: jraph.GraphsTuple, use_running_average):

        e = self.forward_pass(G, use_running_average)
        grad_graph = jax.grad(self.forward_pass, argnums=0)(G, use_running_average)
        e_prime = grad_graph.nodes[:,4:7]

        return e, e_prime


