In [1]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

In [2]:
from typing import Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

In [90]:
# Copyright 2020 DeepMind Technologies Limited.


# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""2-SAT solver example.
Here we train a graph neural network to solve 2-sat problems.
https://en.wikipedia.org/wiki/2-satisfiability
For instance a 2 sat problem with 3 literals would look like this:
   (a or b)  and  (not a or c)  and (not b or not c)
We represent this problem in form of a bipartite-graph, with edges
connecting the literal-nodes (a, b, c) with the constraint-nodes (O).
The corresponding graph looks like this:
     O    O   O
     |\  /\  /|
     | \/  \/ |
     | /\  /\ |
     |/  \/  \|
     a    b   c
The nodes are one-hot encoded with literal nodes as (1, 0) and constraint nodes
as (0, 1). The edges are one-hot encoded with (1, 0) if the literal should be
true and (0, 1) if the literal should be false.
The graph neural network encodes the nodes and the edges and runs multiple
message passing steps by calculating message for each edge and aggregating
all the messages of the nodes.
The training dataset consists of randomly generated 2-sat problems with 2 to 15
literals.
The test dataset consists of randomly generated 2-sat problems with 16 to 20
literals.
"""

import collections
import random
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import numpy as np
import optax
from functools import partial


LabeledProblem = collections.namedtuple("Problem", ("graph", "labels", "mask", "meta"))

Problem = collections.namedtuple("Problem", ("graph", "mask", "meta"))

def get_2sat_problem(min_n_literals: int, max_n_literals: int) -> LabeledProblem:
    """Creates bipartite-graph representing a randomly generated 2-sat problem.
    Args:
      min_n_literals: minimum number of literals in the 2-sat problem.
      max_n_literals: maximum number of literals in the 2-sat problem.
    Returns:
      bipartite-graph, node labels and node mask.
    """
    n_literals = random.randint(min_n_literals, max_n_literals)
    n_literals_true = random.randint(1, n_literals - 1)
    n_constraints = n_literals * (n_literals - 1) // 2

    n_node = n_literals +  n_constraints
    # 0 indicates a literal node
    # 1 indicates a constraint node.
    nodes = [0 if i < n_literals else 1 for i in range(n_node)]
    edges = []
    senders = []
    for literal_node1 in range(n_literals):
        for literal_node2 in range(literal_node1 + 1, n_literals):
            senders.append(literal_node1)
            senders.append(literal_node2)
            # 1 indicates that the literal must be true for this constraint.
            # 0 indicates that the literal must be false for this constraint.
            # I.e. with literals a and b, we have the following possible constraints:
            # 0, 0 -> a or b
            # 1, 0 -> not a or b
            # 0, 1 -> a or not b
            # 1, 1 -> not a or not b
            edges.append(1 if literal_node1 < n_literals_true else 0)
            edges.append(1 if literal_node2 < n_literals_true else 0)

    graph = jraph.GraphsTuple(
        n_node=np.asarray([n_node]),
        n_edge=np.asarray([2 * n_constraints]),
        # One-hot encoding for nodes and edges.
        edges=np.eye(2)[edges],
        nodes=np.eye(2)[nodes],
        globals=None,
        senders=np.asarray(senders),
        receivers=np.repeat(np.arange(n_constraints) + n_literals, 2))

    # In order to jit compile our code, we have to pad the nodes and edges of
    # the GraphsTuple to a static shape.
    max_n_constraints = max_n_literals * (max_n_literals - 1) // 2
    max_nodes = max_n_literals + max_n_constraints  + 1
    max_edges = 2 * max_n_constraints
    graph = jraph.pad_with_graphs(graph, max_nodes, max_edges)

    # The ground truth solution for the 2-sat problem.
    labels = (np.arange(max_nodes) < n_literals_true).astype(np.int32)
    labels = np.eye(2)[labels]

    # For the loss calculation we create a mask for the nodes, which masks
    # the constraint nodes and the padding nodes.
    mask = (np.arange(max_nodes) < n_literals).astype(np.int32)
    meta = {"n_vars": n_literals, "n_constraints": n_constraints}
    return LabeledProblem(graph=graph, labels=labels, mask=mask, meta=meta)

def get_k_sat_problem(n, m, k):
    n_node = n + m
    nodes = [0 if i < n else 1 for i in range(n_node)]
    edges = []
    senders = []
    receivers = []

    for c in range(n, n_node):
        support = np.random.choice(n, replace=False, size=(k,))
        bits = np.random.randint(2, size=(k,))
        senders.extend(support)
        edges.extend(bits)
        receivers.extend(np.repeat(c,k))

    graph = jraph.GraphsTuple(
        n_node=jnp.asarray([n_node]),
        n_edge=jnp.asarray([m * k]),
        edges=np.eye(2)[edges],
        nodes=np.eye(2)[nodes],
        globals=None,
        senders=jnp.asarray(senders),
        receivers=jnp.asarray(receivers)
    )


    # For the loss calculation we create a mask for the nodes, which masks
    # the constraint nodes and the padding nodes.
    mask = (np.arange(n_node) < n).astype(np.int32)
    meta = {"n": n, "m": m, "k": k}
    return Problem(graph=graph, mask=mask, meta=meta)



def network_definition(
        graph: jraph.GraphsTuple,
        num_message_passing_steps: int = 5) -> jraph.ArrayTree:
    """Defines a graph neural network.
    Args:
      graph: Graphstuple the network processes.
      num_message_passing_steps: number of message passing steps.
    Returns:
      Decoded nodes.
    """
    embedding = jraph.GraphMapFeatures(
        embed_edge_fn=jax.vmap(hk.Linear(output_size=16)),
        embed_node_fn=jax.vmap(hk.Linear(output_size=16)))
    graph = embedding(graph)

    @jax.vmap
    @jraph.concatenated_args
    def update_fn(features):
        net = hk.Sequential([
            hk.Linear(10), jax.nn.relu,
            hk.Linear(10), jax.nn.relu,
            hk.Linear(10), jax.nn.relu])
        return net(features)

    for _ in range(num_message_passing_steps):
        gn = jraph.InteractionNetwork(
            update_edge_fn=update_fn,
            update_node_fn=update_fn,
            include_sent_messages_in_node_update=True)
        graph = gn(graph)

    return hk.Linear(2)(graph.nodes)


def train(num_steps: int):
    """Trains a graph neural network on a 2-sat problem."""
    train_dataset = (2, 15)
    test_dataset = (16, 20)
    random.seed(42)

    network = hk.without_apply_rng(hk.transform(network_definition))
    problem = get_2sat_problem(*train_dataset)
    params = network.init(jax.random.PRNGKey(42), problem.graph)

    @jax.jit
    def supervised_prediction_loss(params, problem):
        decoded_nodes = network.apply(params, problem.graph)
        # We interpret the decoded nodes as a pair of logits for each node.
        log_prob = jax.nn.log_softmax(decoded_nodes) * problem.labels
        return -jnp.sum(log_prob * problem.mask[:, None]) / jnp.sum(problem.mask)

    opt_init, opt_update = optax.adam(2e-4)
    opt_state = opt_init(params)

    @jax.jit
    def update(params, opt_state, problem):
        g = jax.grad(supervised_prediction_loss)(params, problem)
        updates, opt_state = opt_update(g, opt_state)
        return optax.apply_updates(params, updates), opt_state

    for step in range(num_steps):
        problem = get_2sat_problem(*train_dataset)
        params, opt_state = update(params, opt_state, problem)
        if step % 1000 == 0:
            train_loss = jnp.mean(
                jnp.asarray([
                    supervised_prediction_loss(params, get_2sat_problem(*train_dataset))
                    for _ in range(100)
                ])).item()
            test_loss = jnp.mean(
                jnp.asarray([
                    supervised_prediction_loss(params, get_2sat_problem(*test_dataset))
                    for _ in range(100)
                ])).item()
            print("step %r loss train %r test %r", step, train_loss, test_loss)
            #logging.info("step %r loss train %r test %r", step, train_loss, test_loss)
    return network, params

In [91]:
def moser_walk_sampler(weights, problem, n_samples = 1000):
    """
    This is a modified moser walker that will run [n_samples] iterations of the random walk even if it has found a solution already
    :param n_samples: int
    :param weights: np array of size n_vars
    :param graph: jraph.GraphTuple
    :return:
    """
    n,m,k = problem.meta.values()
    graph = problem.graph
    # we assume that the instance is k-CNF, so all clauses have the same length
    # This means there k*m edges.

    @jax.jit
    def resample_constraint(rng_key, assignment, j):
        _, rng_key = jax.random.split(rng_key)
        support = jnp.where(jnp.asarray(graph.receivers) == j, size=k)
        support_mask = jnp.zeros(n)
        support_mask.at[support].set(1)
        assignment_proposals = jax.random.bernoulli(rng_key, weights, (n,))
        new_assignment = jnp.where(support_mask, assignment_proposals, assignment)
        return rng_key, new_assignment # assignment.at[support].set(sub_assignment)

    @jax.jit
    def identify_violated_constraint(assignment, graph):

        edge_is_violated = jnp.mod(jnp.asarray(graph.edges)[:,1].astype(np.int32) + assignment[graph.senders], 2)
        constraint_is_violated = jax.vmap(jnp.sum)(jnp.reshape(edge_is_violated, (m,k))) == k
        return jnp.where(constraint_is_violated, size=1)[0], jnp.sum(constraint_is_violated)

    @jax.jit
    def step(i, state):
        rng_key, assignment = state
        violated_constraint, number_violations = identify_violated_constraint(assignment, graph)
        return resample_constraint(rng_key, assignment, violated_constraint)

    rng_key = jax.random.PRNGKey(42)
    init_assignment = jax.random.bernoulli(rng_key, weights, (n,))
    final_assignment = jax.lax.fori_loop(0, n_samples, step, (rng_key, init_assignment))

    return final_assignment


In [92]:
get_k_sat_problem(2,1,2)

Problem(graph=GraphsTuple(nodes=array([[1., 0.],
       [1., 0.],
       [0., 1.]]), edges=array([[0., 1.],
       [1., 0.]]), receivers=DeviceArray([2, 2], dtype=int32), senders=DeviceArray([0, 1], dtype=int32), globals=None, n_node=DeviceArray([3], dtype=int32), n_edge=DeviceArray([2], dtype=int32)), mask=array([1, 1, 0], dtype=int32), meta={'n': 2, 'm': 1, 'k': 2})

In [93]:
get_2sat_problem(2,2)

Problem(graph=GraphsTuple(nodes=array([[1., 0.],
       [1., 0.],
       [0., 1.],
       [0., 0.]]), edges=array([[0., 1.],
       [1., 0.]]), receivers=array([2, 2]), senders=array([0, 1]), globals=None, n_node=array([3, 1]), n_edge=array([2, 0])), labels=array([[0., 1.],
       [1., 0.],
       [1., 0.],
       [1., 0.]]), mask=array([1, 1, 0, 0], dtype=int32), meta={'n_vars': 2, 'n_constraints': 1})

In [96]:
network = hk.without_apply_rng(hk.transform(network_definition))
problem = get_k_sat_problem(100,150,3)
# problem = get_2sat_problem(16,20)
params = network.init(jax.random.PRNGKey(42), problem.graph)
prdictions = network.apply(params, problem.graph)
weights = jax.nn.softmax(prdictions)[: problem.meta["n"],0]
final_assignment = moser_walk_sampler(weights, problem, 10000)

In [97]:
final_assignment

(DeviceArray([1680141627, 3487984247], dtype=uint32),
 DeviceArray([False, False, False,  True,  True,  True,  True, False,
               True,  True, False,  True, False, False, False, False,
              False,  True,  True,  True,  True,  True, False, False,
              False,  True, False,  True,  True, False,  True,  True,
               True,  True,  True, False, False,  True,  True,  True,
              False, False, False,  True,  True, False, False, False,
              False, False,  True, False, False, False,  True,  True,
               True, False, False,  True,  True, False, False, False,
              False, False, False,  True, False, False,  True, False,
               True, False,  True,  True,  True, False, False,  True,
              False, False,  True, False, False, False, False,  True,
               True,  True, False,  True,  True,  True,  True,  True,
               True,  True, False,  True], dtype=bool))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=2/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function constraint_satisfied at /var/folders/4s/mnw4xd1129g6vmrwysfzd97r0000gp/T/ipykernel_65648/4198255437.py:128 for xla_call. This concrete value was not available in Python because it depends on the value of the argument 'j'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError