In [1]:
import os
from cnf import CNF
import numpy as np

In [2]:
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [33]:
from local_search import WalkSATLN
from warm_up import WarmUP

In [34]:
class Net(nn.Module):
    def __init__(self, h=5):
        super(Net, self).__init__()
        self.lin = nn.Linear(4, h)
        self.dropout = nn.Dropout(0.5)
        self.lin2 = nn.Linear(h, 1)
    def forward(self, x):
        x = self.lin(x)
        x = F.relu(self.dropout(x))
        x = self.lin2(x)
        return x

In [35]:
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.lin = nn.Linear(4, 1)
    def forward(self, x):
        x = self.lin(x)
        return x

In [36]:
def load_dir(path):
    data = []
    for filename in os.listdir(path):
        name, ext = os.path.splitext(filename)
        if ext != '.cnf':
            continue
        f = CNF.from_file(os.path.join(path, filename))
        data.append(f)
    return data

In [37]:
def split_data(data):
    train_ds = data[:1500]
    val_ds = data[1500:1700]
    test_ds = data[1700:]
    return train_ds, val_ds, test_ds

In [38]:
data = load_dir("../data/rand3sat/25-106")
train_ds, val_ds, test_ds = split_data(data)

In [39]:
policy = Net(10)
optimizer = optim.RMSprop(policy.parameters(), lr=0.001, weight_decay=1e-5)
max_flips = 10000

In [19]:
ls = WalkSATLN(policy, 10, 10000)
flips, backflips,  loss, accuracy = ls.evaluate(val_ds, walksat=True)

In [20]:
np.median(flips), np.mean(flips)

(66.0, 148.3915)

In [21]:
ls = WalkSATLN(policy, 10, 10000)
flips, backflips,  loss, accuracy = ls.evaluate(val_ds)
np.median(flips), np.mean(flips)

(302.0, 798.296)

In [None]:
data = load_dir("../data/rand3sat/25-106")
train_ds, val_ds, test_ds = split_data(data)

In [31]:
wup = WarmUP(policy, max_flips=1000)

In [23]:
for i in range(5):
    wup.train_epoch(optimizer, train_ds)

0.6879837957272927
0.5716830270042023
0.5282477676523849
0.5202376209789267
0.5107719466562073


In [24]:
torch.save(policy.state_dict(), "walksat_Net.pt")

In [25]:
ls = WalkSATLN(policy, 10, 10000)
flips, backflips,  loss, accuracy = ls.evaluate(val_ds)
np.median(flips), np.mean(flips)

(71.0, 143.944)

In [29]:
wup.break_histo[:15]/1e+5

array([1.28738e+00, 7.37441e+00, 5.11266e+00, 2.54323e+00, 9.75490e-01,
       2.90400e-01, 7.04200e-02, 1.41300e-02, 2.52000e-03, 3.70000e-04,
       4.00000e-05, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00])