# Using the model

### Model

In [31]:
import abc
from typing import Any, Callable, List, Optional, Tuple

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np


_Array = chex.Array
_Fn = Callable[..., Any]
BIG_NUMBER = 1e6
PROCESSOR_TAG = 'clrs_processor'


class Processor(hk.Module):
  """Processor abstract base class."""

  def __init__(self, name: str):
    if not name.endswith(PROCESSOR_TAG):
      name = name + '_' + PROCESSOR_TAG
    super().__init__(name=name)

  @abc.abstractmethod
  def __call__(
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **kwargs,
  ) -> Tuple[_Array, Optional[_Array]]:
    """Processor inference step.

    Args:
      node_fts: Node features.
      edge_fts: Edge features.
      graph_fts: Graph features.
      adj_mat: Graph adjacency matrix.
      hidden: Hidden features.
      **kwargs: Extra kwargs.

    Returns:
      Output of processor inference step as a 2-tuple of (node, edge)
      embeddings. The edge embeddings can be None.
    """
    pass

  @property
  def inf_bias(self):
    return False

  @property
  def inf_bias_edge(self):
    return False


class GAT(Processor):
  """Graph Attention Network (Velickovic et al., ICLR 2018)."""

  def __init__(
      self,
      out_size: int,
      nb_heads: int,
      activation: Optional[_Fn] = jax.nn.relu,
      residual: bool = True,
      use_ln: bool = False,
      name: str = 'gat_aggr',
  ):
    super().__init__(name=name)
    self.out_size = out_size
    self.nb_heads = nb_heads
    if out_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the width!')
    self.head_size = out_size // nb_heads
    self.activation = activation
    self.residual = residual
    self.use_ln = use_ln

  def __call__(  # pytype: disable=signature-mismatch  # numpy-scalars
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **unused_kwargs,
  ) -> _Array:
    """GAT inference step."""

    b, n, _ = node_fts.shape
    assert edge_fts.shape[:-1] == (b, n, n)
    assert graph_fts.shape[:-1] == (b,)
    assert adj_mat.shape == (b, n, n)

    z = jnp.concatenate([node_fts, hidden], axis=-1)
    m = hk.Linear(self.out_size)
    skip = hk.Linear(self.out_size)

    bias_mat = (adj_mat - 1.0) * 1e9
    bias_mat = jnp.tile(bias_mat[..., None],
                        (1, 1, 1, self.nb_heads))     # [B, N, N, H]
    bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2))  # [B, H, N, N]

    a_1 = hk.Linear(self.nb_heads)
    a_2 = hk.Linear(self.nb_heads)
    a_e = hk.Linear(self.nb_heads)
    a_g = hk.Linear(self.nb_heads)

    values = m(z)                                      # [B, N, H*F]
    values = jnp.reshape(
        values,
        values.shape[:-1] + (self.nb_heads, self.head_size))  # [B, N, H, F]
    values = jnp.transpose(values, (0, 2, 1, 3))              # [B, H, N, F]

    att_1 = jnp.expand_dims(a_1(z), axis=-1)
    att_2 = jnp.expand_dims(a_2(z), axis=-1)
    att_e = a_e(edge_fts)
    att_g = jnp.expand_dims(a_g(graph_fts), axis=-1)

    logits = (
        jnp.transpose(att_1, (0, 2, 1, 3)) +  # + [B, H, N, 1]
        jnp.transpose(att_2, (0, 2, 3, 1)) +  # + [B, H, 1, N]
        jnp.transpose(att_e, (0, 3, 1, 2)) +  # + [B, H, N, N]
        jnp.expand_dims(att_g, axis=-1)       # + [B, H, 1, 1]
    )                                         # = [B, H, N, N]
    coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1)
    ret = jnp.matmul(coefs, values)  # [B, H, N, F]
    ret = jnp.transpose(ret, (0, 2, 1, 3))  # [B, N, H, F]
    ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,))  # [B, N, H*F]

    if self.residual:
      ret += skip(z)

    if self.activation is not None:
      ret = self.activation(ret)

    if self.use_ln:
      ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
      ret = ln(ret)

    return ret, None  # pytype: disable=bad-return-type  # numpy-scalars


class GATFull(GAT):
  """Graph Attention Network with full adjacency matrix."""

  def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
               adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
    adj_mat = jnp.ones_like(adj_mat)
    return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)


