In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [184]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(3, h)
        #self.lin.weight.data.uniform_(0,0.01)
        #self.lin.bias.data.uniform_(0,0.001)
        self.lin2 = nn.Linear(h, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(x)
        x = self.lin2(x)
        return x

In [4]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [5]:
def select_variable_reinforce(x, policy):
    logit = policy(x)
    prob = F.softmax(logit, dim=0)
    dist = Categorical(prob.view(-1))
    v = dist.sample()
    return v, dist.log_prob(v)

In [6]:
def compute_true_lit_count(clauses, sol):
    n_clauses = len(clauses)
    true_lit_count = [0] * n_clauses
    for i in range(n_clauses):
        for lit in clauses[i]:
            if sol[abs(lit)] == lit:
                true_lit_count[i] += 1
    return true_lit_count

In [7]:
def do_flip(sol, true_lit_count, v, occur_list):
    sol[v] *= -1
    literal = sol[v]
    for i in occur_list[literal]:
        true_lit_count[i] += 1
    for i in occur_list[-literal]:
        true_lit_count[i] -= 1

In [8]:
# this needs to be more efficient
# consider a subset of the unsat clauses
def stats_per_clause(f, sol, last_10, true_lit_count, unsat_clause):
    """ computes the featutes needed for the model
    """
    r = f.n_variables/ len(f.clauses)
    variables = [abs(v) for v in unsat_clause]
    breaks = np.zeros(len(variables))
    last_5 = last_10[:5]
    for i, v in enumerate(variables):
        broken_count = 0
        literal = sol[v]
        for index in f.occur_list[-literal]:
            if true_lit_count[index] == 1:
                broken_count += 1
        breaks[i] = r*broken_count
    in_last_10 = np.array([int(i + 1 in last_10) for i in variables]) 
    in_last_5 = np.array([int(i + 1 in last_5) for i in variables]) 
    return np.stack([breaks, in_last_10, in_last_5], axis=1), np.array(variables)

In [32]:
def walksat_step(f, true_lit_count, unsat_clause):
    broken_min = float('inf')
    min_breaking_lits = []
    for literal in unsat_clause:
        broken_count = 0
        for index in f.occur_list[-literal]:
            if true_lit_count[index] == 1:
                broken_count += 1
            if broken_count > broken_min:
                break
            if broken_count < broken_min:
                broken_min = broken_count
                min_breaking_lits = [literal]
            elif broken_count == broken_min:
                min_breaking_lits.append(literal)
    return abs(random.choice(min_breaking_lits))

In [42]:
def reinforce_step(f, policy, sol, last_10, true_lit_count, unsat_clause):
    x, variables = stats_per_clause(f, sol, last_10, true_lit_count, unsat_clause)
    x = torch.from_numpy(x).float()
    index, log_prob = select_variable_reinforce(x, policy)
    v = variables[index]
    return v.item(), log_prob, last_10

In [43]:
def generate_episode_reinforce(f, policy, max_flips, walk_prob, walksat=False):
    sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
    true_lit_count = compute_true_lit_count(f.clauses, sol)
    log_probs = []
    flips = 0
    flipped = set()
    backflipped = 0
    unsat_clauses = []
    last_10 = [0]*10
    last_5 = []
    while flips < max_flips:
        unsat_clause_indices = [k for k in range(len(f.clauses)) if true_lit_count[k] == 0]
        unsat_clauses.append(len(unsat_clause_indices))
        sat = not unsat_clause_indices
        if sat:
            break
        unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
        if random.random() < walk_prob:
            v, log_prob = abs(random.choice(unsat_clause)), None
        else:
            if walksat:
                v = walksat_step(f, true_lit_count, unsat_clause)
                log_prob = None
            else:
                v, log_prob, last_10 = reinforce_step(f, policy, sol, last_10, true_lit_count, unsat_clause)
            
            if v not in flipped:
                flipped.add(v)
                last_10.insert(0, v)
                last_10 = last_10[:10]
            else:
                backflipped += 1
        do_flip(sol, true_lit_count, v, f.occur_list)
        flips += 1
        log_probs.append(log_prob)
    return sat, (flips, backflipped, unsat_clauses), (log_probs,)

In [25]:
def reinforce_loss(history, discount):
    log_probs_list = history[0]
    T = len(log_probs_list)
    log_probs_filtered = []
    mask = np.zeros(T, dtype=bool)
    for i, x in enumerate(log_probs_list):
        if x is not None:
            log_probs_filtered.append(x)
            mask[i] = 1

    log_probs = torch.stack(log_probs_filtered)
    partial_rewards = discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
    return -torch.mean(partial_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

In [40]:
def generate_episodes(policy, f, max_fries, max_flips, discount, walk_prob, walksat):
    flips_stats = []
    losses = []
    backflips = []
    num_sols = 0
    for i in range(max_tries):
        out = generate_episode_reinforce(f, policy, max_flips, walk_prob, walksat)
        sat, (flips, backflipped, unsat_clauses), history = out
    
        flips_stats.append(flips)
        backflips.append(backflipped)
        if sat and flips > 0 and not all(map(lambda x: x is None, history[0])):
            loss = reinforce_loss(history, discount)
            losses.append(loss)
        num_sols += sat    
    if losses:
        losses = torch.stack(losses).sum()  
    return np.mean(flips), np.mean(backflips), losses, num_sols/max_tries

In [57]:
def evaluate(policy, data, max_tries, max_flips, discount, walk_prob=.5, walksat=False):
    mean_flips = []
    mean_backflips = []
    mean_losses = []
    accuracy = []
    policy.eval()
    for f in data:
        flips, backflips, losses, acc = generate_episodes(policy, f, max_tries, max_flips, discount,
                                                          walk_prob, walksat)
        mean_flips.append(flips)
        mean_backflips.append(backflips)
        if losses:
            mean_losses.append(losses.item())
        accuracy.append(acc)
        mean_loss = None
        if mean_losses:
            mean_loss = np.mean(mean_losses)
    print(np.mean(mean_flips), np.mean(mean_backflips), mean_loss, np.mean(accuracy))

In [56]:
mean_losses = []
if mean_losses:
    print(True)

In [46]:
def train_epoch(policy, optimizer, data, max_tries = 10, max_flips = 10000, discount = 0.5, walk_prob=.5):
    losses = []
    for f in data:
        policy.train()
        flips, backflips, loss, acc = generate_episodes(
            policy, f, max_tries, max_flips, discount, walk_prob, False)
        if acc > 0:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    print(np.mean(losses))

In [29]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [95]:
train_ds = load_dir("../data/rand3sat/10-43")
#val_ds = load_dir("../data/rand3sat/25-106")

In [155]:
train_ds = load_dir("../data/rand3sat/25-106")
#val_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [185]:
policy = Net(h=10)
optimizer = optim.RMSprop(policy.parameters(), lr=0.01, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [186]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5, walksat=True)

445.11 187.17899999999997 None 0.9990000000000001
418.66 180.21599999999995 None 1.0
482.02 179.59 None 1.0


In [None]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5)

In [160]:
# this is old
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5)

