In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(3, h)
        self.dropout = nn.Dropout(0.5)
        self.lin2 = nn.Linear(h, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(self.dropout(x))
        x = self.lin2(x)
        return x

In [4]:
def init_net(model):
    with torch.no_grad():
        model.lin.weight[0, 0] = 10
        model.lin2.weight[0, 0] = -1

In [76]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.lin = nn.Linear(4, 1)
    def forward(self, x):
        x = self.lin(x)
        return x

In [77]:
def init_net2(policy):
    with torch.no_grad():
        policy.lin.weight[0, 0] = -1
        policy.lin.weight[0, 1] = 0
        policy.lin.weight[0, 2] = 0
        policy.lin.weight[0, 3] = 0
        policy.lin.bias[0] = 0

In [7]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [79]:
class WalkSATLN:
    def __init__(self, policy, max_tries, max_flips, p=0.5, discount=0.5):
        self.policy = policy
        self.max_tries = max_tries
        self.max_flips = max_flips
        self.p = p
        self.discount = discount
        self.flips_to_solution = []
        self.backflips = []
        self.unsat_clauses = []
        self.age = []
        self.last_10 = []
        self.sol = []
        
    def compute_true_lit_count(self, clauses):
        n_clauses = len(clauses)
        true_lit_count = [0] * n_clauses
        for index in range(n_clauses):
            for literal in clauses[index]:
                if self.sol[abs(literal)] == literal:
                    true_lit_count[index] += 1
        return true_lit_count
    
    def select_variable_reinforce(self, x):
        logit = self.policy(x)
        prob = F.softmax(logit, dim=0)
        dist = Categorical(prob.view(-1))
        v = dist.sample()
        return v, dist.log_prob(v)
    
    def do_flip(self, literal, occur_list):
        for i in occur_list[literal]:
            self.true_lit_count[i] += 1
        for i in occur_list[-literal]:
            self.true_lit_count[i] -= 1
        self.sol[abs(literal)] *= -1
        
    def stats_per_clause(self, f, unsat_clause):
        """ computes the featutes needed for the model
        """ 
        r = f.n_variables/ len(f.clauses)
        variables = [abs(v) for v in unsat_clause]
        breaks = np.zeros(len(variables))
        last_5 = self.last_10[:5]
        for i, literal in enumerate(unsat_clause):
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
            breaks[i] = broken_count
        in_last_10 = np.array([int(v in self.last_10) for v in variables]) 
        age = np.array([self.age[v] for v in variables])/(self.age[0] + 1)
        in_last_5 = np.array([int(v in last_5) for v in variables]) 
        return np.stack([breaks, in_last_10, in_last_5, age], axis=1)
    
    def walksat_step(self, f, unsat_clause):
        """Returns chosen literal"""
        broken_min = float('inf')
        min_breaking_lits = []
        for literal in unsat_clause:
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
                if broken_count > broken_min:
                    break
            if broken_count < broken_min:
                broken_min = broken_count
                min_breaking_lits = [literal]
            elif broken_count == broken_min:
                min_breaking_lits.append(literal)
        return abs(random.choice(min_breaking_lits))
    
    def reinforce_step(self, f, unsat_clause):
        x = self.stats_per_clause(f, unsat_clause)
        x = torch.from_numpy(x).float()
        index, log_prob = self.select_variable_reinforce(x)
        literal = unsat_clause[index]
        return literal, log_prob
    
    def generate_episode_reinforce(self, f, walksat):
        self.sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
        self.true_lit_count = self.compute_true_lit_count(f.clauses)
        self.age = np.zeros(f.n_variables + 1)
        log_probs = []
        flips = 0
        flipped = set()
        backflipped = 0
        while flips < max_flips:
            unsat_clause_indices = [k for k in range(len(f.clauses)) if self.true_lit_count[k] == 0]
            sat = not unsat_clause_indices
            if sat:
                break
            unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
            log_prob = None
            if random.random() < self.p:
                literal = random.choice(unsat_clause)
            else:
                if walksat:
                    literal = self.walksat_step(f, unsat_clause)   
                else:
                    literal, log_prob = self.reinforce_step(f, unsat_clause)
                v = abs(literal)
                if v not in flipped:
                    flipped.add(v)
                else:
                    backflipped += 1
                self.last_10.insert(0, v)
                self.last_10 = self.last_10[:10]
                self.age[v] = flips
            self.do_flip(literal, f.occur_list)
            flips += 1
            self.age[0] = flips
            log_probs.append(log_prob)
        return sat, flips, backflipped, log_probs

    def reinforce_loss(self, log_probs_list):
        T = len(log_probs_list)
        log_probs_filtered = []
        mask = np.zeros(T, dtype=bool)
        for i, x in enumerate(log_probs_list):
            if x is not None:
                log_probs_filtered.append(x)
                mask[i] = 1

        log_probs = torch.stack(log_probs_filtered)
        p_rewards = self.discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
        return -torch.mean(p_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

    def generate_episodes(self, f, walksat=False):
        flips_stats = []
        losses = []
        backflips = []
        num_sols = 0
        for i in range(self.max_tries):
            sat, flips, backflipped, log_probs = self.generate_episode_reinforce(f, walksat)
            flips_stats.append(flips)
            backflips.append(backflipped)
            if sat and flips > 0 and not all(map(lambda x: x is None, log_probs)):
                loss = self.reinforce_loss(log_probs)
                losses.append(loss)
            num_sols += sat    
        if losses:
            losses = torch.stack(losses).sum()  
        return np.mean(flips), np.mean(backflips), losses, num_sols/self.max_tries
    
    
    def evaluate(self, data, walksat=False):
        mean_flips = []
        mean_backflips = []
        mean_losses = []
        accuracy = []
        self.policy.eval()
        for f in data:
            flips, backflips, losses, acc = self.generate_episodes(f, walksat)
            mean_flips.append(flips)
            mean_backflips.append(backflips)
            if losses:
                mean_losses.append(losses.item())
            accuracy.append(acc)
            mean_loss = None
            if mean_losses:
                mean_loss = np.mean(mean_losses)
        print(np.mean(mean_flips), np.mean(mean_backflips), mean_loss, np.mean(accuracy))
        
    def train_epoch(self, optimizer, data):
        losses = []
        for f in data:
            self.policy.train()
            flips, backflips, loss, acc = self.generate_episodes(f)
            if acc > 0:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
        print(np.mean(losses))

In [56]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [10]:
train_ds = load_dir("../data/rand3sat/10-43")
#val_ds = load_dir("../data/rand3sat/25-106")

In [11]:
train_ds = load_dir("../data/rand3sat/25-106")
#val_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [12]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.01, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [13]:
init_net(policy)

In [14]:
policy.lin2.weight

Parameter containing:
tensor([[-1.0000, -0.3289, -0.2722,  0.3029, -0.3633]], requires_grad=True)

In [15]:
train_ds = load_dir("../data/kcolor/4-15-0.5/")
train_ds = load_dir("../data/kcolor/3-10-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [80]:
policy = Net2()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=0)
max_tries = 10
max_flips = 10000

In [81]:
init_net2(policy)

In [82]:
#init_net(policy)
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-1.,  0.,  0.,  0.]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True)]

