In [4]:
from clrs._src.processors import GATv2FullD2
import haiku as hk
import optax
import jax
import jax.numpy as jnp
from jax.numpy import array
import pickle

# Load Data

In [5]:
# Read dictionary pkl file
with open('C:\\Users\\Lenovo\\OneDrive\\Documenti\\Università\\Ricerca\\NFL kaggle challenge 2024\\data\\grafo_dati.pkl', 'rb') as fp:
    grafo_dati = pickle.load(fp)
    print('grafo dati loaded succesfully')

grafo dati loaded succesfully


## Global variables for trainig

In [6]:
indice = list(grafo_dati)
num_epochs = 10
batch_size = 128
num_batches = len(indice)
adj_mat = jnp.ones(shape=(batch_size, 23, 23))

In [12]:
grafo_dati[indice[0]]["graph_fts"].shape

(5,)

In [14]:
#check dimensions
for i in indice:
    assert(grafo_dati[i]["node_fts"].shape == (23, 7))
    assert(grafo_dati[i]["edge_fts"].shape == (23, 23))
    assert(grafo_dati[i]["graph_fts"].shape == (5,))
    

In [57]:
# create empty dataset 
# dataset_node_fts = jnp.zeros(shape = (len(indice), 4, 23, 7))
dataset_edge_fts = jnp.zeros(shape = (len(indice), 4, 23, 23, 1))
dataset_graph_fts = jnp.zeros(shape = (len(indice), 4, 5))
labels = jnp.zeros(shape=(len(indice), 1))

In [58]:
# populate datasets

for i, ind in enumerate(indice):
    node_fts = grafo_dati[ind]["node_fts"]
    fnode_fts = jnp.stack([node_fts,
                     node_fts[::-1, :],
                     node_fts[:, ::-1],
                     node_fts[::-1, ::-1]], axis=0)
    edge_fts = grafo_dati[ind]["edge_fts"]
    fedge_fts = jnp.stack([edge_fts,
						edge_fts[::-1, :],
						edge_fts[:, ::-1],
						edge_fts[::-1, ::-1]], axis=0)

    dataset_node_fts = dataset_node_fts.at[i].set(fnode_fts)
    dataset_edge_fts = dataset_edge_fts.at[i].set(jnp.expand_dims(fedge_fts, axis=-1))
    dataset_graph_fts = dataset_graph_fts.at[i].set(grafo_dati[ind]["graph_fts"]) # ripete l'array per tutte e quattro le dimensioni
    labels = labels.at[i].set(grafo_dati[ind]["label"])


In [71]:
print(dataset_node_fts.shape)
print(dataset_edge_fts.shape)
print(dataset_graph_fts.shape)

(3444, 4, 23, 7)
(3444, 4, 23, 23, 1)
(3444, 4, 5)


In [72]:
labels = jnp.reshape(jax.nn.one_hot(labels, 2), (len(indice), 2))

dataset_node_fts = jnp.transpose(dataset_node_fts, (1, 0, 2, 3))
dataset_edge_fts = jnp.transpose(dataset_edge_fts, (1, 0, 2, 3, 4))
dataset_graph_fts = jnp.transpose(dataset_graph_fts, (1, 0, 2))


# Create GNN

In [67]:
class myNet_d2(hk.Module):
    def __init__(self, gatv2_out_size: int, gatv2_nb_heads: int, linear_out_size: int = 2):
        super().__init__()
        self.gatv2_1 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_2 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_3 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        self.gatv2_4 = GATv2FullD2(out_size=gatv2_out_size, nb_heads=gatv2_nb_heads)
        
        self.linear = hk.Linear(output_size=linear_out_size)

    def __call__(self, fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden):
        gat_output = self.gatv2_1.d2_forward(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_2.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_3.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        gat_output = array(gat_output)
        hidden = jnp.mean(gat_output, axis=0)
        
        gat_output = self.gatv2_4.d2_forward(gat_output, fedge_fts, fgraph_fts, adj_mat, hidden)
        
        b, n, d = gat_output[0].shape
        gat_output = jnp.mean(array(gat_output), axis=0)
        flattened_output = jnp.reshape(gat_output, (b, n*d))  # Flat node features on last axis
        scalar_output = self.linear(flattened_output)
        return scalar_output

def myNet_d2_fn(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden):
    model = myNet_d2(gatv2_out_size=8, gatv2_nb_heads=8)
    return model(fnode_fts, fedge_fts, fgraph_fts, adj_mat, hidden)

myNet_d2_init, myNet_d2_apply = hk.transform(myNet_d2_fn)

rng = jax.random.PRNGKey(42)

# Loss function

In [68]:
def loss_fn(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    logits = myNet_d2_apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    return loss

# Set up optimizer

In [69]:
import optax

learning_rate = 1e-4
optimizer = optax.adam(learning_rate)

# Initialize params and optimizer

In [75]:
params = myNet_d2_init(rng, dataset_node_fts[:, :128], dataset_edge_fts[:, :128], dataset_graph_fts[:, :128], adj_mat, dataset_node_fts[0, :128])

opt_state = optimizer.init(params)

# Define train step

In [77]:
@jax.jit
def train_step(params, opt_state, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [79]:
train_step(params, opt_state, rng, dataset_node_fts[:, :128], dataset_edge_fts[:, :128], dataset_graph_fts[:, :128], adj_mat, dataset_node_fts[0, :128], labels[:128])

# Train model

In [None]:
for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        node_fts_batch = dataset_node_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        edge_fts_batch = dataset_edge_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        graph_fts_batch = dataset_graph_fts[:, batch_idx * batch_size:(batch_idx + 1) * batch_size]
        hidden_batch = dataset_node_fts[0, batch_idx * batch_size:(batch_idx + 1) * batch_size]

        y_batch = labels[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        
        params, opt_state, loss = train_step(params, opt_state, rng, node_fts_batch, edge_fts_batch, graph_fts_batch, adj_mat, hidden_batch, y_batch)
        
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss}')