In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [63]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(4, h)
        self.dropout = nn.Dropout(0.5)
        self.lin2 = nn.Linear(h, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(self.dropout(x))
        x = self.lin2(x)
        return x

In [4]:
def init_net(model):
    with torch.no_grad():
        model.lin.weight[0, 0] = 10
        model.lin2.weight[0, 0] = -1

In [5]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.lin = nn.Linear(4, 1)
    def forward(self, x):
        x = self.lin(x)
        return x

In [6]:
def init_net2(policy):
    with torch.no_grad():
        policy.lin.weight[0, 0] = -1
        policy.lin.weight[0, 1] = 0
        policy.lin.weight[0, 2] = 0
        policy.lin.weight[0, 3] = 0
        policy.lin.bias[0] = 0

In [17]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [8]:
class WalkSATLN:
    def __init__(self, policy, max_tries, max_flips, p=0.5, discount=0.5):
        self.policy = policy
        self.max_tries = max_tries
        self.max_flips = max_flips
        self.p = p
        self.discount = discount
        self.flips_to_solution = []
        self.backflips = []
        self.unsat_clauses = []
        self.age = []
        self.last_10 = []
        self.sol = []
        
    def compute_true_lit_count(self, clauses):
        n_clauses = len(clauses)
        true_lit_count = [0] * n_clauses
        for index in range(n_clauses):
            for literal in clauses[index]:
                if self.sol[abs(literal)] == literal:
                    true_lit_count[index] += 1
        return true_lit_count
    
    def select_variable_reinforce(self, x):
        logit = self.policy(x)
        prob = F.softmax(logit, dim=0)
        dist = Categorical(prob.view(-1))
        v = dist.sample()
        return v, dist.log_prob(v)
    
    def do_flip(self, literal, occur_list):
        for i in occur_list[literal]:
            self.true_lit_count[i] += 1
        for i in occur_list[-literal]:
            self.true_lit_count[i] -= 1
        self.sol[abs(literal)] *= -1
        
    def stats_per_clause(self, f, unsat_clause):
        """ computes the featutes needed for the model
        """ 
        r = f.n_variables/ len(f.clauses)
        variables = [abs(v) for v in unsat_clause]
        breaks = np.zeros(len(variables))
        last_5 = self.last_10[:5]
        for i, literal in enumerate(unsat_clause):
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
            breaks[i] = broken_count
        in_last_10 = np.array([int(v in self.last_10) for v in variables]) 
        age = np.array([self.age[v] for v in variables])/(self.age[0] + 1)
        in_last_5 = np.array([int(v in last_5) for v in variables]) 
        return np.stack([breaks, in_last_10, in_last_5, age], axis=1)
    
    def walksat_step(self, f, unsat_clause):
        """Returns chosen literal"""
        broken_min = float('inf')
        min_breaking_lits = []
        for literal in unsat_clause:
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
                if broken_count > broken_min:
                    break
            if broken_count < broken_min:
                broken_min = broken_count
                min_breaking_lits = [literal]
            elif broken_count == broken_min:
                min_breaking_lits.append(literal)
        return abs(random.choice(min_breaking_lits))
    
    def reinforce_step(self, f, unsat_clause):
        x = self.stats_per_clause(f, unsat_clause)
        x = torch.from_numpy(x).float()
        index, log_prob = self.select_variable_reinforce(x)
        literal = unsat_clause[index]
        return literal, log_prob
    
    def generate_episode_reinforce(self, f, walksat):
        self.sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
        self.true_lit_count = self.compute_true_lit_count(f.clauses)
        self.age = np.zeros(f.n_variables + 1)
        log_probs = []
        flips = 0
        flipped = set()
        backflipped = 0
        while flips < max_flips:
            unsat_clause_indices = [k for k in range(len(f.clauses)) if self.true_lit_count[k] == 0]
            sat = not unsat_clause_indices
            if sat:
                break
            unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
            log_prob = None
            if random.random() < self.p:
                literal = random.choice(unsat_clause)
            else:
                if walksat:
                    literal = self.walksat_step(f, unsat_clause)   
                else:
                    literal, log_prob = self.reinforce_step(f, unsat_clause)
                v = abs(literal)
                if v not in flipped:
                    flipped.add(v)
                else:
                    backflipped += 1
                self.last_10.insert(0, v)
                self.last_10 = self.last_10[:10]
                self.age[v] = flips
            self.do_flip(literal, f.occur_list)
            flips += 1
            self.age[0] = flips
            log_probs.append(log_prob)
        return sat, flips, backflipped, log_probs

    def reinforce_loss(self, log_probs_list):
        T = len(log_probs_list)
        log_probs_filtered = []
        mask = np.zeros(T, dtype=bool)
        for i, x in enumerate(log_probs_list):
            if x is not None:
                log_probs_filtered.append(x)
                mask[i] = 1

        log_probs = torch.stack(log_probs_filtered)
        p_rewards = self.discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
        return -torch.mean(p_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

    def generate_episodes(self, f, walksat=False):
        flips_stats = []
        losses = []
        backflips = []
        num_sols = 0
        for i in range(self.max_tries):
            sat, flips, backflipped, log_probs = self.generate_episode_reinforce(f, walksat)
            flips_stats.append(flips)
            backflips.append(backflipped)
            if sat and flips > 0 and not all(map(lambda x: x is None, log_probs)):
                loss = self.reinforce_loss(log_probs)
                losses.append(loss)
            num_sols += sat    
        if losses:
            losses = torch.stack(losses).sum()  
        return np.mean(flips), np.mean(backflips), losses, num_sols/self.max_tries
    
    
    def evaluate(self, data, walksat=False):
        mean_flips = []
        mean_backflips = []
        mean_losses = []
        accuracy = []
        self.policy.eval()
        for f in data:
            flips, backflips, losses, acc = self.generate_episodes(f, walksat)
            mean_flips.append(flips)
            mean_backflips.append(backflips)
            if losses:
                mean_losses.append(losses.item())
            accuracy.append(acc)
            mean_loss = None
            if mean_losses:
                mean_loss = np.mean(mean_losses)
        print(np.mean(mean_flips), np.mean(mean_backflips), mean_loss, np.mean(accuracy))
        
    def train_epoch(self, optimizer, data):
        losses = []
        for f in data:
            self.policy.train()
            flips, backflips, loss, acc = self.generate_episodes(f)
            if acc > 0:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
        print(np.mean(losses))

In [9]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [10]:
train_ds = load_dir("../data/rand3sat/10-43")
#val_ds = load_dir("../data/rand3sat/25-106")

In [11]:
train_ds = load_dir("../data/rand3sat/25-106")
#val_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [12]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.01, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [13]:
init_net(policy)

In [14]:
policy.lin2.weight

Parameter containing:
tensor([[-1.0000, -0.3289, -0.2722,  0.3029, -0.3633]], requires_grad=True)

In [15]:
train_ds = load_dir("../data/kcolor/4-15-0.5/")
train_ds = load_dir("../data/kcolor/3-10-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [80]:
policy = Net2()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=0)
max_tries = 10
max_flips = 10000

In [81]:
init_net2(policy)

In [82]:
#init_net(policy)
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-1.,  0.,  0.,  0.]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True)]

