In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [3]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(3, h)
        self.dropout = nn.Dropout(0.5)
        self.lin2 = nn.Linear(h, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(self.dropout(x))
        x = self.lin2(x)
        return x

In [4]:
def init_net(model):
    with torch.no_grad():
        model.lin.weight[0, 0] = 10
        model.lin2.weight[0, 0] = -1

In [5]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.lin = nn.Linear(3, 1)
    def forward(self, x):
        x = self.lin(x)
        return x

In [6]:
def init_net2(policy):
    with torch.no_grad():
        policy.lin.weight[0, 0] = -1
        policy.lin.weight[0, 1] = 0
        policy.lin.weight[0, 2] = 0
        policy.lin.bias[0] = 0

In [7]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [32]:
class WalkSATLN:
    def __init__(self, policy, max_tries, max_flips, p=0.5, discount=0.5):
        self.policy = policy
        self.max_tries = max_tries
        self.max_flips = max_flips
        self.p = p
        self.discount = discount
        self.flips_to_solution = []
        self.backflips = []
        self.unsat_clauses = []
        self.age = []
        self.last_10 = []
        self.sol = []
        
    def compute_true_lit_count(self, clauses):
        n_clauses = len(clauses)
        true_lit_count = [0] * n_clauses
        for index in range(n_clauses):
            for literal in clauses[index]:
                if self.sol[abs(literal)] == literal:
                    true_lit_count[index] += 1
        return true_lit_count
    
    def select_variable_reinforce(self, x):
        logit = self.policy(x)
        prob = F.softmax(logit, dim=0)
        dist = Categorical(prob.view(-1))
        v = dist.sample()
        return v, dist.log_prob(v)
    
    def do_flip(self, literal, occur_list):
        for i in occur_list[literal]:
            self.true_lit_count[i] += 1
        for i in occur_list[-literal]:
            self.true_lit_count[i] -= 1
        self.sol[abs(literal)] *= -1
        
    def stats_per_clause(self, f, unsat_clause):
        """ computes the featutes needed for the model
        """ 
        r = f.n_variables/ len(f.clauses)
        variables = [abs(v) for v in unsat_clause]
        breaks = np.zeros(len(variables))
        last_5 = self.last_10[:5]
        for i, literal in enumerate(unsat_clause):
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
            breaks[i] = broken_count
        in_last_10 = np.array([int(i + 1 in self.last_10) for i in variables]) 
        in_last_5 = np.array([int(i + 1 in last_5) for i in variables]) 
        return np.stack([breaks, in_last_10, in_last_5], axis=1)
    
    def walksat_step(self, f, unsat_clause):
        """Returns chosen literal"""
        broken_min = float('inf')
        min_breaking_lits = []
        for literal in unsat_clause:
            broken_count = 0
            for index in f.occur_list[-literal]:
                if self.true_lit_count[index] == 1:
                    broken_count += 1
                if broken_count > broken_min:
                    break
            if broken_count < broken_min:
                broken_min = broken_count
                min_breaking_lits = [literal]
            elif broken_count == broken_min:
                min_breaking_lits.append(literal)
        return abs(random.choice(min_breaking_lits))
    
    def reinforce_step(self, f, unsat_clause):
        x = self.stats_per_clause(f, unsat_clause)
        x = torch.from_numpy(x).float()
        index, log_prob = self.select_variable_reinforce(x)
        literal = unsat_clause[index]
        return literal, log_prob
    
    def generate_episode_reinforce(self, f, walksat):
        self.sol = [x if random.random() < 0.5 else -x for x in range(f.n_variables + 1)]
        self.true_lit_count = self.compute_true_lit_count(f.clauses)
        log_probs = []
        flips = 0
        flipped = set()
        backflipped = 0
        while flips < max_flips:
            unsat_clause_indices = [k for k in range(len(f.clauses)) if self.true_lit_count[k] == 0]
            sat = not unsat_clause_indices
            if sat:
                break
            unsat_clause = f.clauses[random.choice(unsat_clause_indices)]
            log_prob = None
            if random.random() < self.p:
                literal = random.choice(unsat_clause)
            else:
                if walksat:
                    literal = self.walksat_step(f, unsat_clause)   
                else:
                    literal, log_prob = self.reinforce_step(f, unsat_clause)
                v = abs(literal)
                if v not in flipped:
                    flipped.add(v)
                else:
                    backflipped += 1
                self.last_10.insert(0, v)
                self.last_10 = self.last_10[:10]
            self.do_flip(literal, f.occur_list)
            flips += 1
            log_probs.append(log_prob)
        return sat, flips, backflipped, log_probs

    def reinforce_loss(self, log_probs_list):
        T = len(log_probs_list)
        log_probs_filtered = []
        mask = np.zeros(T, dtype=bool)
        for i, x in enumerate(log_probs_list):
            if x is not None:
                log_probs_filtered.append(x)
                mask[i] = 1

        log_probs = torch.stack(log_probs_filtered)
        p_rewards = self.discount ** torch.arange(T - 1, -1, -1, dtype=torch.float32, device=log_probs.device)
        return -torch.mean(p_rewards[torch.from_numpy(mask).to(log_probs.device)] * log_probs)

    def generate_episodes(self, f, walksat=False):
        flips_stats = []
        losses = []
        backflips = []
        num_sols = 0
        for i in range(self.max_tries):
            sat, flips, backflipped, log_probs = self.generate_episode_reinforce(f, walksat)
            flips_stats.append(flips)
            backflips.append(backflipped)
            if sat and flips > 0 and not all(map(lambda x: x is None, log_probs)):
                loss = self.reinforce_loss(log_probs)
                losses.append(loss)
            num_sols += sat    
        if losses:
            losses = torch.stack(losses).sum()  
        return np.mean(flips), np.mean(backflips), losses, num_sols/self.max_tries
    
    
    def evaluate(self, data, walksat=False):
        mean_flips = []
        mean_backflips = []
        mean_losses = []
        accuracy = []
        self.policy.eval()
        for f in data:
            flips, backflips, losses, acc = self.generate_episodes(f, walksat)
            mean_flips.append(flips)
            mean_backflips.append(backflips)
            if losses:
                mean_losses.append(losses.item())
            accuracy.append(acc)
            mean_loss = None
            if mean_losses:
                mean_loss = np.mean(mean_losses)
        print(np.mean(mean_flips), np.mean(mean_backflips), mean_loss, np.mean(accuracy))
        
    def train_epoch(self, optimizer, data):
        losses = []
        for f in data:
            self.policy.train()
            flips, backflips, loss, acc = self.generate_episodes(f)
            if acc > 0:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
        print(np.mean(losses))

In [9]:
def change_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [10]:
train_ds = load_dir("../data/rand3sat/10-43")
#val_ds = load_dir("../data/rand3sat/25-106")

In [11]:
train_ds = load_dir("../data/rand3sat/25-106")
#val_ds = load_dir("../data/rand3sat/25-106")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [12]:
policy = Net()
optimizer = optim.RMSprop(policy.parameters(), lr=0.01, weight_decay=1e-5)
max_tries = 10
max_flips = 10000

In [13]:
init_net(policy)

In [14]:
policy.lin2.weight

Parameter containing:
tensor([[-1.0000, -0.3289, -0.2722,  0.3029, -0.3633]], requires_grad=True)

In [15]:
train_ds = load_dir("../data/kcolor/4-15-0.5/")
train_ds = load_dir("../data/kcolor/3-10-0.5/")
val_ds = train_ds[1900:]
train_ds = train_ds[:1900]
len(train_ds), len(val_ds)

(1900, 100)

In [16]:
policy = Net2()
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=0)
max_tries = 10
max_flips = 10000

