In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [78]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(3, h)
        self.lin.weight.data.uniform_(-0.001, 0)
        self.lin.bias.data.uniform_(-.0001, 0)
        self.dropout = nn.Dropout(0.5)
        self.lin2 = nn.Linear(h, 1)
        self.lin2.weight.data.uniform_(0,0.0001)
        self.lin2.bias.data.uniform_(0,0.001)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(self.dropout(x))
        x = self.lin2(x)
        return x

In [79]:
def init_net(model):
    with torch.no_grad():
        model.lin.weight[0, 0] = 10
        model.lin2.weight[0, 0] = -1

In [4]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.lin = nn.Linear(3, 1)
        self.lin.weight.data.uniform_(0,0.01)
    def forward(self, x):
        x = self.lin(x)
        return x

In [73]:
def init_net2(policy):
    with torch.no_grad():
        policy.lin.weight[0, 0] = -1
        policy.lin.weight[0, 1] = 0
        policy.lin.weight[0, 2] = 0
        policy.lin.bias[0] = 0

In [5]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [6]:
def select_variable_reinforce(x, policy):
    logit = policy(x)
    prob = F.softmax(logit, dim=0)
    dist = Categorical(prob.view(-1))
    v = dist.sample()
    return v, dist.log_prob(v)

In [7]:
def compute_true_lit_count(clauses, sol):
    n_clauses = len(clauses)
    true_lit_count = [0] * n_clauses
    for i in range(n_clauses):
        for lit in clauses[i]:
            if sol[abs(lit)] == lit:
                true_lit_count[i] += 1
    return true_lit_count

In [8]:
def do_flip(sol, true_lit_count, literal, occur_list):
    for i in occur_list[literal]:
        true_lit_count[i] += 1
    for i in occur_list[-literal]:
        true_lit_count[i] -= 1
    sol[abs(literal)] *= -1

In [91]:
# this needs to be more efficient
# consider a subset of the unsat clauses
def stats_per_clause(f, last_10, true_lit_count, unsat_clause):
    """ computes the featutes needed for the model
    """
    r = f.n_variables/ len(f.clauses)
    variables = [abs(v) for v in unsat_clause]
    breaks = np.zeros(len(variables))
    last_5 = last_10[:5]
    for i, literal in enumerate(unsat_clause):
        broken_count = 0
        for index in f.occur_list[-literal]:
            if true_lit_count[index] == 1:
                broken_count += 1
        breaks[i] = broken_count
    in_last_10 = np.array([int(i + 1 in last_10) for i in variables]) 
    in_last_5 = np.array([int(i + 1 in last_5) for i in variables]) 
    return np.stack([breaks, in_last_10, in_last_5], axis=1)

In [10]:
def walksat_step(f, true_lit_count, unsat_clause):
    """Returns chosen literal"""
    broken_min = float('inf')
    min_breaking_lits = []
    for literal in unsat_clause:
        broken_count = 0
        for index in f.occur_list[-literal]:
            if true_lit_count[index] == 1:
                broken_count += 1
            if broken_count > broken_min:
                break
        if broken_count < broken_min:
            broken_min = broken_count
            min_breaking_lits = [literal]
        elif broken_count == broken_min:
            min_breaking_lits.append(literal)
    return abs(random.choice(min_breaking_lits))

In [11]:
def reinforce_step(f, policy, last_10, true_lit_count, unsat_clause):
    x = stats_per_clause(f, last_10, true_lit_count, unsat_clause)
    x = torch.from_numpy(x).float()
    index, log_prob = select_variable_reinforce(x, policy)
    literal = unsat_clause[index]
    return literal, log_prob, last_10

In [12]:
def generate_episode_reinforce(f, policy, max_flips, walk_prob=0.5, walksat=False):
    sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
    true_lit_count = compute_true_lit_count(f.clauses, sol)
    log_probs = []
    flips = 0
    flipped = set()
    backflipped = 0
    unsat_clauses = []
    last_10 = [0]*10
    while flips < max_flips:
        unsat_clause_indices = [k for k in range(len(f.clauses)) if true_lit_count[k] == 0]
        unsat_clauses.append(len(unsat_clause_indices))
        sat = not unsat_clause_indices
        if sat:
            break
        unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
        if random.random() < walk_prob:
            literal, log_prob = random.choice(unsat_clause), None
        else:
            if walksat:
                literal = walksat_step(f, true_lit_count, unsat_clause)
                log_prob = None
            else:
                literal, log_prob, last_10 = reinforce_step(f, policy, last_10, true_lit_count, unsat_clause)
            v = abs(literal)
            if v not in flipped:
                flipped.add(v)
                last_10.insert(0, v)
                last_10 = last_10[:10]
            else:
                backflipped += 1
        do_flip(sol, true_lit_count, literal, f.occur_list)
        flips += 1
        log_probs.append(log_prob)
    return sat, (flips, backflipped, unsat_clauses), (log_probs,)

