In [1]:
import jax.numpy as np
from jax import lax, nn, random, vmap
from jax._src.nn.functions import normalize
from jax.experimental import stax
from jax.nn.initializers import glorot_normal
from jax.random import normal

In [7]:
a = np.array(([1,2], [3,4]))
b = np.array(([5,6], [7,8]))

In [10]:
a,b

(DeviceArray([[1, 2],
              [3, 4]], dtype=int32),
 DeviceArray([[5, 6],
              [7, 8]], dtype=int32))

In [8]:
np.dot(a,b)

DeviceArray([[19, 22],
             [43, 50]], dtype=int32)

In [19]:
vmap(np.dot, in_axes=(0,0))(a,b) #[1*5+2*6, 3*7+4*8] # same if you do out_axe = (-1) or out_axe = (-0)

DeviceArray([17, 53], dtype=int32)

In [20]:
vmap(np.dot, in_axes=(0,1))(a,b) #[1*5+2*7, 3*6+4*8] ## consider row combination of a with corresponding col of b

DeviceArray([19, 50], dtype=int32)

In [14]:
vmap(np.dot, in_axes=(1,0))(a,b) #[1*5+3*6, 2*7+4*8]

DeviceArray([23, 46], dtype=int32)

In [22]:
vmap(np.dot, in_axes=(1,1))(a,b) #[1*5+3*7, 2*6+4*8]

DeviceArray([26, 44], dtype=int32)

In [25]:
vmap(np.dot, in_axes=(None,1))(a,b) #[[1*5+2*7, 3*5+4*7], [1*6+2*8, 3*6+4*8]] # consider all row combination of a with all col combination of b

DeviceArray([[19, 43],
             [22, 50]], dtype=int32)

In [27]:
vmap(np.dot, in_axes=(None,1))(a,b).T

DeviceArray([[19, 22],
             [43, 50]], dtype=int32)

In [17]:
vmap(np.dot, in_axes=(None,1), out_axes=(-1))(a,b) # same as out_axe = (0) and then transpose the output

DeviceArray([[19, 22],
             [43, 50]], dtype=int32)

In [30]:
a,b

(DeviceArray([[1, 2],
              [3, 4]], dtype=int32),
 DeviceArray([[5, 6],
              [7, 8]], dtype=int32))

In [28]:
vmap(np.dot, in_axes=(-1, None), out_axes=0)(a,b) #

DeviceArray([[26, 30],
             [38, 44]], dtype=int32)

In [29]:
vmap(np.dot, in_axes=(1, None), out_axes=0)(a,b) #

DeviceArray([[26, 30],
             [38, 44]], dtype=int32)

In [3]:
key = random.PRNGKey(12)
k1, k2 = random.split(key)
A = random.normal(k1, (5,5))  # A shape (5, 5), adjacency matrix, 5 node * 5 node
F = random.normal(k2, (5, 13)) #F shape (5, 13) feature matrix, 5 node, 13 features



In [6]:
mp0 = vmap(np.dot, in_axes=(-1, None), out_axes=0)(A, F)
mp0.shape

(5, 13)

In [5]:
mp = vmap(np.dot, in_axes=(-1, None), out_axes=(-1))(A, F) #MessagePassing apply_func in layers.py
mp.shape

(13, 5)

In [4]:
from jax.lax import batch_matmul
mp1 = np.dot(A, F)
mp2 = batch_matmul(A,F)
mp1.shape, mp2.shape

((5, 13), (5, 13))

In [5]:
from patch_gnn.layers import MessagePassing

In [6]:
init_fun, apply_fun = MessagePassing()

In [7]:
(n_nodes, n_features), (adjacency_weights)  = init_fun(random.PRNGKey(12), input_shape = (5,13,1))
adjacency_weights.shape # (1,1)
out = apply_fun(adjacency_weights[0][0], (A, F)) # why is message passing params not a matrix?
out.shape

(13, 5)

In [None]:
adjacency_weights.shape

In [None]:
n_adjacencies = 1
adjacency_weights = random.normal(k1, (n_adjacencies, 1))

In [None]:
mp_out = np.squeeze(np.dot(mp, params))
mp_out.shape