In [None]:
!pip install jraph dm-haiku optax python-sat memory_profiler

In [1]:

%load_ext autoreload
%autoreload 2
%load_ext memory_profiler

In [None]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

In [2]:
import jax

from constraint_problems import get_problem_from_cnf
from model import train_model, network_definition
from random_walk import moser_walk_sampler, moser_walk

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.40

In [None]:
# from pysat.formula import CNF
# import glob
# instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/uf20-91/*.cnf')]
# train_instances, test_instances = instances[:950],instances[950:]

In [4]:
from pysat.formula import CNF
# import glob
# test_instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/blocksworld/*.cnf') if f == '../Data/blocksworld/bw_large.d.cnf']
# train_instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/blocksworld/*.cnf') if f != '../Data/blocksworld/bw_large.d.cnf']

In [None]:
import os, gzip

instance_name = "49-122082"
path = os.path.join("/Users/p390943/Downloads/10K", instance_name)
with gzip.open(path + ".cnf.gz", "rt") as f:
    cnf = CNF(from_string=f.read())

In [None]:
len(cnf.clauses)

In [None]:
del cnf

In [5]:
train = [get_problem_from_cnf(CNF(from_file="../Data/blocksworld/huge.cnf")), get_problem_from_cnf(CNF(from_file="../Data/blocksworld/medium.cnf"))]
test = [get_problem_from_cnf(CNF(from_file="../Data/blocksworld/bw_large.a.cnf"))]

In [None]:
# for i in test_instances:
#     print(i.graph.n_node, i.graph.n_edge, i.params)

In [None]:
network, params = train_model(1000,0.02,train, test, num_steps=10)

In [None]:
planning_problem = test[0]
prdictions = network.apply(params, planning_problem.graph)
weights = jax.nn.softmax(prdictions)[:,1]
trajectory, energies = moser_walk_sampler(weights, planning_problem, 1000, 0)

In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    # Run the operations to be profiled
    final_assignment, energy = moser_walk(weights, planning_problem, 1000, 0)

In [6]:
%timeit final_assignment, energy = moser_walk(weights, planning_problem, 1000, 0)

NameError: name 'weights' is not defined

In [None]:
test[0].graph.edges[:, 1].shape

In [None]:
import numpy as np
np.histogram(weights)

In [None]:
import jax.numpy as jnp

def violated_constraints2(edges, senders, mask, assignment):
    edge_is_violated = jnp.mod(
        edges + assignment[senders], 2
    )
    violated_constraint_edges = edge_is_violated @ mask  # (x,) @ (x,m)  = (m,)
    return violated_constraint_edges == 1

In [None]:
rng_key = jax.random.PRNGKey(0)
random_assignment = jax.random.bernoulli(rng_key, weights)

edges = planning_problem.graph.edges[:,1].astype(np.int32)
senders = planning_problem.graph.senders
mask, _, _ = planning_problem.constraint_utils

%timeit violated_constraints2(edges, senders, mask, random_assignment).block_until_ready()
%timeit jax.jit(violated_constraints2)(edges, senders, mask, random_assignment).block_until_ready()

In [None]:
class SATTrainingDataset(data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        processed_instances = glob.glob(os.path.join(data_dir, "*.pkl"))
        self.solved_instances = []
        for f in processed_instances:
            with open(f, "rb") as fl:
                s = pickle.load(fl)
                if not s:
                    continue
                self.solved_instances.append(f.split("_")[0])
        self.ratio_of_solved_instances = self.__len__()/len(processed_instances)

    def __len__(self):
        return len(self.solved_instances)

    @staticmethod
    def solution_dict_to_array(solution_dict):
        return np.array(list(solution_dict.values()), dtype=int)

    def __getitem__(self, idx):
        instance_name = self.solved_instances[idx]
        path = os.path.join(self.data_dir, instance_name)
        with gzip.open(path + ".cnf.gz", "rt") as f:
            problem = get_problem_from_cnf(CNF(from_string=f.read()))
        with open(path + "_sol.pkl", "rb") as f:
            solution_dict = pickle.load(f)
        return problem, self.solution_dict_to_array(solution_dict)

In [None]:
sat_data = SATTrainingDataset("/Users/p390943/Downloads/10K")

In [None]:
train_data, test_data = data.random_split(sat_data, [0.8,0.2])

In [None]:
train_loader = data.DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=1, shuffle=True)

In [None]:
test_data.__getitem__(0)

In [None]:
import time
import optax
import haiku as hk
import jax.numpy as jnp

num_epochs = 10

network = hk.without_apply_rng(hk.transform(network_definition))
params = network.init(jax.random.PRNGKey(42), train_data[0][0].graph)
optimizer = optax.adam(2e-4)
opt_state = optimizer.init(params)

@jax.jit
def prediction_loss(params, problem, solution):
    decoded_nodes = network.apply(params, problem.graph)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes) * solution
    return -jnp.sum(log_prob * problem.mask[:, None]) / jnp.sum(problem.mask)

# Make a batched version of the forwarding
batched_predict = jax.vmap(network.apply, in_axes=(None, 0))