In [13]:
def reinforce_loss(history, discount):
    log_probs_list = history[0]
    T = len(log_probs_list)
    log_probs_filtered = []
    mask = np.zeros(T, dtype=bool)
    for i, x in enumerate(log_probs_list):
        if x is not None:
            log_probs_filtered.append(x)
            mask[i] = 1

    log_probs = torch.stack(log_probs_filtered)
    partial_rewards = discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
    return -torch.mean(partial_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

In [14]:
def generate_episodes(policy, f, max_fries, max_flips, discount, walk_prob, walksat):
    flips_stats = []
    losses = []
    backflips = []
    num_sols = 0
    for i in range(max_tries):
        out = generate_episode_reinforce(f, policy, max_flips, walk_prob, walksat)
        sat, (flips, backflipped, unsat_clauses), history = out
    
        flips_stats.append(flips)
        backflips.append(backflipped)
        if sat and flips > 0 and not all(map(lambda x: x is None, history[0])):
            loss = reinforce_loss(history, discount)
            losses.append(loss)
        num_sols += sat    
    if losses:
        losses = torch.stack(losses).sum()  
    return np.mean(flips), np.mean(backflips), losses, num_sols/max_tries

In [15]:
def evaluate(policy, data, max_tries, max_flips, discount, walk_prob=.5, walksat=False):
    mean_flips = []
    mean_backflips = []
    mean_losses = []
    accuracy = []
    policy.eval()
    for f in data:
        flips, backflips, losses, acc = generate_episodes(policy, f, max_tries, max_flips, discount,
                                                          walk_prob, walksat)
        mean_flips.append(flips)
        mean_backflips.append(backflips)
        if losses:
            mean_losses.append(losses.item())
        accuracy.append(acc)
        mean_loss = None
        if mean_losses:
            mean_loss = np.mean(mean_losses)
    print(np.mean(mean_flips), np.mean(mean_backflips), mean_loss, np.mean(accuracy))

In [16]:
def train_epoch(policy, optimizer, data, max_tries = 10, max_flips = 10000, discount = 0.5, walk_prob=.5):
    losses = []
    for f in data:
        policy.train()
        flips, backflips, loss, acc = generate_episodes(
            policy, f, max_tries, max_flips, discount, walk_prob, False)
        if acc > 0:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print(np.mean(losses))

In [17]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [18]:
train_ds = load_dir("../data/rand3sat/10-43")
#val_ds = load_dir("../data/rand3sat/25-106")

In [25]:
train_ds = load_dir("../data/rand3sat/25-106")
#val_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [80]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.01, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [81]:
init_net(policy)

In [82]:
policy.lin2.weight

Parameter containing:
tensor([[-1.0000e+00,  8.6174e-05,  2.7535e-05,  6.9441e-05,  4.5019e-05]],
       requires_grad=True)

In [28]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5, walksat=True)

70.27 20.523999999999997 None 1.0
65.62 19.508000000000003 None 1.0
70.35 20.115 None 1.0


In [50]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

137.04 66.489 0.08295307789696381 0.9990000000000001
135.61 64.2 0.08592267288826407 1.0
160.76 63.029999999999994 0.0900525104929693 1.0


In [83]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

168.37 59.137999999999984 0.08420275656040759 1.0
183.32 67.34100000000001 0.08476282902061939 1.0
234.61 69.435 0.08331708039157093 1.0


In [207]:
train_epoch(policy, optimizer, train_ds[:500], walk_prob=.5)

0.8294109015241266


In [86]:
train_ds = load_dir("../data/kcolor/4-15-0.5/")
train_ds = load_dir("../data/kcolor/3-10-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [87]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [88]:
init_net(policy)

In [89]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5, walksat=True)

47.77 11.647 None 1.0
50.96 11.69 None 1.0
47.4 11.472999999999999 None 1.0


In [90]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

67.14 17.039 0.14500915387645363 1.0
68.77 17.336000000000002 0.15029805906116964 1.0
65.06 17.078999999999997 0.14786683354526758 1.0


In [92]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.15667497039586306


In [93]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

59.22 18.176 0.15272776201367377 1.0
65.53 18.135 0.14240527544170617 1.0
70.62 18.245 0.1481259680353105 1.0


In [None]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount = 0.5)

In [None]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

In [76]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount = 0.5)

0.23761244366131723


In [77]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

161.09 64.67200000000001 0.2483160090446472 1.0
169.22 72.015 0.25875902105122806 1.0
181.51 62.083999999999996 0.26591289695352316 1.0


In [78]:
train_epoch(policy, optimizer, train_ds[500:1900], walk_prob=.5, discount = 0.5)

0.21902772487673378


In [79]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

137.88 54.611999999999995 0.2724594607204199 1.0
191.53 57.958999999999996 0.25007543120533227 1.0
156.28 61.147 0.27258397758007047 1.0


In [80]:
train_epoch(policy, optimizer, train_ds, walk_prob=.5, discount = 0.5)

0.210308761643736


In [81]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

134.7 50.357000000000006 0.2677038526907563 1.0
112.72 46.394 0.27576946139335634 1.0
124.73 46.220000000000006 0.25905397575348615 1.0


In [82]:
train_epoch(policy, optimizer, train_ds, walk_prob=.5, discount = 0.5)
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

0.20099698880687356
121.66 45.887 0.26536831188946963 1.0
140.23 43.32200000000001 0.23908030308783054 1.0
150.35 43.584999999999994 0.2926462286710739 1.0


In [83]:
train_epoch(policy, optimizer, train_ds, walk_prob=.5, discount = 0.5)
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

0.19399403272000582
129.07 40.751999999999995 0.29220971882343294 1.0
123.3 41.93299999999999 0.2670344277843833 1.0
128.87 43.587999999999994 0.23773450043052435 1.0


In [40]:
(np.e)**0.9

2.4596031111569494