0.15922583890100941


In [None]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

In [162]:
train_epoch(policy, optimizer, train_ds[1000:19000], walk_prob=.5)

0.15637464732345607


In [163]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

869.41 403.2699999999999 0.151531353876926 0.9969999999999999
848.05 415.268 0.15788671686314046 0.993
717.56 382.342 0.157049910267815 0.9940000000000001


In [164]:
change_lr(optimizer, 0.001)

In [165]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5)

0.15107456147880294


In [166]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

861.64 393.195 0.16378629960119725 0.996
873.49 374.065 0.14964713517576456 0.995
832.02 378.163 0.16430504124611617 0.9969999999999999


In [167]:
train_epoch(policy, optimizer, train_ds[1000:19000])

0.16016710352566507


In [168]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

951.89 398.6030000000001 0.1411107956338674 0.99
703.77 397.0199999999999 0.14821282254531978 0.995
772.65 389.138 0.13893511860165744 0.996


In [169]:
train_epoch(policy, optimizer, train_ds[:1000])

0.15675873104389756


In [170]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

719.37 384.3540000000001 0.1605460020666942 0.9940000000000001
639.63 394.7660000000001 0.18049253313802183 0.9919999999999999
758.54 389.92600000000004 0.1717829753877595 0.9940000000000001


In [171]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5)

0.1620051857996732


In [172]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

843.64 373.1000000000001 0.1694360166718252 0.996
662.63 389.443 0.1678937450866215 0.992
1051.82 425.6410000000001 0.14378401993075385 0.988


In [173]:
train_epoch(policy, optimizer, train_ds[1000:1900], walk_prob=.5)

0.15798000554812866


In [174]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

781.03 394.702 0.12954339002724738 0.9890000000000001
967.63 377.03100000000006 0.17004246404860168 0.9939999999999999
683.5 377.32599999999996 0.16464639690238983 0.99


In [175]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5)

0.155600578377489


In [176]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

680.27 403.6 0.1574638346210122 0.995
873.41 386.48600000000005 0.1633128352602944 0.9940000000000001
793.25 401.07300000000004 0.17766096994280814 0.9939999999999999


In [177]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5)

0.16137498652585783


In [178]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, 0.5, 0.5)

586.88 357.41600000000005 0.1631914810091257 0.995
719.2 414.005 0.15400963881518692 0.987
901.22 382.50200000000007 0.1647008233424276 0.993


## Using a model train on the smaller dataset

In [107]:
train_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [108]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5, walksat=True)

416.73 194.42499999999995 None 1.0
343.34 175.79300000000003 None 1.0
382.89 179.39299999999994 None 1.0


In [109]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

980.46 490.61100000000005 0.12019802499125945 0.964
1073.29 492.02 0.13488227373803965 0.96
1213.91 546.9559999999999 0.12811717163567665 0.955


In [110]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.13988973849065317


In [111]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