In [83]:
max_tries = 10
max_flips = 10000

In [84]:
ls = WalkSATLN(policy, max_tries, max_flips, p=0.5)

In [85]:
for i in range(3):
    ls.evaluate(val_ds, walksat=True)

46.58 11.414000000000001 None 1.0
45.52 11.923000000000002 None 1.0
47.01 11.539000000000001 None 1.0


In [86]:
for i in range(3):
    ls.evaluate(val_ds)

89.64 29.127000000000002 0.2603824938088655 1.0
98.23 30.545 0.2574538692086935 1.0
94.22 29.067000000000007 0.25322111681103704 1.0


In [87]:
ls.train_epoch(optimizer, train_ds)

0.1812230912164638


In [88]:
for i in range(3):
    ls.evaluate(val_ds)

58.21 12.152999999999999 0.14510542072355748 1.0
58.22 12.187000000000003 0.14102069918066262 1.0
55.31 12.955 0.1360379394888878 1.0


In [89]:
for j in range(10):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.12453894591056987
50.01 11.043000000000005 0.11110588377341628 1.0
0.10370882746616476
48.02 10.163999999999998 0.10437269033864141 1.0
0.09388547586620247
57.51 9.96 0.09492984515614808 1.0
0.08792231511483949
46.17 9.558 0.08627009075134992 1.0
0.08383730543581279
48.29 9.048 0.08164749088697136 1.0
0.0819346508753829
46.43 9.223 0.07573967995122075 1.0
0.07892688288230841
44.93 10.091999999999999 0.082134427735582 1.0
0.07633932366538303
48.02 9.06 0.07406406317371875 1.0
0.0770908094403383
49.24 9.122 0.07616797674447298 1.0
0.07556656423490495
45.82 9.402000000000001 0.07587006560293957 1.0


