In [20]:
import jax.numpy as np
from jax import lax, nn, random, vmap
from jax._src.nn.functions import normalize
from jax.experimental import stax
from jax.nn.initializers import glorot_normal
from jax.random import normal
from functools import partial
from patch_gnn.layers import softmax_on_non_zero,concatenate_node_features,get_norm_attn_matrix

In [21]:
key = random.PRNGKey(12)
k1, k2,k3,k4,k5 = random.split(key, num = 5)
n_features =5
n_output_dims =13
n_node = 17
w = random.normal(k1, (n_features, n_output_dims))
a = random.normal(k2, (n_output_dims * 2,))
nfa = random.normal(k3, (n_features,))
adjacency_matrix = random.normal(k4, (n_node,n_node))
node_embeddings = random.normal(k5, (n_node, n_features))


In [23]:
node_feat_attn = vmap(partial(np.multiply, np.abs(nfa)))(node_embeddings)
node_projection = np.dot(
        node_feat_attn, w
    )
node_by_node_concat = concatenate_node_features(
        node_projection
    )
projection = np.dot(node_by_node_concat, a)
atten_leaky_relu = nn.leaky_relu(projection, negative_slope=0.1)
atten_leaky_relu = np.squeeze(atten_leaky_relu)
attention = atten_leaky_relu * np.squeeze(
        adjacency_matrix
    )
norm_attention = softmax_on_non_zero(attention, adjacency_matrix)#vmap(partial(softmax_on_non_zero, adj =adjacency_matrix ))(attention)
norm_attention.shape


(17, 17)

In [2]:
key = random.PRNGKey(12)
k1, k2 = random.split(key)
n_patches = 3
n_node = 5
n_features = 13
n_adjacencies = 7

A = random.normal(k1, (n_node,n_node, n_adjacencies))  # adjacency matrix, 5 node * 5 node
F = random.normal(k2, (n_node, n_features)) # feature matrix, 5 node *13 features
adjacency_weights_init=glorot_normal() # init weight 
adjacency_weights = adjacency_weights_init(k1, (n_adjacencies, 1))
adjacency_weights.shape
mp = vmap(np.dot, in_axes=(-1, None), out_axes=(-1))(A, F)

print(f"mp.shape is {mp.shape}") #shape (n_nodes, n_features, n_adjacency_like_matrics) 
print(f"np.dot(mp, adjacency_weights) shape is {np.dot(mp, adjacency_weights).shape}")
print(f"np.squeeze(np.dot(mp, adjacency_weights)) shape is {(np.dot(mp, adjacency_weights)).shape}") # shape (n_nodes, n_features)



mp.shape is (5, 13, 7)
np.dot(mp, adjacency_weights) shape is (5, 13, 1)
np.squeeze(np.dot(mp, adjacency_weights)) shape is (5, 13, 1)


In [5]:
from patch_gnn.layers import MessagePassing

In [6]:
init_fun, apply_fun = MessagePassing()

In [7]:
(n_nodes, n_features), (adjacency_weights)  = init_fun(random.PRNGKey(12), input_shape = (5,13,1))
adjacency_weights.shape # (1,1)
out = apply_fun(adjacency_weights[0][0], (A, F)) # why is message passing params not a matrix?
out.shape

(13, 5)

In [None]:
adjacency_weights.shape

In [None]:
n_adjacencies = 1
adjacency_weights = random.normal(k1, (n_adjacencies, 1))

In [None]:
mp_out = np.squeeze(np.dot(mp, params))
mp_out.shape

In [8]:
from patch_gnn.layers import concatenate_node_features

### according to the current implementation

### However, alternatively if node_by_node_cat and a are different shape

In [97]:
node_by_node_concat = random.normal(k1, (n_node, 2*n_output_dims))
a = random.normal(k2, (n_node, 2*n_output_dims))

projection = np.dot(node_by_node_concat, a.T)
print(projection.shape)
atten_after_relu = nn.leaky_relu(projection, negative_slope=0.1)
print(output.shape)


(3, 3)
(3, 3)


In [98]:
atten_after_relu

DeviceArray([[ 3.1671557 ,  7.295389  ,  1.4041579 ],
             [-1.2738957 , -0.39845598, -0.28209615],
             [ 5.189487  , -0.09320077, -0.586724  ]], dtype=float32)

In [99]:
np.tril(atten_after_relu)

DeviceArray([[ 3.1671557 ,  0.        ,  0.        ],
             [-1.2738957 , -0.39845598,  0.        ],
             [ 5.189487  , -0.09320077, -0.586724  ]], dtype=float32)