class GATv2(Processor):
  """Graph Attention Network v2 (Brody et al., ICLR 2022)."""

  def __init__(
      self,
      out_size: int,
      nb_heads: int,
      mid_size: Optional[int] = None,
      activation: Optional[_Fn] = jax.nn.relu,
      residual: bool = True,
      use_ln: bool = False,
      name: str = 'gatv2_aggr',
  ):
    super().__init__(name=name)
    if mid_size is None:
      self.mid_size = out_size
    else:
      self.mid_size = mid_size
    self.out_size = out_size
    self.nb_heads = nb_heads
    if out_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the width!')
    self.head_size = out_size // nb_heads
    if self.mid_size % nb_heads != 0:
      raise ValueError('The number of attention heads must divide the message!')
    self.mid_head_size = self.mid_size // nb_heads
    self.activation = activation
    self.residual = residual
    self.use_ln = use_ln

  def __call__(  # pytype: disable=signature-mismatch  # numpy-scalars
      self,
      node_fts: _Array,
      edge_fts: _Array,
      graph_fts: _Array,
      adj_mat: _Array,
      hidden: _Array,
      **unused_kwargs,
  ) -> _Array:
    """GATv2 inference step."""

    b, n, _ = node_fts.shape
    assert edge_fts.shape[:-1] == (b, n, n)
    assert graph_fts.shape[:-1] == (b,)
    assert adj_mat.shape == (b, n, n)

    z = jnp.concatenate([node_fts, hidden], axis=-1)
    m = hk.Linear(self.out_size)
    skip = hk.Linear(self.out_size)

    bias_mat = (adj_mat - 1.0) * 1e9
    bias_mat = jnp.tile(bias_mat[..., None],
                        (1, 1, 1, self.nb_heads))     # [B, N, N, H]
    bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2))  # [B, H, N, N]

    w_1 = hk.Linear(self.mid_size)
    w_2 = hk.Linear(self.mid_size)
    w_e = hk.Linear(self.mid_size)
    w_g = hk.Linear(self.mid_size)

    a_heads = []
    for _ in range(self.nb_heads):
      a_heads.append(hk.Linear(1))

    values = m(z)                                      # [B, N, H*F]
    values = jnp.reshape(
        values,
        values.shape[:-1] + (self.nb_heads, self.head_size))  # [B, N, H, F]
    values = jnp.transpose(values, (0, 2, 1, 3))              # [B, H, N, F]

    pre_att_1 = w_1(z)
    pre_att_2 = w_2(z)
    pre_att_e = w_e(edge_fts)
    pre_att_g = w_g(graph_fts)

    pre_att = (
        jnp.expand_dims(pre_att_1, axis=1) +     # + [B, 1, N, H*F]
        jnp.expand_dims(pre_att_2, axis=2) +     # + [B, N, 1, H*F]
        pre_att_e +                              # + [B, N, N, H*F]
        jnp.expand_dims(pre_att_g, axis=(1, 2))  # + [B, 1, 1, H*F]
    )                                            # = [B, N, N, H*F]

    pre_att = jnp.reshape(
        pre_att,
        pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size)
    )  # [B, N, N, H, F]

    pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4))  # [B, H, N, N, F]

    # This part is not very efficient, but we agree to keep it this way to
    # enhance readability, assuming `nb_heads` will not be large.
    logit_heads = []
    for head in range(self.nb_heads):
      logit_heads.append(
          jnp.squeeze(
              a_heads[head](jax.nn.leaky_relu(pre_att[:, head])),
              axis=-1)
      )  # [B, N, N]

    logits = jnp.stack(logit_heads, axis=1)  # [B, H, N, N]

    coefs = jax.nn.softmax(logits + bias_mat, axis=-1)
    ret = jnp.matmul(coefs, values)  # [B, H, N, F]
    ret = jnp.transpose(ret, (0, 2, 1, 3))  # [B, N, H, F]
    ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,))  # [B, N, H*F]

    if self.residual:
      ret += skip(z)

    if self.activation is not None:
      ret = self.activation(ret)

    if self.use_ln:
      ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
      ret = ln(ret)

    return ret, None  # pytype: disable=bad-return-type  # numpy-scalars


