In [1]:
import jax
from jax import numpy as jnp
from functools import partial
import copy

In [2]:
jax.default_device = jax.devices("cpu")[0]

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
jax.config.update("jax_enable_x64", True)

In [4]:
X = jnp.array([[1,1,0,1,1],[1,1,1,0,0],[0,1,1,0,1],[1,0,0,1,0],[1,0,1,0,1]], dtype=jnp.float64)
X.dtype

dtype('float64')

In [5]:
seed = 33
key = jax.random.PRNGKey(seed)

In [6]:
def init_params(key, layer_sizes):
    scale_weights = 1e-1
    scale_biases = 1e-1
    params = []
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = jax.random.split(key)
        params.append(dict(weights=scale_weights * jax.random.normal(subkey, (n_in, n_out)), 
                           biases=scale_biases * jax.random.normal(subkey, (n_out,)))
                      )
    return params

In [7]:
size_based_on_input = X.shape[0] * X.shape[1]
layer_sizes = [size_based_on_input, 128, 128, size_based_on_input]
params = init_params(key, layer_sizes)

In [8]:
jax.tree_map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (25, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (25,), 'weights': (128, 25)}]

In [9]:
@jax.jit
def forward(params, x):
    shape = x.shape
    x = jnp.ravel(x)
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
    raw_out = jnp.reshape(x @ last['weights'] + last['biases'], shape)
    out = jax.nn.softmax(raw_out, axis=1)
    return out
forward(params, X)

Array([[0.23096333, 0.19699308, 0.26861276, 0.18406525, 0.11936557],
       [0.17756499, 0.20624649, 0.18322543, 0.27584555, 0.15711755],
       [0.22797607, 0.1946949 , 0.22207917, 0.1648719 , 0.19037796],
       [0.18770376, 0.14453747, 0.19382609, 0.31161444, 0.16231824],
       [0.16732629, 0.17238931, 0.20177378, 0.20262801, 0.2558826 ]],      dtype=float64)

In [10]:
@jax.jit
def group_loss(size, bad_edge_sum, group_val_sum):
    max_size = size * (size - 1)
    max_size = jnp.where(max_size < 1., 0.5, max_size)
    return (bad_edge_sum / max_size) - (group_val_sum / size)

In [11]:
@jax.jit
def loss_fn(params, x):
    pred = forward(params, x)
    group_ids = jnp.argmax(pred, axis=1)
    groups = {}
    for x_idx, group_id in enumerate(group_ids):
        if not groups.get(group_id, False): groups[group_id] = {"size": 0, "x_idxes": set(), "bad_edge_sum": 0, "group_val_sum": 0}
        groups[group_id]["size"] += 1
        groups[group_id]["x_idxes"].add(x_idx)
        groups[group_id]["group_val_sum"] += pred[x_idx, group_id]
    for _, group in groups.items():
        for x_idx in group["x_idxes"]:
            for other_idx in range(x.shape[0]):
                if other_idx == x_idx: continue
                if other_idx not in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 1, 1, 0)
                if other_idx in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 0, 1, 0)
    group_infos = {group_id: {"size": group["size"], "bad_edge_sum": group["bad_edge_sum"], "group_val_sum": group["group_val_sum"]} for group_id, group in groups.items()}
    losses = list(map(lambda group: group_loss(group["size"], group["bad_edge_sum"], group["group_val_sum"]), list(group_infos.values())))
    loss = jnp.sum(jnp.asarray(losses))
    return loss
loss_fn(params, X)

Array(18.66006858, dtype=float64)

