In [1]:
from clrs._src.processors import GATv2FullD2
import haiku as hk
import optax
import jax
import jax.numpy as jnp
from jax.numpy import array

In [2]:
# import matplotlib.pyplot as plt \\ TODO DOWNLOAD MATPLOTLIB IN TACTIC_AI ENV
# \\ TODO rendere più leggibile il codice introducendo un datastructure che consideri in maniera più compatta le x (node_fts, edge_fts, graph_fts, adj_mat, hidden)

In [3]:
# Per utilizzare il metodo d2_forward() in GATv2FullD2 bisogna aggiungere applicare le 4 proiezioni 
# degli input per node (b, n, d), edge (b, n, n, d) e graph (b, d) features 
# stack(originale, flip_verticale, flip_orizzontale, flip_verticale_orizzontale).
# Nuova dimensione dell'input per node_fts p.e.: (4, b, n, d)
# ESEMPIO
mat = jnp.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
mat_verticale = mat[::-1, :]
mat_orizzontale = mat[:, ::-1]
mat_verticale_orizzontale = mat[::-1, ::-1]
print(mat, mat_verticale, mat_orizzontale, mat_verticale_orizzontale, sep="\n")
# poi stack sulla prima dimensione
mats = [mat, mat_verticale, mat_orizzontale, mat_verticale_orizzontale]
mat = jnp.stack(mats, axis=0)
print(jnp.mean(mat, axis=0).shape)
# jnp.mean(jnp.stack(mats, axis=0), axis=1)

[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]
[[13 14 15 16]
 [ 9 10 11 12]
 [ 5  6  7  8]
 [ 1  2  3  4]]
[[ 4  3  2  1]
 [ 8  7  6  5]
 [12 11 10  9]
 [16 15 14 13]]
[[16 15 14 13]
 [12 11 10  9]
 [ 8  7  6  5]
 [ 4  3  2  1]]
(4, 4)


In [4]:
K = jax.random.PRNGKey(0)
jnp.mean(jax.random.normal(K, shape=(10, 20, 30)), axis=0).shape


(20, 30)

In [5]:
def model_fn(node_fts, edge_fts, graph_fts, adj_mat, hidden):
    model = GATv2FullD2(out_size=8, nb_heads=8)
    return model(node_fts, edge_fts, graph_fts, adj_mat, hidden) # d2_forward()

model_init, model_apply = hk.transform(model_fn)

# Prepare data

In [6]:
# creo finte features per il grafo
rng = jax.random.PRNGKey(0)
# nodo
node_fts = jax.random.normal(rng, shape=(200, 22, 4))
# egdes
edge_fts = jax.random.normal(rng, shape=(200, 22, 22, 1))
# etichetta grafo
graph_fts = jax.random.normal(rng, shape=(200, 2))
# matrice di adiacenza
adj_mat = jnp.ones(shape=(200, 22, 22))
# hidden (?)
hidden = jax.random.normal(rng, shape=(200, 22, 4))
# labels tackle
labels = jax.random.permutation(rng, jnp.concat([jnp.ones(100), jnp.zeros(100)]))

DATI FLIPPATI PER IL METODO d2_forward()

In [7]:
edge_fts.shape

(200, 22, 22, 1)

In [8]:
(4%8)

4

In [9]:
fnode_fts = jnp.stack([node_fts,
                     node_fts[:, ::-1, :],
                     node_fts[:, :, ::-1],
                     node_fts[:, ::-1, ::-1]], axis=0)
fedge_fts = jnp.stack([edge_fts,
                     edge_fts[:, ::-1, :],
                     edge_fts[:, :, ::-1],
                     edge_fts[:, ::-1, ::-1]], axis=0)
# da valutare se graph feat deve essere flippato
# fgraph_fts = jnp.stack([graph_fts,
#                      graph_fts[:, ::-1, :],
#                      graph_fts[:, :, ::-1],
#                      graph_fts[:, ::-1, ::-1]], axis=0)
fgraph_fts = jnp.stack((graph_fts,)*4, axis=0)

fgraph_fts.shape

(4, 200, 2)

In [10]:
print(edge_fts.shape[:-1])

(200, 22, 22)


In [11]:
len(graph_fts)

200

In [12]:
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG rng to `init`, since parameters
# are typically initialized randomly:
rng = jax.random.PRNGKey(42)
params = model_init(rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = model_apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)[0]

In [13]:
y.shape

(200, 22, 8)

In [14]:
b, n, d = y.shape
jnp.reshape(y, (b, n*d)).shape