class GATv2FullD2(GATv2):
  """Graph Attention Network v2 with full adjacency matrix and D2 symmetry."""

  def d2_forward(self,
                 node_fts: List[_Array],
                 edge_fts: List[_Array],
                 graph_fts: List[_Array],
                 adj_mat: _Array,
                 hidden: _Array,
                 **unused_kwargs) -> List[_Array]:
    num_d2_actions = 4

    d2_inverses = [
        0, 1, 2, 3  # All members of D_2 are self-inverses!
    ]

    d2_multiply = [
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [2, 3, 0, 1],
        [3, 2, 1, 0],
    ]

    assert len(node_fts) == num_d2_actions
    assert len(edge_fts) == num_d2_actions
    assert len(graph_fts) == num_d2_actions

    ret_nodes = []
    adj_mat = jnp.ones_like(adj_mat)

    for g in range(num_d2_actions):
      emb_values = []
      for h in range(num_d2_actions):
        gh = d2_multiply[d2_inverses[g]][h]
        node_features = jnp.concatenate(
            (node_fts[g], node_fts[gh]),
            axis=-1)
        edge_features = jnp.concatenate(
            (edge_fts[g], edge_fts[gh]),
            axis=-1)
        graph_features = jnp.concatenate(
            (graph_fts[g], graph_fts[gh]),
            axis=-1)
        cell_embedding = super().__call__(
            node_fts=node_features,
            edge_fts=edge_features,
            graph_fts=graph_features,
            adj_mat=adj_mat,
            hidden=hidden
        )
        emb_values.append(cell_embedding[0])
      ret_nodes.append(
          jnp.mean(jnp.stack(emb_values, axis=0), axis=0)
      )

    return ret_nodes

class GATv2Full(GATv2):
  """Graph Attention Network v2 with full adjacency matrix."""

  def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
               adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
    adj_mat = jnp.ones_like(adj_mat)
    return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)

def forward(node_fts, edge_fts, graph_fts, adj_mat, hidden):
    # Instantiate your GATv2FullD2 here, within the function.
    out_size = 128  
    nb_heads = 8
    model = GATv2FullD2(out_size=out_size, nb_heads=nb_heads)
    return model.d2_forward(node_fts, edge_fts, graph_fts, adj_mat, hidden)

# main
batch_size = 1
n_players = 22
n_node_features = 5
n_edge_features = 3
n_graph_features = 2
n_latent_features = 4

# Create the data
node_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_players, n_node_features)), axis=0) for _ in range(4)]
edge_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_players, n_players, n_edge_features)), axis=0) for _ in range(4)]
graph_fts = [jnp.expand_dims(jnp.array(np.random.rand(n_graph_features)), axis=0) for _ in range(4)]
adj_mat = jnp.expand_dims(jnp.array(np.random.randint(0, 2, (n_players, n_players))), axis=0)
hidden = jnp.expand_dims(jnp.zeros((n_players, n_latent_features)), axis=0)

# Transform the function using Haiku
transformed_forward = hk.transform(forward)

# Create a JAX random key for initialization
rng = jax.random.PRNGKey(42)

# Initialize the model (parameters)
params = transformed_forward.init(rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)

# Apply the model
output = transformed_forward.apply(params, rng, node_fts, edge_fts, graph_fts, adj_mat, hidden)


### Receiver Prediction