In [83]:
max_tries = 10
max_flips = 10000

In [84]:
ls = WalkSATLN(policy, max_tries, max_flips, p=0.5)

In [85]:
for i in range(3):
    ls.evaluate(val_ds, walksat=True)

46.58 11.414000000000001 None 1.0
45.52 11.923000000000002 None 1.0
47.01 11.539000000000001 None 1.0


In [86]:
for i in range(3):
    ls.evaluate(val_ds)

89.64 29.127000000000002 0.2603824938088655 1.0
98.23 30.545 0.2574538692086935 1.0
94.22 29.067000000000007 0.25322111681103704 1.0


In [87]:
ls.train_epoch(optimizer, train_ds)

0.1812230912164638


In [88]:
for i in range(3):
    ls.evaluate(val_ds)

58.21 12.152999999999999 0.14510542072355748 1.0
58.22 12.187000000000003 0.14102069918066262 1.0
55.31 12.955 0.1360379394888878 1.0


In [89]:
for j in range(10):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.12453894591056987
50.01 11.043000000000005 0.11110588377341628 1.0
0.10370882746616476
48.02 10.163999999999998 0.10437269033864141 1.0
0.09388547586620247
57.51 9.96 0.09492984515614808 1.0
0.08792231511483949
46.17 9.558 0.08627009075134992 1.0
0.08383730543581279
48.29 9.048 0.08164749088697136 1.0
0.0819346508753829
46.43 9.223 0.07573967995122075 1.0
0.07892688288230841
44.93 10.091999999999999 0.082134427735582 1.0
0.07633932366538303
48.02 9.06 0.07406406317371875 1.0
0.0770908094403383
49.24 9.122 0.07616797674447298 1.0
0.07556656423490495
45.82 9.402000000000001 0.07587006560293957 1.0


In [90]:
for i in range(3):
    ls.evaluate(val_ds)

51.37 9.123 0.08137148874811828 1.0
44.49 9.124999999999998 0.07867962806951255 1.0
45.86 8.37 0.07515300078783184 1.0


In [91]:
for j in range(5):
    ls.train_epoch(optimizer, train_ds)

0.07313184804632328
0.07416222611313539
0.07207129431232859
0.07216185999225433
0.07024976932737781


In [92]:
for i in range(3):
    ls.evaluate(val_ds)

40.58 8.068000000000001 0.07303019106388092 1.0
49.0 8.913 0.06776832322590053 1.0
49.04 8.673 0.07228743204846978 1.0


In [93]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-16.5120,  -1.4553,  -4.0445,  -5.7023]], requires_grad=True),
 Parameter containing:
 tensor([-0.0211], requires_grad=True)]

In [102]:
train_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [103]:
policy = Net2()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=0)
init_net2(policy)
max_tries = 10
max_flips = 10000

