In [None]:
!pip install dm-clrs

In [2]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(100))

In [65]:
def custom_dot_exp_log(inputs, w, basis_exp=jnp.exp(1), basis_log=jnp.exp(1)):
    max_val = jnp.maximum(jnp.max(inputs), jnp.max(w))
    out = max_val + jnp.log(jnp.dot(jnp.power(inputs - max_val,basis_exp), jnp.power(w - max_val, basis_exp)))
    out = out / jnp.log(basis_log)


In [66]:
inputs = jnp.array([1,6,7])
w = jnp.array([1,2,3])
print(jnp.log(jnp.dot(jnp.exp(inputs), jnp.exp(w))))
print(custom_dot_exp_log(inputs,w))

10.127223
None


In [3]:
train_sampler, spec = clrs.build_sampler(
    name='bellman_ford',
    num_samples=100,
    length=16)

test_sampler, spec = clrs.build_sampler(
    name='bellman_ford',
    num_samples=100,
    length=64)

pprint.pprint(spec)

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=100)

{'A': ('input', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'd': ('hint', 'node', 'scalar'),
 'msk': ('hint', 'node', 'mask'),
 'pi': ('output', 'node', 'pointer'),
 'pi_h': ('hint', 'node', 'pointer'),
 'pos': ('input', 'node', 'scalar'),
 's': ('input', 'node', 'mask_one')}


In [11]:
processor_factory = clrs.get_processor_factory('mpnn', True, 0)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=32,
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False,
    dropout_prob=0.0,
)

dummy_trajectory = next(train_sampler)

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234)

In [12]:
step = 0

while step <= 100:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  rng_key, new_rng_key = jax.random.split(rng_key)
  cur_loss = model.feedback(rng_key, feedback)
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')
  step += 1

step = 0 | loss = 7.864797115325928 | val_acc = 0.28125 | test_acc = 0.07874999940395355
step = 10 | loss = 4.704117298126221 | val_acc = 0.41015625 | test_acc = 0.19765624403953552
step = 20 | loss = 3.7001068592071533 | val_acc = 0.513671875 | test_acc = 0.3335937559604645
step = 30 | loss = 3.19502854347229 | val_acc = 0.607421875 | test_acc = 0.3785937428474426
step = 40 | loss = 2.68827748298645 | val_acc = 0.669921875 | test_acc = 0.40687498450279236
step = 50 | loss = 2.3591408729553223 | val_acc = 0.734375 | test_acc = 0.43031248450279236
step = 60 | loss = 2.2472643852233887 | val_acc = 0.755859375 | test_acc = 0.41203123331069946
step = 70 | loss = 1.8525781631469727 | val_acc = 0.78515625 | test_acc = 0.4596875011920929
step = 80 | loss = 1.6861634254455566 | val_acc = 0.837890625 | test_acc = 0.4923437535762787
step = 90 | loss = 1.5138217210769653 | val_acc = 0.8359375 | test_acc = 0.5285937190055847
step = 100 | loss = 1.3650599718093872 | val_acc = 0.830078125 | test_acc

In [None]:
step = 100 | loss = 2.5001742839813232 | val_acc = 0.623046875 | test_acc = 0.38593748211860657