def loss(params, problems, targets):
    preds = batched_predict(params, problems)
    return -jnp.mean(preds * targets)

def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

@jax.jit
def update(params, opt_state, x, y):
    g = jax.grad(loss)(params, x, y)
    updates, opt_state = optimizer.update(g, opt_state)
    return  optax.apply_updates(params, updates), opt_state

for epoch in range(num_epochs):
    start_time = time.time()
    for batch, (x, y) in enumerate(sat_data):
        y = one_hot(y, 2)
        params, opt_state = update(params, opt_state, x, y)
    epoch_time = time.time() - start_time
    print(f"{batch} done")

    # train_acc = accuracy(params, train_images, train_labels)
    # test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    # print("Training set accuracy {}".format(train_acc))
    # print("Test set accuracy {}".format(test_acc))

In [None]:
train_data, test_data = data.random_split(sat_data, [0.8,0.2])

In [None]:
train_loader = data.DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=1, shuffle=True)

In [None]:
test_data.__getitem__(0)

In [None]:
import time
import optax
import haiku as hk
import jax.numpy as jnp

num_epochs = 10

network = hk.without_apply_rng(hk.transform(network_definition))
params = network.init(jax.random.PRNGKey(42), train_data[0][0].graph)
optimizer = optax.adam(2e-4)
opt_state = optimizer.init(params)

@jax.jit
def prediction_loss(params, problem, solution):
    decoded_nodes = network.apply(params, problem.graph)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes) * solution
    return -jnp.sum(log_prob * problem.mask[:, None]) / jnp.sum(problem.mask)

# Make a batched version of the forwarding
batched_predict = jax.vmap(network.apply, in_axes=(None, 0))

def loss(params, problems, targets):
    preds = batched_predict(params, problems)
    return -jnp.mean(preds * targets)

def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

@jax.jit
def update(params, opt_state, x, y):
    g = jax.grad(loss)(params, x, y)
    updates, opt_state = optimizer.update(g, opt_state)
    return  optax.apply_updates(params, updates), opt_state

for epoch in range(num_epochs):
    start_time = time.time()
    for batch, (x, y) in enumerate(sat_data):
        y = one_hot(y, 2)
        params, opt_state = update(params, opt_state, x, y)
    epoch_time = time.time() - start_time
    print(f"{batch} done")

    # train_acc = accuracy(params, train_images, train_labels)
    # test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    # print("Training set accuracy {}".format(train_acc))
    # print("Test set accuracy {}".format(test_acc))

In [32]:
train_loader = data.DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=1, shuffle=True)

In [33]:
test_data.__getitem__(0)

done creating objects
done creating GraphsTuple


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  return np.array(list(solution_dict.values()), dtype=np.int)


(SATProblem(graph=GraphsTuple(nodes=array([[1., 0.],
        [1., 0.],
        [1., 0.],
        ...,
        [0., 1.],
        [0., 1.],
        [0., 1.]]), edges=array([[0., 1.],
        [0., 1.],
        [0., 1.],
        ...,
        [1., 0.],
        [1., 0.],
        [1., 0.]]), receivers=array([ 29483,  29483,  29483, ..., 743848, 743849, 743849]), senders=array([   13,    14,    15, ..., 29481, 29460, 29482]), globals=None, n_node=array([743850]), n_edge=array([1455473])), mask=array([1, 1, 1, ..., 0, 0, 0], dtype=int32), params=[29483, 714367, 23]),
 array([0, 0, 0, ..., 0, 0, 0]))

In [2]:
import time
import optax
import haiku as hk
import jax.numpy as jnp

num_epochs = 10

network = hk.without_apply_rng(hk.transform(network_definition))
params = network.init(jax.random.PRNGKey(42), train_data[0][0].graph)
optimizer = optax.adam(2e-4)
opt_state = optimizer.init(params)

@jax.jit
def prediction_loss(params, problem, solution):
    decoded_nodes = network.apply(params, problem.graph)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes) * solution
    return -jnp.sum(log_prob * problem.mask[:, None]) / jnp.sum(problem.mask)

# Make a batched version of the forwarding
batched_predict = jax.vmap(network.apply, in_axes=(None, 0))

def loss(params, problems, targets):
    preds = batched_predict(params, problems)
    return -jnp.mean(preds * targets)

def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

@jax.jit
def update(params, opt_state, x, y):
    g = jax.grad(loss)(params, x, y)
    updates, opt_state = optimizer.update(g, opt_state)
    return  optax.apply_updates(params, updates), opt_state

for epoch in range(num_epochs):
    start_time = time.time()
    for batch, (x, y) in enumerate(sat_data):
        y = one_hot(y, 2)
        params, opt_state = update(params, opt_state, x, y)
    epoch_time = time.time() - start_time
    print(f"{batch} done")

    # train_acc = accuracy(params, train_images, train_labels)
    # test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    # print("Training set accuracy {}".format(train_acc))
    # print("Test set accuracy {}".format(test_acc))

NameError: name 'network_definition' is not defined