# Implementation of deep learning based moser-type CSP solver

Here we basically combine the approach from the Deepmind paper "MIP solving using deep neural networks" with " a local lemma for focussed stochastic search" from Achlioptas.

## Defining Atomic CSPs and the random walker

We focus on binary problems first (basically these are SAT problems I think)

In [578]:
import numpy as np
from typing import NamedTuple
import numpy.typing as npt
import torch
from torch_geometric.data import Data
import torch_geometric.transforms as T
from pysat.formula import CNF

class AtomicConstraint:
    
    def __init__(self,support, vals):
        self.support: npt.NDArray[np.int_] = support
        self.vals: npt.NDArray[np.int_] = vals

    def is_violated_by(self, assignment: npt.NDArray[np.int_]):
        return np.array_equiv(self.vals,assignment[self.support])
      
    
class BinAtCSP:
    variables: npt.NDArray[np.int_]
    constraints: list[AtomicConstraint]

    def num_vars(self):
        self = len(self.variables)
    
    def from_cnf(self, cnf: CNF):
        self.variables = np.arange(cnf.nv)
        self.constraints = [
            AtomicConstraint(
                support=np.array([(abs(l) - 1) for l in c]), 
                vals=np.array(((np.sign(c) + 1) // 2))
            ) for c in cnf.clauses if len(c) > 0
        ]
        assert all(len(c.support) == len(c.vals) for c in self.constraints)
        return self


    def to_pytorch(self):
        edges = []
        edge_features = []
        n = len(self.variables)
        m = len(self.constraints)
        
        for j,c in enumerate(self.constraints, start=n):
            edge_features.extend(c.vals)
            edges.extend([v, j] for v in c.support)   

        # index connecting variable nodes and constraint nodes
        edge_index = torch.tensor(edges)

        # edge attributes: A matrix of shape [num_edges, 1] that contains for each edge the element of the constraint
        edge_attr = torch.tensor(edge_features, dtype=torch.float)

        # node attributes will be initialised to 1.
        x = torch.ones(n + m,1)

        # define the graph
        data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr=edge_attr)
        data = T.ToUndirected()(data)
        return data
            
class Oracle:
    
    def __init__(self, dist, params):
        self.dist = dist
        self.params = params

    def sample(self, support: npt.NDArray[np.int_] ):
        conditional_dist = self.dist(logits=self.params[support])
        return conditional_dist.sample().flatten()
    
    
class OracleRandomWalker:
    
    def __init__(self, oracle: Oracle, instance: BinAtCSP):
        self.oracle = oracle
        self.instance = instance
        
    def find_violated_constraint(self, assignment: npt.NDArray[np.int_]):
        # to accelerate this: 
        
        for i, constraint in enumerate(self.instance.constraints):
            if constraint.is_violated_by(assignment):
                return constraint, True
        return None, False
    
    def run(self,step_limit=1000, return_trajectory=False):
        assignment = self.oracle.sample(self.instance.variables)
        trajectory = assignment if return_trajectory else None
        counter = 1
        while counter < step_limit:
            flaw, flaw_exists = self.find_violated_constraint(assignment.numpy())
            if not flaw_exists:
                return assignment, "satisfied", counter, trajectory
                
            # update assignment
            assignment[flaw.support] = self.oracle.sample(support = flaw.support)
            if return_trajectory:
                trajectory = torch.vstack([trajectory, assignment])
            counter += 1
            
        return assignment, "unsatisfied", counter, trajectory
        

## Define the neural network

Following the GCN + MLP generative apprpoach from the Deepmind paper here. Since we deal with CSPs instead of MIPs, the encoding is slightly different: I assign the violated bitstrings to edges and then pool them in a first step to generate node embeddings.

In [571]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.nn as nn

class GCN(nn.Module):
    def __init__(self, out_dim = 20):
        super().__init__()
        self.conv1 = GCNConv(1, 1, normalize=True, add_self_loops=False)
        self.conv2 = GCNConv(1, 16)
        self.conv3 = GCNConv(16, out_dim)

    def forward(self, data):
        x = self.conv1(data.x, data.edge_index, edge_weight=data.edge_attr)
        x = self.conv2(x, data.edge_index)
        x = F.relu(x)
        x = self.conv3(x, data.edge_index)
        return F.log_softmax(x, dim=1)

In [572]:
# defining the MLP 
class MLP(nn.Module):
    # the input here should be the final node embedding dimension
    def __init__(self, in_dim=20, hidden_n_2=400):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_dim, hidden_n_2)
        self.fc2 = nn.Linear(hidden_n_2, 1)

        self.reLU = nn.ReLU() 
        # self.sigmoid = nn.Sigmoid()

    def forward(self, embedded):
        h1 = self.reLU(self.fc1(embedded))
        return self.fc2(h1) # self.sigmoid(self.fc2(h1))