In [90]:
for i in range(3):
    ls.evaluate(val_ds)

51.37 9.123 0.08137148874811828 1.0
44.49 9.124999999999998 0.07867962806951255 1.0
45.86 8.37 0.07515300078783184 1.0


In [91]:
for j in range(5):
    ls.train_epoch(optimizer, train_ds)

0.07313184804632328
0.07416222611313539
0.07207129431232859
0.07216185999225433
0.07024976932737781


In [92]:
for i in range(3):
    ls.evaluate(val_ds)

40.58 8.068000000000001 0.07303019106388092 1.0
49.0 8.913 0.06776832322590053 1.0
49.04 8.673 0.07228743204846978 1.0


In [93]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-16.5120,  -1.4553,  -4.0445,  -5.7023]], requires_grad=True),
 Parameter containing:
 tensor([-0.0211], requires_grad=True)]

In [57]:
#train_ds = load_dir("../data/rand3sat/25-106")
train_ds = load_dir("../data/kclique/3-5-0.2/")
val_ds = train_ds[1800:]
train_ds = train_ds[:1800]
len(train_ds), len(val_ds)

(1800, 160)

In [64]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=0)
#init_net2(policy)
max_tries = 10
max_flips = 10000

In [65]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-0.4709, -0.4295, -0.1632,  0.1809],
         [ 0.2581,  0.4259, -0.3950, -0.3905],
         [-0.3567, -0.0589, -0.4169, -0.3181],
         [-0.0271, -0.4185,  0.4220,  0.3886],
         [ 0.4473, -0.2143, -0.3733, -0.0181]], requires_grad=True),
 Parameter containing:
 tensor([ 0.1345, -0.3361, -0.2100,  0.0532,  0.2223], requires_grad=True),
 Parameter containing:
 tensor([[-0.0039,  0.4246, -0.4139, -0.4040,  0.2209]], requires_grad=True),
 Parameter containing:
 tensor([-0.3221], requires_grad=True)]

In [66]:
ls = WalkSATLN(policy, max_tries, max_flips, p=0.5)

In [61]:
for i in range(3):
    ls.evaluate(val_ds, walksat=True)

35.46875 10.08375 None 1.0
34.3 10.196250000000001 None 1.0
36.10625 10.094999999999999 None 1.0


In [67]:
for i in range(3):
    ls.evaluate(val_ds)

191.0625 78.3 0.4523022493813187 1.0
199.725 80.59625 0.4411004173569381 1.0
182.94375 78.62937500000001 0.4927808415144682 1.0