In [104]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-1.,  0.,  0.,  0.]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True)]

In [105]:
ls = WalkSATLN(policy, max_tries, max_flips, p=0.5)

In [106]:
for i in range(3):
    ls.evaluate(val_ds, walksat=True)

66.93 19.721 None 1.0
66.44 19.753 None 1.0
72.35 19.808 None 1.0


In [107]:
for i in range(3):
    ls.evaluate(val_ds)

204.92 88.72300000000001 0.19727257976308465 1.0
235.66 88.44699999999999 0.20111599389463664 1.0
194.67 82.65000000000002 0.2127423594892025 1.0


In [108]:
for j in range(15):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.1544297411689829
139.45 53.26499999999999 0.11312383824493737 1.0
0.10077853788168317
112.72 60.613 0.09554155359510333 1.0
0.0812333848534495
131.45 54.929 0.06751992259174586 1.0
0.07252814758464841
127.51 50.636 0.0627974709845148 1.0
0.06574691613998239
162.69 53.94299999999999 0.07377243232214824 1.0
0.061656820494559055
134.56 52.05199999999999 0.06159733816049993 1.0
0.06249577908609737
138.74 55.047000000000004 0.05994226535549387 1.0
0.05791741791190457
173.59 59.261 0.05636318602017127 1.0
0.06004922917210742
121.91 52.61499999999999 0.05441363054793328 1.0
0.05631683989587289
123.59 50.077 0.05824942529201507 1.0
0.05818386168720041
134.55 55.455 0.05165729889238719 1.0
0.05329474430102365
130.35 51.323999999999984 0.057044581160880625 1.0
0.05352749843308151
138.93 50.863 0.05162087505217641 1.0
0.05422014358712324
149.14 58.915 0.055577607933664695 1.0
0.05325949137209376
165.63 54.643 0.051833666417514905 1.0


In [109]:
for i in range(3):
    ls.evaluate(val_ds)

126.44 52.32599999999999 0.05601068066898733 1.0
154.3 57.70099999999999 0.05302362880902365 1.0
144.15 57.972 0.04758059546351433 1.0


In [110]:
for j in range(15):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.05075577362033073
141.7 50.56000000000001 0.056652487966930495 1.0
0.049056783326262746
139.63 53.21799999999999 0.04776167414034717 1.0
0.04929888884537688
144.26 46.379000000000005 0.04096380891278386 1.0
0.04917042895685882
134.84 52.16799999999999 0.04742288448498584 1.0
0.049287893496498564
205.53 53.99000000000001 0.05089841512381099 1.0
0.04563707548605874
146.66 52.4 0.0508113626530394 1.0
0.04765723780494514
155.52 56.452 0.042181635486194864 1.0
0.0478724420093376
123.42 53.698 0.051565392259508375 1.0
0.04602471422340482
90.14 48.987999999999985 0.03849637843144592 1.0
0.046688915330282486
113.35 55.09399999999999 0.03537640028167516 1.0
0.044437741633622574
123.2 53.698 0.04528322346624918 1.0
0.04674516648551952
98.49 53.479000000000006 0.053240925012505616 1.0
0.04583556882943652
111.43 59.07300000000001 0.04382167976349592 1.0
0.04389332783660441
157.14 62.68600000000001 0.04307505873905029 1.0
0.04477601390507373
120.96 52.782999999999994 0.04354304140782915 1.0


In [111]:
for i in range(3):
    ls.evaluate(val_ds)

116.25 52.54900000000001 0.047982482310617344 1.0
160.15 54.340999999999994 0.045732693669851866 1.0
154.72 52.681000000000004 0.036913557433872486 1.0


In [112]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-12.7003,  -1.8755,  -3.2198,  -4.1707]], requires_grad=True),
 Parameter containing:
 tensor([0.2007], requires_grad=True)]

In [113]:
change_lr(optimizer, lr=0.0001)

In [114]:
for j in range(5):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.04604027766098673
132.83 50.22300000000001 0.046052083629765546 1.0
0.045425660096462754
112.48 57.06300000000001 0.04760960756408167 1.0
0.04646750557525114
119.12 44.55200000000001 0.04366783687932184 1.0
0.045103035410299384
107.21 47.16 0.041057201857911424 1.0
0.04418696907793147
147.7 66.445 0.040336825030390176 1.0


In [115]:
for j in range(5):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

0.044634085781286366
143.83 55.553999999999995 0.04541555834002793 1.0
0.04396421247739789
117.44 49.674 0.04313059221196454 1.0
0.044676057761937144
145.69 58.878 0.04326639292150503 1.0
0.04380048438964877
138.63 53.80700000000001 0.053035292758140715 1.0
0.04267985111540551
147.07 55.093 0.05460201552486978 1.0


In [116]:
for i in range(3):
    ls.evaluate(val_ds)

145.44 54.019000000000005 0.04472931834054179 1.0
132.89 51.23799999999999 0.05314328968815971 1.0
134.78 51.717 0.051823291562031956 1.0