class Model(nn.Module):
    def __init__(self, latent_dim=20):
        super(Model, self).__init__()
        self.decoder = MLP(in_dim=latent_dim)
        self.gcn = GCN(out_dim=latent_dim)

    def forward(self, x):
        embedding = self.gcn(x)
        return self.decoder(embedding)

## Training

How do we traing this model? In other words, what is the loss as a function of the current model? WE follow two approaches, one unsupervised, the other supervised. 

In both cases, we follow Deepmind in trying to minimize the cross entropy between the generative distribution and a Gibbs distribution that favors low energy state. However, our energy function is different from their linear function. 

For the supervised, we do exactly what Deepmind did and use an off-the-shelf solver to find solutions and then we estimate their loss function using the trajectory of the solver (so including sub-optimal solutions). We use these points estimate the energy function and then get a differentiable estimate of the cross entropy.

Unsupervised: For given parameters, we run the random walker for $N$ steps and then use this to estimate the cross entropy. This also gives us a differentiable estimate of the loss function (with the subtlety that now the energy estimates also depend on the parameters so we need some kind of reparametrization trick to deal with that, maybe we can also do this analytically??  

### Loading the training instances

In [573]:
import glob

instances = [BinAtCSP().from_cnf(CNF(from_file=f)) for f in glob.glob('uf20-91/*.cnf')]

train_instances, test_instances = instances[:800], instances[800:]

### Implementing the unsupervised approach

Lets first do it without consider the fact that our weights will depend on the parameters (this might lead to bad performance, but lets see)

#### Defining the loss function

In [601]:
def energy(assignment, instance):
    # this is a bit of a random energy function and to be investigated/improved
    p = np.sum([c.is_violated_by(assignment) for c in instance.constraints])/len(instance.constraints)
    return np.maximum(np.sqrt(p), 1e-12)


class UnsupervisedLoss(nn.Module):
    def __init__(self, number_samples = 100, temperature = 1, alpha=3):
        super(UnsupervisedLoss, self).__init__()
        
        self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
        self.random_walker = OracleRandomWalker
        self.number_samples = number_samples
        self.temperature = temperature
        self.alpha = alpha
        
    def forward(self, params, instance):
        variable_params = params[:len(instance.variables)]
        oracle = Oracle(dist=Bernoulli, params=variable_params.detach())
        walker = self.random_walker(oracle, instance)
        _, _, _, trajectory = walker.run(step_limit=self.number_samples, return_trajectory=True)
        weights = torch.Tensor(self.estimate_gibbs_weights(trajectory, instance))
        neg_log_probs_model = torch.sum(self.bce_loss(variable_params.tile(self.number_samples).T, trajectory), axis=1)
        
        return weights.dot(neg_log_probs_model)
    
    def estimate_gibbs_weights(self, trajectory, instance):
        trajectory = trajectory.numpy()
        energies = np.apply_along_axis(func1d=energy, arr=trajectory, instance=instance, axis=1)
        weights = np.exp(- self.temperature * energies)
        return weights/np.sum(weights)
        
    

In [602]:
from torch.optim import Rprop

learning_rate = 1e-3


unsat_csp = BinAtCSP()
unsat_csp.from_cnf(unsat_cnf)

num_vars = len(unsat_csp.variables)

m = Model()

loss_fn = UnsupervisedLoss(number_samples=50)
optimizer = torch.optim.Rprop(params=m.parameters(), lr=learning_rate) 

def train():
    for (idx, instance) in enumerate(train_instances):
        print(f"Starting to train on instance {idx}")
        
        # print([c.support for c in instance.constraints])
        try:
            pred = m(instance.to_pytorch())
            loss = loss_fn(pred, instance)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                loss = loss.item()
                print(f"loss: {loss:>7f}  [{idx:>5d}/{len(train_data):>5d}]")
        except:
            print(f"could not train on instance {idx}")
                
train()


Starting to train on instance 0
loss: 15.110449  [    0/  800]
Starting to train on instance 1
Starting to train on instance 2
Starting to train on instance 3
Starting to train on instance 4
Starting to train on instance 5
Starting to train on instance 6
Starting to train on instance 7
Starting to train on instance 8
Starting to train on instance 9
Starting to train on instance 10
loss: 12.845843  [   10/  800]
Starting to train on instance 11
Starting to train on instance 12
Starting to train on instance 13
could not train on instance 13
Starting to train on instance 14
Starting to train on instance 15
Starting to train on instance 16
Starting to train on instance 17
Starting to train on instance 18
Starting to train on instance 19
could not train on instance 19
Starting to train on instance 20
loss: 13.869633  [   20/  800]
Starting to train on instance 21
Starting to train on instance 22
Starting to train on instance 23
Starting to train on instance 24
Starting to train on instance 

Starting to train on instance 213
Starting to train on instance 214
Starting to train on instance 215
Starting to train on instance 216
Starting to train on instance 217
Starting to train on instance 218
Starting to train on instance 219
Starting to train on instance 220
loss: 13.826776  [  220/  800]
Starting to train on instance 221
Starting to train on instance 222
Starting to train on instance 223
Starting to train on instance 224
Starting to train on instance 225
Starting to train on instance 226
could not train on instance 226
Starting to train on instance 227
Starting to train on instance 228
Starting to train on instance 229
Starting to train on instance 230
loss: 13.509699  [  230/  800]
Starting to train on instance 231
Starting to train on instance 232
Starting to train on instance 233
Starting to train on instance 234
Starting to train on instance 235
Starting to train on instance 236
Starting to train on instance 237
Starting to train on instance 238
Starting to train on i

Starting to train on instance 429
Starting to train on instance 430
loss: 13.244860  [  430/  800]
Starting to train on instance 431
Starting to train on instance 432
Starting to train on instance 433
Starting to train on instance 434
Starting to train on instance 435
Starting to train on instance 436
Starting to train on instance 437
Starting to train on instance 438
Starting to train on instance 439
Starting to train on instance 440
loss: 13.569084  [  440/  800]
Starting to train on instance 441
Starting to train on instance 442
Starting to train on instance 443
Starting to train on instance 444
Starting to train on instance 445
Starting to train on instance 446
Starting to train on instance 447
could not train on instance 447
Starting to train on instance 448
Starting to train on instance 449
Starting to train on instance 450
loss: 13.160810  [  450/  800]
Starting to train on instance 451
Starting to train on instance 452
Starting to train on instance 453
Starting to train on inst

Starting to train on instance 638
Starting to train on instance 639
Starting to train on instance 640
loss: 13.273820  [  640/  800]
Starting to train on instance 641
Starting to train on instance 642
Starting to train on instance 643
could not train on instance 643
Starting to train on instance 644
Starting to train on instance 645
Starting to train on instance 646
Starting to train on instance 647
Starting to train on instance 648
Starting to train on instance 649
Starting to train on instance 650
loss: 13.763478  [  650/  800]
Starting to train on instance 651
Starting to train on instance 652
Starting to train on instance 653
Starting to train on instance 654
Starting to train on instance 655
Starting to train on instance 656
Starting to train on instance 657
Starting to train on instance 658
Starting to train on instance 659
Starting to train on instance 660
could not train on instance 660
Starting to train on instance 661
Starting to train on instance 662
Starting to train on ins

#### Evaluate

To evaluate the performance, we compare the trained model with a naive Moser's algorithm, to see whether it performs any better

In [650]:
def evaluate(model, test_instances, steps_eval):
    results = []
    for i in test_instances[:100]:
        
        pred = model(i.to_pytorch())
        variable_params = pred[:len(i.variables)]
        trained_oracle = Oracle(Bernoulli, params=variable_params)
        trained_walker = OracleRandomWalker(trained_oracle, i)
        
    
        random_oracle = Oracle(Bernoulli, params=torch.full_like(variable_params, 0.5))
        mosers = OracleRandomWalker(random_oracle, i)
        

        moser_out, _, _, _ = mosers.run(step_limit=steps_eval)
        trained_out, _, _, _ = trained_walker.run(step_limit=steps_eval)
        
        results.append(np.divide(energy(trained_out, i)**2,energy(moser_out, i)**2))
    
    print(np.sum(results))

In [651]:
evaluate(m, test_instances, 1000)

1.4835164835164835e+24