In [27]:
# Define a linear layer to compute scores
class ReceiverPredictor(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.linear = hk.Linear(1)
    
    def __call__(self, embeddings):
        return self.linear(embeddings)

# Initialize and apply the linear layer
def receiver_prediction_forward(embeddings):
    model = ReceiverPredictor()
    return model(embeddings)

# Average the embeddings across the four transformed views
averaged_embeddings = jnp.mean(jnp.stack(output), axis=0)

# Transform the receiver prediction function using Haiku
transformed_receiver_prediction = hk.transform(receiver_prediction_forward)
# Initialize the receiver prediction model (parameters)
params_receiver = transformed_receiver_prediction.init(rng, averaged_embeddings)
# Apply the receiver prediction model to get scores
scores = transformed_receiver_prediction.apply(params_receiver, rng, averaged_embeddings)

# Squeeze the scores to remove the last dimension and apply softmax
scores = jnp.squeeze(scores, axis=-1)
probabilities = jax.nn.softmax(scores, axis=-1)

# Print the probabilities
print(probabilities)

[[0.04747611 0.04750526 0.0488289  0.04205194 0.04901259 0.0424196
  0.04168977 0.04840901 0.04685603 0.04676333 0.04718176 0.04297247
  0.04319186 0.04787761 0.04247737 0.04570602 0.04057295 0.04831294
  0.04313481 0.04904582 0.04600187 0.04251195]]


### Shot prediction

In [32]:
# Step 1: Average the embeddings across the four transformed views
averaged_embeddings = jnp.mean(jnp.stack(output), axis=0)

# Step 2: Define a linear layer to compute the shot prediction score
class ShotPredictor(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.linear = hk.Linear(1)
    
    def __call__(self, embeddings):
        global_embedding = jnp.mean(embeddings, axis=1)
        return self.linear(global_embedding)

# Step 3: Initialize and apply the linear layer for shot prediction
def shot_prediction_forward(embeddings):
    model = ShotPredictor()
    return model(embeddings)

# Transform the shot prediction function using Haiku
transformed_shot_prediction = hk.transform(shot_prediction_forward)

# Create a JAX random key for initialization
rng = jax.random.PRNGKey(42)

# Initialize the shot prediction model (parameters)
params_shot = transformed_shot_prediction.init(rng, averaged_embeddings)

# Apply the shot prediction model to get the score
score = transformed_shot_prediction.apply(params_shot, rng, averaged_embeddings)

# Step 4: Apply sigmoid to convert the score into a probability
shot_probability = jax.nn.sigmoid(score)

# Print the shot probability
print(shot_probability)

[[0.5949052]]


### Suggest adjustments in player positions and velocities

In [34]:
from jax import random

# Step 1: Average the embeddings across the four transformed views
averaged_embeddings = jnp.mean(jnp.stack(output), axis=0)

# Step 2: Define a VAE (Variational Autoencoder) to generate new player positions and velocities
class TacticGenerator(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.encoder = hk.Linear(64)
        self.mean = hk.Linear(4)  # For positions (x, y) and velocities (vx, vy)
        self.log_std = hk.Linear(4)  # For positions (x, y) and velocities (vx, vy)

    def __call__(self, embeddings, rng):
        # Encoder
        encoded = self.encoder(embeddings)
        
        # Generate latent variables for positions and velocities
        mean = self.mean(encoded)
        log_std = self.log_std(encoded)
        std = jnp.exp(log_std)
        
        # Reparameterization trick
        eps = random.normal(rng, mean.shape)
        z = mean + eps * std
        
        return z, mean, log_std

# Step 3: Initialize and apply the VAE for tactic generation
def tactic_generation_forward(embeddings, rng):
    model = TacticGenerator()
    return model(embeddings, rng)

# Transform the tactic generation function using Haiku
transformed_tactic_generation = hk.transform(tactic_generation_forward)

# Create a JAX random key for initialization
rng = jax.random.PRNGKey(42)

# Initialize the tactic generation model (parameters)
params_tactic = transformed_tactic_generation.init(rng, averaged_embeddings, rng)

# Apply the tactic generation model to generate new positions and velocities
new_positions_velocities, mean, log_std = transformed_tactic_generation.apply(params_tactic, rng, averaged_embeddings, rng)

# Extract positions and velocities from the generated output
new_positions = new_positions_velocities[:, :, :2]  # First two dimensions for positions (x, y)
new_velocities = new_positions_velocities[:, :, 2:]  # Last two dimensions for velocities (vx, vy)

# Print the new positions and velocities
print("New Positions:", new_positions)
print("New Velocities:", new_velocities)

New Positions: [[[-0.15373099 -0.5776852 ]
  [-0.54797816  0.349627  ]
  [ 0.796337    0.27583018]
  [ 0.8361689   0.658131  ]
  [-1.1743636   0.26393038]
  [ 1.1182784   0.09397618]
  [-0.972666    0.57598615]
  [ 0.88302577 -1.1866739 ]
  [ 0.38232225 -0.19630621]
  [-1.4119033  -0.42984524]
  [-1.0729328   1.243825  ]
  [-1.4856875   0.30136743]
  [-0.2459184   1.2662115 ]
  [-2.2002583   0.9062169 ]
  [-2.3522437   0.55627424]
  [ 0.46456093  0.10938837]
  [ 1.0150944   0.8138926 ]
  [-0.4559812   2.5382173 ]
  [-0.50415194  0.20640746]
  [-0.810416   -1.6976104 ]
  [ 0.7919555   1.781926  ]
  [ 1.92274     1.7019848 ]]]
New Velocities: [[[ 0.29783598  0.17285213]
  [ 0.4111511  -1.1176956 ]
  [-0.5487272  -0.26066583]
  [ 0.83609235 -0.8582128 ]
  [ 0.7402787  -1.2575378 ]
  [ 0.74992824  1.0300466 ]
  [ 0.05962151  0.18771464]
  [ 1.0322506  -1.770295  ]
  [ 0.619172    0.41612193]
  [ 0.2987381  -0.21577743]
  [-0.22905567  0.49388447]
  [ 0.2659631  -0.33611947]
  [ 0.14251202 