In [1]:
import jraph
import numpy as np
import jax
from jax.experimental.sparse import BCOO
from pysat.formula import CNF
import jax.numpy as jnp
import collections
from functools import partial

In [2]:
SATProblem = collections.namedtuple("SATProblem", ("graph", "mask", "params", "clause_lengths"))

class HashableSATProblem(SATProblem):
    def __hash__(self):
        return hash((self.graph.senders.tostring(),
                     self.graph.receivers.tostring(),
                     self.graph.edges.tostring(),
                     self.params,
                     tuple(self.clause_lengths)
                     ))

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()


def get_problem_from_cnf(cnf: CNF, pad_nodes=0, pad_edges=0):
    cnf.clauses = [c for c in cnf.clauses if len(c) > 0]
    n = cnf.nv
    m = len(cnf.clauses)
    n_node = n + m
    clause_lengths = [len(c) for c in cnf.clauses]
    k = max(clause_lengths)
    n_edge = sum(clause_lengths)

    # for sake of jitting, if the cnf isn't already strictly in k-cnf form, we introduce
    # additional dummy variables and constraints. NB: While this in principles solves the problem,
    # it actually is to be avoided, if possible: This is because it very easy to satisfy all constraint except one
    # by just setting the dummy variables to True. This creates local minima and also breaks locality.
    # if any([len(c) != k for c in cnf.clauses]):
    #     m += 2 ** k - 1
    #     n += k
    #
    #     dummy_vars = np.arange(n - k, n)
    #     senders.extend(np.repeat(dummy_vars, 2 ** k - 1))
    #
    #     # we introduce additional constraints to force the dummy variables into the all zeros string
    #     additional_constraints = all_bitstrings(k)[1:, :]
    #
    #     for j in range(2 ** k - 1):
    #         edges.extend(additional_constraints[j, :])
    #         receivers.extend(np.repeat(m - 2 ** k + 1, k))

    edges = []
    senders = []
    receivers = []
    nodes = [0 if i < n else 1 for i in range(n_node)]
    for j, c in enumerate(cnf.clauses):
        support = [(abs(l) - 1) for l in c]
        assert len(support) == len(
            set(support)
        ), "Multiple occurrences of single variable in constraint"

        vals = ((np.sign(c) + 1) // 2).astype(np.int32)

        senders.extend(support)
        edges.extend(vals)
        receivers.extend(np.repeat(j + n, len(c)))

    assert len(nodes) == n_node
    assert len(receivers) == len(senders)
    assert len(senders) == len(edges)
    assert len(edges) == n_edge

    graph = jraph.GraphsTuple(
        n_node=np.asarray([n_node]),
        n_edge=np.asarray([n_edge]),
        edges=np.eye(2)[edges],
        nodes=np.eye(2)[nodes],
        globals=None,
        senders=np.asarray(senders),
        receivers=np.asarray(receivers),
    )

    # padding done in case we want to jit the graph, this is relevant mostly for training the gnn model, not for
    # executing moser's walk on single instances

    if pad_nodes > n_node or pad_edges > n_edge:
        n_node = max(pad_nodes, n_node)
        n_edge = max(pad_edges, n_edge)
        graph = jraph.pad_with_graphs(
            graph, n_node, n_edge
        )

    # For the loss calculation we create a mask for the nodes, which masks
    # the constraint nodes and the padding nodes.

    mask = (np.arange(pad_nodes) < n).astype(np.int32)

    return HashableSATProblem(
        graph=graph,
        mask=mask,
        clause_lengths=clause_lengths,
        params=(n, m, k)
    )

@partial(jax.jit, static_argnames=("problem",))
def violated_constraints(problem: SATProblem, assignment):
    graph = problem.graph
    edge_is_violated = jnp.mod(graph.edges[:, 1] + assignment[graph.senders].T, 2)

    e = len(graph.edges)
    _, m, k = problem.params
    edge_mask_sp = BCOO(
        (np.ones(e), np.column_stack((np.arange(e), graph.receivers))), shape=(e, m)
    )

    violated_constraint_edges = edge_is_violated @ edge_mask_sp  # (,x) @ (x,m)  = (,m)
    constraint_is_violated = violated_constraint_edges == jnp.asarray(problem.clause_lengths)

    # if all clauses have the same length, then we could implement the above simply as
    # constraint_is_violated = (
    #     jax.vmap(jnp.sum)(jnp.reshape(edge_is_violated, (m, k))) == k
    # )
    return constraint_is_violated

In [3]:
test_cnf = get_problem_from_cnf(CNF(from_file="../Data/BroadcastTestSet/41-20486.cnf"))
n,m,k = test_cnf.params
np.random.seed(42)
assignment = np.random.randint(0,2,n)

In [4]:
%timeit violated_constraints(test_cnf, assignment).block_until_ready()

  return hash((self.graph.senders.tostring(),
  self.graph.receivers.tostring(),
  self.graph.edges.tostring(),


10 ms ± 465 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
