In [None]:
!pip install dm-clrs

In [None]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))



In [None]:
train_sampler, spec = clrs.build_sampler(
    name='bellman_ford',
    num_samples=100,
    length=16)

test_sampler, spec = clrs.build_sampler(
    name='bellman_ford',
    num_samples=100,
    length=64)

pprint.pprint(spec)

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=100)

{'A': ('input', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'd': ('hint', 'node', 'scalar'),
 'msk': ('hint', 'node', 'mask'),
 'pi': ('output', 'node', 'pointer'),
 'pi_h': ('hint', 'node', 'pointer'),
 'pos': ('input', 'node', 'scalar'),
 's': ('input', 'node', 'mask_one')}


In [None]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=32,
    encode_hints=True,
    decode_hints=True,
    decode_diffs=False,
    hint_teacher_forcing_noise=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False,
    dropout_prob=0.0,
)

dummy_trajectory = next(train_sampler)

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234)

In [None]:
step = 0

while step <= 100:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  rng_key, new_rng_key = jax.random.split(rng_key)
  cur_loss = model.feedback(rng_key, feedback)
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')
  step += 1

  flat_grads, treedef = jax.tree_flatten(masked_grads)


step = 0 | loss = 6.8649001121521 | val_acc = 0.25390625 | test_acc = 0.1237499937415123
step = 10 | loss = 3.8234963417053223 | val_acc = 0.466796875 | test_acc = 0.1704687476158142
step = 20 | loss = 3.022090435028076 | val_acc = 0.609375 | test_acc = 0.3075000047683716
step = 30 | loss = 2.4777908325195312 | val_acc = 0.732421875 | test_acc = 0.3806249797344208
step = 40 | loss = 2.105839729309082 | val_acc = 0.78125 | test_acc = 0.4154687523841858
step = 50 | loss = 1.7853212356567383 | val_acc = 0.7890625 | test_acc = 0.4312499761581421
step = 60 | loss = 1.6517027616500854 | val_acc = 0.79296875 | test_acc = 0.5023437142372131
step = 70 | loss = 1.4947378635406494 | val_acc = 0.849609375 | test_acc = 0.5318750143051147
step = 80 | loss = 1.404116153717041 | val_acc = 0.849609375 | test_acc = 0.532031238079071
step = 90 | loss = 1.276430368423462 | val_acc = 0.85546875 | test_acc = 0.5393750071525574
step = 100 | loss = 1.1954240798950195 | val_acc = 0.869140625 | test_acc = 0.552