In [19]:
def get_groups(params, x):
    pred = forward(params, x)
    group_ids = jnp.argmax(pred, axis=1)
    groups = {}
    for x_idx, group_id in enumerate(group_ids):
        try:
            if not groups.get(group_id, False): groups[group_id] = {"size": 0, "x_idxes": set(), "bad_edge_sum": 0, "group_val_sum": 0}
        except:
            group_id = int(group_id)
            if not groups.get(group_id, False): groups[group_id] = {"size": 0, "x_idxes": set(), "bad_edge_sum": 0, "group_val_sum": 0}
        groups[group_id]["size"] += 1
        groups[group_id]["x_idxes"].add(x_idx)
        groups[group_id]["group_val_sum"] += pred[x_idx, group_id]
    for _, group in groups.items():
        for x_idx in group["x_idxes"]:
            for other_idx in range(x.shape[0]):
                if other_idx == x_idx: continue
                if other_idx not in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 1, 1, 0)
                if other_idx in group["x_idxes"]: group["bad_edge_sum"] += jnp.where(x[x_idx, other_idx] == 0, 1, 0)
    group_infos = {group_id: {"size": group["size"], "bad_edge_sum": group["bad_edge_sum"], "group_val_sum": group["group_val_sum"]} for group_id, group in groups.items()}
    losses = list(map(lambda group: group_loss(group["size"], group["bad_edge_sum"], group["group_val_sum"]), list(group_infos.values())))
    loss = jnp.sum(jnp.asarray(losses))
    return groups
get_groups(params, X)

{2: {'size': 1,
  'x_idxes': {0},
  'bad_edge_sum': Array(3, dtype=int64, weak_type=True),
  'group_val_sum': Array(0.26861276, dtype=float64)},
 3: {'size': 2,
  'x_idxes': {1, 3},
  'bad_edge_sum': Array(5, dtype=int64, weak_type=True),
  'group_val_sum': Array(0.58745999, dtype=float64)},
 0: {'size': 1,
  'x_idxes': {2},
  'bad_edge_sum': Array(2, dtype=int64, weak_type=True),
  'group_val_sum': Array(0.22797607, dtype=float64)},
 4: {'size': 1,
  'x_idxes': {4},
  'bad_edge_sum': Array(2, dtype=int64, weak_type=True),
  'group_val_sum': Array(0.2558826, dtype=float64)}}

In [None]:
@jax.jit
def update(learning_rate, params, grads):
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

In [None]:
print(loss_fn(params, X))
print(loss_fn(jax.tree_map(lambda p: p + 2.1, params), X))
print(loss_fn(jax.tree_map(lambda p: p - 2.1, params), X))

18.42001666342628
15.0
18.931123153843615


Array([[0.22673842, 0.18439643, 0.31268311, 0.16847783, 0.10770421],
       [0.15803546, 0.19700905, 0.16948328, 0.3250424 , 0.15042981],
       [0.27143823, 0.18074362, 0.21307662, 0.15608245, 0.17865908],
       [0.17196635, 0.14150871, 0.17120961, 0.36893546, 0.14637988],
       [0.15647498, 0.15951089, 0.19142187, 0.19070813, 0.30188413]],      dtype=float64)

In [None]:
learning_rate = 1e-3
n_iters = 50

for _ in range(n_iters):
    grads = jax.grad(loss_fn)(params, X)
    params = update(learning_rate, params, grads)
    print(loss_fn(params, X))

18.65598807103841
18.651880940815033
18.647746866440947
18.643585523811613
18.639396584378918
18.63517971508872
18.630938196359796
18.626685957276962
18.622404972717963
18.61813469327295
18.61384347483918
18.609522738472513
18.60517212392422
18.600791265994218
18.596379794462926
18.59193733402248
18.587463504207328
18.582957919324205
18.578420188381543
18.573849915018265
18.569246697432114
18.564613082064003
18.559946777803667
18.55524912793077
18.550514187964886
18.545744161579453
18.540940345567815
18.536100256969064
18.53122258822695
18.526308587457848
18.521359366771538
18.516369972626002
18.511342398488914
18.506278304647402
18.501173732846546
18.496028725343646
18.49084442741207
18.485620759862943
18.48037907085864
18.47510846327875
18.469796083513433
18.464441405256103
18.459043895211742
18.453603013024786
18.44811911959222
18.4425898821077
18.437015564459116
18.43139564118035
18.425729535562994
18.42001666342628


In [None]:
loss_fn(params, X)

Array(18.42001666, dtype=float64)