In [17]:
init_net2(policy)

In [18]:
#init_net(policy)
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-1.,  0.,  0.]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True)]

In [19]:
max_tries = 10
max_flips = 10000

In [33]:
ls = WalkSATLN(policy, max_tries, max_flips, p=0.5)

In [26]:
for i in range(3):
    ls.evaluate(val_ds, walksat=True)

50.36 11.426999999999998 None 1.0
44.94 11.057999999999998 None 1.0
46.86 11.525 None 1.0


In [27]:
for i in range(3):
    ls.evaluate(val_ds)

96.13 32.422999999999995 0.2566214936226606 1.0
97.68 31.145 0.25815081529319284 1.0
90.28 31.678999999999995 0.24032414235174657 1.0


In [None]:
ls.train_epoch(optimizer, train_ds[:1000])

In [None]:
for i in range(3):
    ls.evaluate(val_ds)

In [None]:
for j in range(10):
    ls.train_epoch(optimizer, train_ds)
    ls.evaluate(val_ds)

In [94]:
for i in range(3):
    ls.evaluate(val_ds)
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

74.63 17.503999999999998 0.12857144352048636 1.0
66.37 17.093 0.1335478150472045 1.0
62.71 17.196 0.13198083877563477 1.0


In [96]:
[p for p in policy.parameters()]

[Parameter containing:
 tensor([[-18.3708,  -0.6255,  -0.8659]], requires_grad=True),
 Parameter containing:
 tensor([0.0004], requires_grad=True)]

In [75]:
# old values
for i in range(3):
    evaluate(policy, val_ds, max_tries, max_flips, discount=0.5, walk_prob=.5)

54.99 16.944999999999997 0.14899197135120631 1.0
59.57 16.757 0.1426751885190606 1.0
65.47 17.554000000000002 0.14114499025046826 1.0