(200, 176)

### ora per d2_forward()

In [17]:
def model_fwd_fn(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden):
    model = GATv2FullD2(out_size=4, nb_heads=4)
    return model.d2_forward(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

model_fwd_init, model_fwd_apply = hk.transform(model_fwd_fn)

In [18]:
rng = jax.random.PRNGKey(0)

params = model_fwd_init(rng, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

y = model_fwd_apply(params, rng, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)


In [None]:
y.shape

(4, 200, 22, 4)

In [19]:
class myNet(hk.Module):
    def __init__(self, gatv2_out_size: int, gatv2_nb_heads: int, linear_out_size: int = 2):
        super().__init__()
        self.gatv2 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.linear = hk.Linear(output_size=linear_out_size)

    def __call__(self, node_fts, edge_fts, graph_fts, adj_mat, hidden):
        gat_output, _ = self.gatv2(node_fts, edge_fts, graph_fts, adj_mat, hidden)
        b, n, d = gat_output.shape
        flattened_output = jnp.reshape(gat_output, (b, n*d))  # Flat node features on last axis
        scalar_output = self.linear(flattened_output)
        return scalar_output

def myNet_fn(node_fts, edge_fts, graph_fts, adj_mat, hidden):
    model = myNet(gatv2_out_size=8, gatv2_nb_heads=8)
    return model(node_fts, edge_fts, graph_fts, adj_mat, hidden)

myNet_init, myNet_apply = hk.transform(myNet_fn)

rng = jax.random.PRNGKey(42)
params = myNet_init(rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = myNet_apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

In [20]:
# output di d2_forward() è una lista di len uguale a 4 (da capire come utilizzare questi dati moltiplicati)
# per 4
class myNet_d2(hk.Module):
    def __init__(self, gatv2_out_size: int, gatv2_nb_heads: int, linear_out_size: int = 2):
        super().__init__()
        self.gatv2_1 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_2 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_3 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_4 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        
        self.linear = hk.Linear(output_size=linear_out_size)

    def __call__(self, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden):
        gat_output = self.gatv2_1.d2_forward(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_2.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_3.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_4.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        
        b, n, d = gat_output[0].shape
        gat_output = jnp.mean(array(gat_output), axis=0)
        flattened_output = jnp.reshape(gat_output, (b, n*d))  # Flat node features on last axis
        scalar_output = self.linear(flattened_output)
        return scalar_output

def myNet_d2_fn(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden):
    model = myNet_d2(gatv2_out_size=8, gatv2_nb_heads=8)
    return model(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

myNet_d2_init, myNet_d2_apply = hk.transform(myNet_d2_fn)

rng = jax.random.PRNGKey(42)
params = myNet_d2_init(rng, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = myNet_d2_apply(params, rng, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

In [21]:
jnp.ravel(y).shape

(400,)

In [22]:
y.shape

(200, 2)

In [25]:
len(params)
print(jax.devices())

[CpuDevice(id=0)]


# setup loss func

In [26]:
def loss_fn(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    logits = myNet_apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 2)))
    return loss

In [27]:
#check
loss_fn(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, labels)

ValueError: Unable to retrieve parameter 'w' for module 'my_net/~/gatv2_aggr_clrs_processor/linear' All parameters must be created as part of `init`.

# setup the optimizer

In [None]:
import optax

learning_rate = 1e-4
optimizer = optax.adam(learning_rate)

#  Initialize Parameters and Optimizer State

In [None]:
rng = jax.random.PRNGKey(42)

params = myNet_init(rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

opt_state = optimizer.init(params)


# Define the training step

In [None]:
@jax.jit
def train_step(params, opt_state, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [None]:
# check
train_step(params, opt_state, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, labels)

({'my_net/~/gatv2_aggr_clrs_processor/linear': {'b': Array([-0.00099999, -0.00099999], dtype=float32),
   'w': Array([[ 0.25564265,  0.15695918],
          [-0.4795112 , -0.37934893],
          [-0.41237822, -0.22365592],
          [-0.43243405,  0.215911  ],
          [-0.18614386, -0.10927615],
          [ 0.3672926 , -0.14089692],
          [ 0.10815946,  0.4379233 ],
          [-0.07625037,  0.08147988]], dtype=float32)},
  'my_net/~/gatv2_aggr_clrs_processor/linear_1': {'b': Array([-0.00099999, -0.00099999], dtype=float32),
   'w': Array([[-0.48068088,  0.6309275 ],
          [ 0.18759483, -0.14978436],
          [-0.44798693,  0.20787208],
          [ 0.14497875, -0.19853099],
          [ 0.16847068, -0.18189548],
          [-0.378203  ,  0.02443844],
          [ 0.04869739,  0.31703183],
          [-0.2213596 , -0.32207504]], dtype=float32)},
  'my_net/~/gatv2_aggr_clrs_processor/linear_2': {'b': Array([ 0.00099999, -0.00099994], dtype=float32),
   'w': Array([[ 0.14765775,  0.0

Train model

In [None]:
num_epochs = 10
batch_size = 128
num_batches = node_fts.shape[0] // batch_size

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        node_fts_batch = node_fts[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        edge_fts_batch = edge_fts[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        graph_fts_batch = graph_fts[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        adj_mat_batch = adj_mat[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        hidden_batch = hidden[batch_idx * batch_size:(batch_idx + 1) * batch_size]

        y_batch = labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        
        params, opt_state, loss = train_step(params, opt_state, rng, node_fts_batch, edge_fts_batch, graph_fts_batch, adj_mat_batch, hidden_batch, y_batch)
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss}')


Epoch 0, Batch 0, Loss: 0.9068794250488281
Epoch 0, Batch 10, Loss: 0.831483006477356
Epoch 0, Batch 20, Loss: 0.5115661025047302
Epoch 0, Batch 30, Loss: 0.8358895182609558
Epoch 1, Batch 0, Loss: 0.8545390963554382
Epoch 1, Batch 10, Loss: 0.8069888949394226
Epoch 1, Batch 20, Loss: 0.5200721621513367
Epoch 1, Batch 30, Loss: 0.8098663687705994
Epoch 2, Batch 0, Loss: 0.8402116298675537
Epoch 2, Batch 10, Loss: 0.7819634675979614
Epoch 2, Batch 20, Loss: 0.5270403623580933
Epoch 2, Batch 30, Loss: 0.7877657413482666
Epoch 3, Batch 0, Loss: 0.8303874731063843
Epoch 3, Batch 10, Loss: 0.7595011591911316
Epoch 3, Batch 20, Loss: 0.5323825478553772
Epoch 3, Batch 30, Loss: 0.7688813209533691
Epoch 4, Batch 0, Loss: 0.8212447166442871
Epoch 4, Batch 10, Loss: 0.739532470703125
Epoch 4, Batch 20, Loss: 0.5364087820053101
Epoch 4, Batch 30, Loss: 0.7528592944145203
Epoch 5, Batch 0, Loss: 0.814746081829071
Epoch 5, Batch 10, Loss: 0.7228638529777527
Epoch 5, Batch 20, Loss: 0.54020351171493

In [28]:
def loss_fn(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    logits = myNet_d2_apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 2)))
    return loss

learning_rate = 1e-3
optimizer = optax.adam(learning_rate)

rng = jax.random.PRNGKey(42)

params = myNet_d2_init(rng, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

opt_state = optimizer.init(params)

@jax.jit
def train_step(params, opt_state, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [29]:
num_epochs = 10
batch_size = 6
num_batches = node_fts.shape[0] // batch_size

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        fnode_fts_batch = fnode_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        fedge_fts_batch = fedge_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        fgraph_fts_batch = fgraph_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        adj_mat_batch = adj_mat[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        hidden_batch = hidden[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        y_batch = labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        
        params, opt_state, loss = train_step(params, opt_state, rng, fnode_fts_batch, fedge_fts_batch, fgraph_fts_batch, adj_mat_batch, hidden_batch, y_batch)
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss}')


Epoch 0, Batch 0, Loss: 0.6903144717216492
Epoch 0, Batch 10, Loss: 0.7177843451499939
Epoch 0, Batch 20, Loss: 0.6700448393821716
Epoch 0, Batch 30, Loss: 0.6930404901504517
Epoch 1, Batch 0, Loss: 0.6693402528762817
Epoch 1, Batch 10, Loss: 0.6974799036979675
Epoch 1, Batch 20, Loss: 0.6561087965965271
Epoch 1, Batch 30, Loss: 0.684623122215271
Epoch 2, Batch 0, Loss: 0.6561148166656494
Epoch 2, Batch 10, Loss: 0.7010778188705444
Epoch 2, Batch 20, Loss: 0.652472734451294
Epoch 2, Batch 30, Loss: 0.6765684485435486
Epoch 3, Batch 0, Loss: 0.6411766409873962
Epoch 3, Batch 10, Loss: 0.7023191452026367
Epoch 3, Batch 20, Loss: 0.6400326490402222
Epoch 3, Batch 30, Loss: 0.6728585958480835
Epoch 4, Batch 0, Loss: 0.6226884126663208
Epoch 4, Batch 10, Loss: 0.6976571083068848
Epoch 4, Batch 20, Loss: 0.6121337413787842
Epoch 4, Batch 30, Loss: 0.6634902358055115
Epoch 5, Batch 0, Loss: 0.5955710411071777
Epoch 5, Batch 10, Loss: 0.6877567172050476
Epoch 5, Batch 20, Loss: 0.5906676054000

# esempio

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp

class MyLayer(hk.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

    def __call__(self, edge_fts, graph_fts, adj_mat, hidden):
        edge_fts = hk.Linear(self.hidden_dim)(edge_fts)
        graph_fts = hk.Linear(self.hidden_dim)(graph_fts)
        hidden = hk.Linear(self.hidden_dim)(hidden)
        aggregated = jnp.matmul(adj_mat, edge_fts)
        output = aggregated + graph_fts + hidden
        return output

def model_fn(edge_fts, graph_fts, adj_mat, hidden):
    model = MyLayer(hidden.shape[-1])
    return model(edge_fts, graph_fts, adj_mat, hidden)

model_init, model_apply = hk.transform(model_fn)

# Example input data
edge_fts = jnp.array([[1.0, 2.0], [3.0, 4.0]])  # Shape (num_edges, feature_dim)
graph_fts = jnp.array([[1.0, 2.0]])             # Shape (num_graphs, feature_dim)
adj_mat = jnp.array([[0, 1], [1, 0]])           # Shape (num_nodes, num_nodes)
hidden = jnp.array([[0.5, 0.5]])                # Shape (num_nodes, hidden_dim)

rng = jax.random.PRNGKey(42)

# Initialize the model parameters
params = model_init(rng, edge_fts, graph_fts, adj_mat, hidden)

# Perform a forward pass
output = model_apply(params, rng, edge_fts, graph_fts, adj_mat, hidden)
print(output)


[[-1.6547499   4.2710185 ]
 [ 0.34689993  1.204605  ]]


In [None]:
num_d2_actions = 4

d2_inverses = [
0, 1, 2, 3  # All members of D_2 are self-inverses!
]

d2_multiply = [
[0, 1, 2, 3],
[1, 0, 3, 2],
[2, 3, 0, 1],
[3, 2, 1, 0],
]

# assert len(node_fts) == num_d2_actions
# assert len(edge_fts) == num_d2_actions
# assert len(graph_fts) == num_d2_actions

ret_nodes = []
adj_mat = jnp.ones_like(adj_mat)

for g in range(num_d2_actions):
    emb_values = []
    for h in range(num_d2_actions):
        gh = d2_multiply[d2_inverses[g]][h]
        # print("gh", gh)
        # print("h", h)
        # print("d2_inverses[g]", d2_inverses[g])
        node_features = jnp.concatenate(
            (node_fts[g], node_fts[gh]),
            axis=-1)
        edge_features = jnp.concatenate(
            (edge_fts[g], edge_fts[gh]),
            axis=-1)
        graph_features = jnp.concatenate(
            (graph_fts[g], graph_fts[gh]),
            axis=-1)
    #     cell_embedding = model_apply(params, rng,
    #         node_fts=node_features,
    #         edge_fts=edge_features,
    #         graph_fts=graph_features,
    #         adj_mat=adj_mat,
    #         hidden=hidden
    #     )
    #     emb_values.append(cell_embedding[0])
    # ret_nodes.append(
    # jnp.mean(jnp.stack(emb_values, axis=0), axis=0)
    # )

gh 0
h 0
d2_inverses[g] 0
gh 1
h 1
d2_inverses[g] 0
gh 2
h 2
d2_inverses[g] 0
gh 3
h 3
d2_inverses[g] 0
gh 1
h 0
d2_inverses[g] 1
gh 0
h 1
d2_inverses[g] 1
gh 3
h 2
d2_inverses[g] 1
gh 2
h 3
d2_inverses[g] 1
gh 2
h 0
d2_inverses[g] 2
gh 3
h 1
d2_inverses[g] 2
gh 0
h 2
d2_inverses[g] 2
gh 1
h 3
d2_inverses[g] 2
gh 3
h 0
d2_inverses[g] 3
gh 2
h 1
d2_inverses[g] 3
gh 1
h 2
d2_inverses[g] 3
gh 0
h 3
d2_inverses[g] 3