In [100]:
def softmax_on_non_zero(vect):
    """
    Apply softmax normalization on a vector and ignore the 0 values
    For example: [-0.3, 0 , 0] denotes the attention of a node on three 
    nodes, this node has no value on node 2 or 3, and only -0.3 value on
    itself, the softmax will return [1,0,0], given "0" doesn't participate in
    softmax calculation
    
    :param vect: a column vector from a matrix output from `leaky_relu`
    :return a list with softmax normalized values
    """
    vect_non_zero = np.asarray([i for i in vect if i!=0])
    return nn.softmax(vect_non_zero).tolist()

In [101]:
def get_norm_attn_matrix(atten_after_relu):
    """
    Apply softmax normalization on a matrix of shape (n_node, n_node).
    Turn upper triangle of this matrix to value of 0s to represent that 
    the attention of a given node only applies to the nodes before the 
    current node
    
    :param output: the attention matrix coming out from `leaky_relu`
    
    :returns matrix: a numpy.ndarray triangle matrix of shape (n_node, n_node)
    denoting how much each node pays attention to previous nodes. The attention
    of each node on other node should sum up to 1

    """
    matrix = []
    # turn previous atten matrix to triagnle matrix
    # btw, when performing masking, the matrix coming out from 
    # leaky_relu is not symmetrical, just always ignore the upper triangle of the matrix?
    atten_after_relu = np.tril(atten_after_relu)
    for i in range(len(atten_after_relu)):
        matrix.append(softmax_on_non_zero(atten_after_relu[i,:]))
    pad = len(max(matrix, key =len))
    matrix  = numpy.array([i + [0]*(pad-len(i)) for i in matrix])
    return matrix

In [93]:
for i in range(len(output)):
    print(softmax_on_non_zero(output[i,:]))

[1.0]
[0.29412367939949036, 0.705876350402832]
[0.9918871521949768, 0.005037559196352959, 0.0030752879101783037]


In [103]:
type(get_norm_attn_matrix(atten_after_relu))

numpy.ndarray

In [104]:
get_norm_attn_matrix(atten_after_relu)

array([[1.        , 0.        , 0.        ],
       [0.29412368, 0.70587635, 0.        ],
       [0.99188715, 0.00503756, 0.00307529]])

In [32]:
#import numpy
#a=  numpy.empty([3,3])
#a[:] = 1
#a[numpy.tril_indices(a.shape[0], -1)] = numpy.nan
#a.T

array([[ 1., nan, nan],
       [ 1.,  1., nan],
       [ 1.,  1.,  1.]])

In [None]:
# test message passing

In [87]:
from patch_gnn.layers import MessagePassing
from jax.experimental import stax
from jax.random import normal,PRNGKey
from jax import random,jit
from jax.experimental.optimizers import adam
from functools import partial
from patch_gnn.training import mseloss, step
from tqdm.auto import tqdm
import jax.numpy as np
from patch_gnn.models import MPNN

In [88]:
adjs = random.normal(k1, (3,20,20, 1))
feats = random.normal(k1, (3,20,67))
train_graph = (adjs, feats)
train_target = random.normal(k2, (3,))

In [89]:
num_training_steps = 10
node_feature_shape = (20,67)
num_adjacency =1
model_mpnn = MPNN(
        node_feature_shape=node_feature_shape,
        num_adjacency=num_adjacency,
        num_training_steps=num_training_steps
    )
model_mpnn.fit(train_graph, train_target)

  0%|          | 0/10 [00:00<?, ?it/s]

mp is Traced<ShapedArray(float32[20,67,1])>with<BatchTrace(level=3/1)>
  with val = Traced<ShapedArray(float32[3,20,67,1])>with<DynamicJaxprTrace(level=0/1)>
       batch_dim = 0
 output is Traced<ShapedArray(float32[20,67])>with<BatchTrace(level=3/1)>
  with val = Traced<ShapedArray(float32[3,20,67])>with<JVPTrace(level=2/1)>
               with primal = Traced<ShapedArray(float32[3,20,67])>with<DynamicJaxprTrace(level=0/1)>
                    tangent = Traced<ShapedArray(float32[3,20,67]):JaxprTrace(level=1/1)>
       batch_dim = 0


<patch_gnn.models.MPNN at 0x2aacfff90430>

In [90]:
model_mpnn.loss_history

[DeviceArray(95.95095, dtype=float32),
 DeviceArray(10.907946, dtype=float32),
 DeviceArray(29.840017, dtype=float32),
 DeviceArray(52.032665, dtype=float32),
 DeviceArray(38.774757, dtype=float32),
 DeviceArray(16.14018, dtype=float32),
 DeviceArray(6.3733273, dtype=float32),
 DeviceArray(12.946376, dtype=float32),
 DeviceArray(23.326376, dtype=float32),
 DeviceArray(25.031109, dtype=float32)]