In [68]:
for j in range(15):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.4302416797582474
45.4375 12.6175 0.34005261920392515 1.0
0.24499379313292188
40.825 11.126249999999999 0.29827029216103257 1.0
0.2097663758788258
35.71875 10.28875 0.29054862991906705 1.0
0.21226228304828207
38.05625 9.68 0.26661623902618886 1.0
0.20121002536474003
34.975 9.391875 0.24267770475707948 1.0
0.19861296176392998
36.31875 9.13875 0.23618312077596784 1.0
0.1950332091645234
36.51875 8.960625 0.2220528768375516 1.0
0.19326598833522035
33.96875 8.698125000000001 0.2086977020604536 1.0
0.19179700914227094
34.28125 8.068125 0.19989559159148484 1.0
0.1934044715630201
36.7375 8.42125 0.1954457562416792 1.0
0.19226405694615095
37.25 8.408750000000001 0.18346511863637716 1.0
0.18701020144055494
32.975 8.34375 0.18828705131309106 1.0
0.18385154952450344
36.275 8.055625000000001 0.1708696506684646 1.0
0.1842433035528908
34.3375 8.20625 0.15983921515289695 1.0
0.18066280145430938
33.84375 7.78125 0.14616115840617566 1.0


In [69]:
for j in range(10):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.1900133056889495
28.26875 7.88125 0.1598966016783379 1.0
0.18571612565622975
35.28125 7.835625 0.1493968229740858 1.0
0.18074108219783133
32.84375 7.983125000000001 0.13599841308314353 1.0
0.1765854917262267
30.14375 7.8050000000000015 0.13796455648262054 1.0
0.1810314372052542
32.24375 7.433750000000001 0.1414280461729504 1.0
0.1801449193000897
34.96875 8.31625 0.13618679160717875 1.0
0.18026766701497965
32.49375 8.1125 0.13329387855483218 1.0
0.16985321175384646
34.0 7.7425000000000015 0.14587803810136393 1.0
0.17907653575944196
33.9125 7.766875000000001 0.11901606121100486 1.0
0.17503390257945284
33.125 7.799375 0.12751176658202895 1.0


In [70]:
for i in range(3):
    ls.evaluate(val_ds)

36.04375 7.583125000000001 0.12679847616236656 1.0
31.39375 7.8125 0.12810946155805142 1.0
35.2125 8.2875 0.1146646652603522 1.0


In [71]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-2.4546, -0.5375, -0.4934, -0.9594],
         [ 0.0429,  0.2279, -0.4182, -0.5133],
         [-0.3567, -0.0589, -0.4169, -0.3181],
         [ 2.0970,  0.3866,  0.7357,  0.1994],
         [ 2.2783,  0.5436,  0.7418,  0.2649]], requires_grad=True),
 Parameter containing:
 tensor([ 2.4530, -0.5514, -0.2100, -1.3267, -1.5038], requires_grad=True),
 Parameter containing:
 tensor([[ 6.5825,  0.2712, -0.4139, -3.4072, -2.8941]], requires_grad=True),
 Parameter containing:
 tensor([0.0931], requires_grad=True)]

In [72]:
change_lr(optimizer, lr=0.0001)

In [73]:
for j in range(5):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.17591446167090907
32.45 7.956874999999999 0.1163583088782616 1.0
0.17709127300496524
34.25625 7.95875 0.14412777799880133 1.0
0.17739224164876052
32.79375 8.01625 0.13437118483707308 1.0
0.17430420260048574
32.46875 7.830625 0.12468807835830376 1.0
0.17605081266775313
35.6875 8.038749999999999 0.13829632542910986 1.0


In [74]:
for i in range(3):
    ls.evaluate(val_ds)

36.06875 7.825625 0.120837081305217 1.0
30.125 7.878750000000001 0.133009365783073 1.0
39.49375 8.409375 0.1410188393376302 1.0
