In [40]:
import jax
from jax import numpy as jnp
from functools import partial

In [2]:
jax.default_device = jax.devices("cpu")[0]

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
jax.config.update("jax_enable_x64", True)

In [4]:
X = jnp.array([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=jnp.float64)
X.dtype

dtype('float64')

In [5]:
seed = 33
key = jax.random.PRNGKey(seed)

In [6]:
def init_params(key, layer_sizes):
    scale_weights = 1e-1
    scale_biases = 1e-1
    params = []
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = jax.random.split(key)
        params.append(dict(weights=scale_weights * jax.random.normal(subkey, (n_in, n_out)), 
                           biases=scale_biases * jax.random.normal(subkey, (n_out,)))
                      )
    return params

In [7]:
layer_sizes = [9, 128, 128, 9]
params = init_params(key, layer_sizes)

In [8]:
jax.tree_map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (9, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (9,), 'weights': (128, 9)}]

In [9]:
@jax.jit
def forward(params, x):
    x = jnp.ravel(x)
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
    raw_out = jnp.reshape(x @ last['weights'] + last['biases'], (3, 3))
    out = jax.nn.softmax(raw_out, axis=1)
    return out
forward(params, X)

Array([[0.34678932, 0.33656011, 0.31665057],
       [0.22175218, 0.32623454, 0.45201328],
       [0.3609601 , 0.28891771, 0.35012219]], dtype=float64)

In [46]:
@jax.jit
def group_loss(size, bad_edge_sum):
    max_size = size * (size - 1)
    max_size = jnp.where(max_size < 1., 1., max_size)
    return bad_edge_sum / max_size

In [48]:
@jax.jit
def loss_fn(params, x, y):
    pred = forward(params, x)
    group_ids = jnp.argmax(pred, axis=1)
    groups = {}
    for x_idx, group_id in enumerate(group_ids):
        if not groups.get(group_id, False):
            groups[group_id] = {"size": 0, "x_idxes": set(), "bad_edge_sum": 0}
        groups[group_id]["size"] += 1
        groups[group_id]["x_idxes"].add(x_idx)
    for _, group in groups.items():
        for x_idx in group["x_idxes"]:
            for other_idx in range(x.shape[0]):
                if other_idx == x_idx: continue
                if other_idx not in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 1, 1, 0)
                if other_idx in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 0, 1, 0)
    group_infos = {group_id: {"size": group["size"], "bad_edge_sum": group["bad_edge_sum"]} for group_id, group in groups.items()}
    losses = list(map(lambda group: group_loss(group["size"], group["bad_edge_sum"]), list(group_infos.values())))
    loss = jnp.sum(jnp.asarray(losses))
    return loss
loss_fn(params, X, forward(params, X))

Array(4., dtype=float64)