In [1]:

%load_ext autoreload
%autoreload 2
%load_ext memory_profiler

In [2]:
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

In [3]:
import jax

from constraint_problems import get_problem_from_cnf
from model import train_model
from random_walk import moser_walk_sampler, moser_walk

In [4]:
# from pysat.formula import CNF
# import glob
# instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/uf20-91/*.cnf')]
# train_instances, test_instances = instances[:950],instances[950:]

In [5]:
from pysat.formula import CNF
# import glob
# test_instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/blocksworld/*.cnf') if f == '../Data/blocksworld/bw_large.d.cnf']
# train_instances = [get_problem_from_cnf(CNF(from_file=f)) for f in glob.glob('../Data/blocksworld/*.cnf') if f != '../Data/blocksworld/bw_large.d.cnf']

In [6]:
train = [get_problem_from_cnf(CNF(from_file="../Data/blocksworld/huge.cnf")), get_problem_from_cnf(CNF(from_file="../Data/blocksworld/medium.cnf"))]
test = [get_problem_from_cnf(CNF(from_file="../Data/blocksworld/bw_large.a.cnf"))]

In [7]:
# for i in test_instances:
#     print(i.graph.n_node, i.graph.n_edge, i.params)

In [8]:
network, params = train_model(1000,0.02,train, test, num_steps=10)

step %r loss train %r test %r 0 18.0


In [9]:
planning_problem = test[0]
prdictions = network.apply(params, planning_problem.graph)
weights = jax.nn.softmax(prdictions)[:,1]
trajectory, energies = moser_walk_sampler(weights, planning_problem, 1000, 0)

In [12]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    # Run the operations to be profiled
    final_assignment, energy = moser_walk(weights, planning_problem, 1000, 0)

2022-10-27 10:24:52.561864: E external/org_tensorflow/tensorflow/python/profiler/internal/python_hooks.cc:373] Can't import tensorflow.python.profiler.trace
2022-10-27 10:25:01.698736: E external/org_tensorflow/tensorflow/python/profiler/internal/python_hooks.cc:373] Can't import tensorflow.python.profiler.trace


Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


127.0.0.1 - - [27/Oct/2022 10:25:21] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -


In [15]:
%timeit final_assignment, energy = moser_walk(weights, planning_problem, 1000, 0)

7.13 s ± 378 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
test[0].graph.edges[:, 1].shape

(10809,)

In [15]:
import numpy as np
np.histogram(weights)

(array([5105,    5,    1,    0,    0,    4,    2,    0,    7,   10]),
 array([0.49260545, 0.5084288 , 0.5242521 , 0.5400754 , 0.5558988 ,
        0.57172215, 0.58754545, 0.60336876, 0.6191921 , 0.6350155 ,
        0.6508388 ], dtype=float32))

In [58]:
import jax.numpy as jnp

def violated_constraints2(edges, senders, mask, assignment):
    edge_is_violated = jnp.mod(
        edges + assignment[senders], 2
    )
    violated_constraint_edges = edge_is_violated @ mask  # (x,) @ (x,m)  = (m,)
    return violated_constraint_edges == 1

In [59]:
rng_key = jax.random.PRNGKey(0)
random_assignment = jax.random.bernoulli(rng_key, weights)

edges = planning_problem.graph.edges[:,1].astype(np.int32)
senders = planning_problem.graph.senders
mask, _, _ = planning_problem.constraint_utils

%timeit violated_constraints2(edges, senders, mask, random_assignment).block_until_ready()
%timeit jax.jit(violated_constraints2)(edges, senders, mask, random_assignment).block_until_ready()

7.81 ms ± 651 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.6 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [51]:
a = np.random.rand(*mask.shape)

In [56]:
%timeit (edges @ a)

AttributeError: 'numpy.ndarray' object has no attribute 'block_until_ready'

In [57]:
%timeit jax.jit(lambda x : edges @ x)(a).block_until_ready()

109 ms ± 8.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
from torch.utils import data
import glob
import os
import pickle
import gzip

In [7]:
class SATTrainingDataset(data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        processed_instances = glob.glob(os.path.join(data_dir, "*.pkl"))
        self.solved_instances = []
        for f in processed_instances:
            with open(f, "rb") as fl:
                s = pickle.load(fl)
                if not s:
                    continue
                self.solved_instances.append(f.split("_")[0])
        self.ratio_of_solved_instances = self.__len__()/len(processed_instances)

    def __len__(self):
        return len(self.solved_instances)

    @staticmethod
    def solution_dict_to_array(solution_dict):
        return np.array(list(solution_dict.values()), dtype=np.int)

    def __getitem__(self, idx):
        instance_name = self.solved_instances[idx]
        path = os.path.join(self.data_dir, instance_name)
        problem_file = gzip.open(path + ".cnf.gz", "rt")
        problem = get_problem_from_cnf(CNF(from_file=problem_file))
        solution_file = open(path + "_sol.pkl")
        solution_dict = pickle.load(solution_file)
        return problem, self.solution_dict_to_array(solution_dict)



In [8]:
sat_data = SATTrainingDataset("/Users/p390943/Downloads/10K")

In [9]:
train_data, test_data = data.random_split(sat_data, [0.8,0.2])

In [10]:
train_loader = data.DataLoader(train_data, batch_size=20, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=20, shuffle=True)

In [None]:
import time

num_epochs = 10


for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))