1041.16 411.69200000000006 0.15662544228835032 0.98
784.59 425.652 0.15783004622673616 0.9769999999999999
717.45 431.52600000000007 0.16076064602937548 0.975


In [112]:
train_epoch(policy, optimizer, train_ds[1000:1900], walk_prob=.5, discount=0.5)

0.17063769628500772


In [113]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

697.78 386.7660000000001 0.17932638720609248 0.987
957.21 396.1840000000001 0.16634278478100895 0.991
971.07 411.4370000000001 0.17387648035306483 0.992


In [114]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount=0.5)

0.15878655840922146


In [115]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, 0.5, 0.5)

883.59 379.3310000000001 0.16394700217992067 0.995
884.46 417.501 0.16850176389794796 0.992
744.65 392.30799999999994 0.1524734530900605 0.993


In [116]:
train_ds = load_dir("../data/kcolor/4-15-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [117]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [118]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5, walksat=True)

1336.9 749.6160000000001 None 0.983
1463.93 722.4960000000001 None 0.986
1901.37 799.176 None 0.982


In [119]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.03242877619287174


In [120]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

3408.94 1671.2410000000004 0.03272874584421515 0.883
2947.71 1495.197 0.03670568984933197 0.9129999999999999
3356.58 1580.612 0.031729125855490564 0.897


In [121]:
train_epoch(policy, optimizer, train_ds[1000:1900], walk_prob=.5, discount = 0.5)

0.036617499960120765


In [122]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

2661.37 1253.543 0.04447664735023864 0.9390000000000001
2536.92 1276.642 0.046337029247079047 0.93
3263.22 1243.623 0.04456117057823576 0.939


In [123]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.04779403054586146


In [124]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

1620.53 765.1440000000001 0.058775811828672885 0.985
1684.17 768.3710000000001 0.05517783605493605 0.98
1812.28 791.8889999999999 0.05703072253614664 0.981


In [125]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount=0.5)

0.05916170927230269


In [126]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

1045.23 499.50100000000003 0.06391769652254879 1.0
1290.13 564.948 0.06700162066146731 0.996
962.88 527.331 0.06789143403060734 0.998


In [127]:
train_epoch(policy, optimizer, train_ds[:500], walk_prob=.5, discount=0.5)

0.06636926405038684


In [128]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

1116.17 506.40399999999994 0.06637105594854802 0.997
1554.8 523.349 0.07122086943127215 0.995
1101.52 471.147 0.06678609466180205 0.998


In [129]:
train_epoch(policy, optimizer, train_ds[900:1900], walk_prob=.5, discount=0.5)

0.06597169480565936


In [130]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

1072.57 452.00800000000004 0.06657401758246123 0.9990000000000001
856.1 417.89 0.07686279580928385 0.9990000000000001
1035.32 430.24700000000007 0.07512179552577436 0.998


In [132]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.07190380474086851


In [133]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

919.97 393.606 0.0699298278708011 0.9990000000000001
1114.82 418.51700000000005 0.07240340700373053 0.9990000000000001
1052.15 434.9 0.06658170715905726 0.9990000000000001


In [134]:
change_lr(optimizer, 0.0001)

In [135]:
train_epoch(policy, optimizer, train_ds[900:1900], walk_prob=.5, discount=0.5)

0.06914114368450829


In [136]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5)

1174.81 439.64599999999996 0.06793487512040883 0.9990000000000001
922.37 409.96899999999994 0.07614302545785905 0.9990000000000001
1128.09 420.574 0.08216650203801691 1.0


In [84]:
train_ds = load_dir("../data/kcolor/3-10-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [85]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [86]:
for i in range(3):
    evaluate(policy, val_ds , max_tries, max_flips, discount=0.5, walk_prob=.5, walksat=True)

121.54 53.349999999999994 None 1.0
136.97 54.977 None 1.0
155.14 54.88500000000001 None 1.0


In [87]:
train_epoch(policy, optimizer, train_ds[:1000], walk_prob=.5, discount=0.5)

0.23722988438606263


In [88]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

193.22 76.306 0.23265512000769376 1.0
209.12 74.878 0.262177706733346 1.0
201.46 74.525 0.24403006821870804 1.0


In [89]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount = 0.5)

0.24774086069129408


In [90]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

152.72 57.14699999999998 0.26695315714925527 1.0
161.81 52.11600000000001 0.28142453625798225 1.0
158.83 52.914000000000016 0.273628425039351 1.0


In [91]:
train_epoch(policy, optimizer, train_ds[500:1500], walk_prob=.5, discount = 0.5)

0.27094397219642996


In [92]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

105.26 37.832 0.27669035162776706 1.0
119.64 40.663 0.29247995145618916 1.0
100.49 37.184 0.28421006713062524 1.0


In [93]:
train_epoch(policy, optimizer, train_ds[500:1900], walk_prob=.5, discount = 0.5)

0.2762653423872377


In [94]:
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

102.92 32.092999999999996 0.2857779210805893 1.0
104.05 34.479 0.26322814900428054 1.0
109.89 34.550999999999995 0.276728473380208 1.0
