In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [133]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.lin = nn.Linear(3, 5)
        self.lin.weight.data.uniform_(0,0.01)
        self.lin.bias.data.uniform_(0,0.001)
        self.lin2 = nn.Linear(5, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(x)
        x = self.lin2(x)
        return x

In [117]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [118]:
def select_variable_reinforce(x, policy):
    logit = policy(x)
    prob = F.softmax(logit, dim=0)
    dist = Categorical(prob.view(-1))
    v = dist.sample()
    return v, dist.log_prob(v)

In [119]:
def compute_true_lit_count(clauses, sol):
    n_clauses = len(clauses)
    true_lit_count = [0] * n_clauses
    for i in range(n_clauses):
        for lit in clauses[i]:
            if sol[abs(lit)] == lit:
                true_lit_count[i] += 1
    return true_lit_count

In [120]:
def do_flip(sol, true_lit_count, v, occur_list):
    sol[v + 1] *= -1
    literal = sol[v + 1]
    for i in occur_list[literal]:
        true_lit_count[i] += 1
    for i in occur_list[-literal]:
        true_lit_count[i] -= 1

In [142]:
# this needs to be more efficient
# consider a subset of the unsat clauses
def stats(f, sol, last_10, last_5, true_lit_count):
    """ computes the featutes needed for the model
    """
    r = f.n_variables/ len(f.clauses)
    breaks = np.zeros(f.n_variables)
    for v in range(1, len(sol)):
        broken_count = 0
        literal = sol[v]
        for index in f.occur_list[-literal]:
            if true_lit_count[index] == 1:
                broken_count += 1
        breaks[v - 1] = r*broken_count
    in_last_10 = np.array([int(i + 1 in last_10) for i in range(f.n_variables)]) 
    in_last_5 = np.array([int(i + 1 in last_5) for i in range(f.n_variables)]) 
    return np.stack([breaks, in_last_10, in_last_5], axis=1)

In [147]:
 def generate_episode_reinforce(f, policy, max_flips, walk_prob):
    sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
    true_lit_count = compute_true_lit_count(f.clauses, sol)
    log_probs = []
    flips = 0
    flipped = set()
    backflipped = 0
    unsat_clauses = []
    last_10 = [0]*10
    last_5 = []
    while flips < max_flips:
        unsat_clause_indices = [k for k in range(len(f.clauses)) if true_lit_count[k] == 0]
        unsat_clauses.append(len(unsat_clause_indices))
        sat = not unsat_clause_indices
        if sat:
            break
        if random.random() < walk_prob:
            unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
            v, log_prob = abs(random.choice(unsat_clause)) - 1, None
        else:
            x = stats(f, sol, last_10, last_5, true_lit_count)
            x = torch.from_numpy(x).float()
            v, log_prob = select_variable_reinforce(x, policy)
            if v.item() not in flipped:
                flipped.add(v.item())
                last_10.insert(0, v.item() + 1)
                last_10 = last_10[:10]
                last_5 = last_10[:5]
            else:
                backflipped += 1
        do_flip(sol, true_lit_count, v, f.occur_list)
        flips += 1
        log_probs.append(log_prob)
    return sat, (flips, backflipped, unsat_clauses), (log_probs,)

In [144]:
def reinforce_loss(history, discount):
    log_probs_list = history[0]
    T = len(log_probs_list)
    log_probs_filtered = []
    mask = np.zeros(T, dtype=bool)
    for i, x in enumerate(log_probs_list):
        if x is not None:
            log_probs_filtered.append(x)
            mask[i] = 1

    log_probs = torch.stack(log_probs_filtered)
    partial_rewards = discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
    return -torch.mean(partial_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

In [145]:
def generate_episodes(policy, f, max_fries, max_flips, discount, walk_prob=.5):
    flips_stats = []
    losses = []
    backflips = []
    num_sols = 0
    for i in range(max_tries):
        out = generate_episode_reinforce(f, policy, max_flips=max_flips, walk_prob=walk_prob)
        sat, (flips, backflipped, unsat_clauses), history = out
    
        flips_stats.append(flips)
        backflips.append(backflipped)
        if sat and flips > 0 and not all(map(lambda x: x is None, history[0])):
            loss = reinforce_loss(history, discount)
            losses.append(loss)
        num_sols += sat    
    if losses:
        losses = torch.stack(losses).sum()  
    return np.mean(flips), np.mean(backflips), losses, num_sols/max_tries

In [137]:
def evaluate(policy, data, max_tries, max_flips, discount, walk_prob=.5):
    mean_flips = []
    mean_backflips = []
    mean_losses = []
    accuracy = []
    policy.eval()
    for f in data:
        flips, backflips, losses, acc = generate_episodes(policy, f, max_tries, max_flips, discount,
                                                          walk_prob)
        mean_flips.append(flips)
        mean_backflips.append(backflips)
        if losses:
            mean_losses.append(losses.item())
        accuracy.append(acc)
    print(np.mean(mean_flips), np.mean(mean_backflips), np.mean(mean_losses), np.mean(accuracy))

In [None]:
def train_epoch(policy, optimizer, data, max_tries = 10, max_flips = 10000, discount = 0.5, walk_prob=.5):
    losses = []
    for f in data:
        policy.train()
        flips, backflips, loss, acc = generate_episodes(policy, f, max_tries, max_flips, discount, walk_prob)
        if acc > 0:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print(np.mean(losses))

In [156]:
train_ds = load_dir("../data/rand3sat/10-43")
val_ds = load_dir("../data/rand3sat/25-106")
len(train_ds), len(val_ds)

(2000, 2000)

In [157]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [158]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [None]:
for i in range(5):
    evaluate(policy, val_ds[:100] , max_tries, max_flips, 0.5, 0.5)

In [None]:
train_epoch(policy, optimizer, train_ds[:500], walk_prob=.5)

In [None]:
for i in range(3):
    evaluate(policy, val_ds[:100] , max_tries, max_flips, 0.5, 0.5)

In [None]:
train_epoch(policy, optimizer, train_ds[500:1000], walk_prob=.5)

In [152]:
for i in range(3):
    evaluate(policy, val_ds[:100] , max_tries, max_flips, 0.5, 0.5)

58.8 28.726 1.7562458352930843 1.0
68.1 31.000999999999994 1.6202993234992027 1.0
69.8 27.831999999999997 1.511616210155189 1.0


In [153]:
change_lr(optimizer, 0.0005)

In [154]:
train_epoch(policy, optimizer, train_ds[1000:2000])

[Parameter containing:
tensor([[-0.3106,  0.2155,  0.2198],
        [-0.3193,  0.2277,  0.2245],
        [ 0.7599, -0.1529, -0.1491],
        [ 0.7396, -0.1564, -0.1603],
        [ 0.7672, -0.1613, -0.1648]], requires_grad=True), Parameter containing:
tensor([0.3773, 0.3879, 0.3021, 0.3200, 0.3262], requires_grad=True), Parameter containing:
tensor([[-1.0492, -0.8462,  0.9129,  1.1484,  0.9808]], requires_grad=True), Parameter containing:
tensor([-0.0013], requires_grad=True)]
[Parameter containing:
tensor([[-0.4077,  0.2480,  0.2523],
        [-0.4169,  0.2604,  0.2572],
        [ 1.0000, -0.1836, -0.1797],
        [ 0.9783, -0.1869, -0.1909],
        [ 1.0068, -0.1919, -0.1953]], requires_grad=True), Parameter containing:
tensor([0.4781, 0.4899, 0.3548, 0.3693, 0.3788], requires_grad=True), Parameter containing:
tensor([[-1.2844, -1.0807,  1.1638,  1.3998,  1.2319]], requires_grad=True), Parameter containing:
tensor([-0.0009], requires_grad=True)]


In [155]:
for i in range(5):
    evaluate(policy, val_ds[:100], max_tries, max_flips, 0.5, 0.5)

64.54 29.662000000000003 1.3064033984579146 1.0
73.82 27.777 1.1989874884858727 1.0
119.43 30.728999999999996 1.2746308460831641 1.0
64.36 30.643999999999995 1.2890113532729446 1.0
77.1 28.85 1.3198073209263383 1.0
