# P vs NP: Deep Search Workspace

This notebook runs the 'Frozen Backbone' search loop. It generates 3-SAT instances, identifies their backbone using a DPLL solver, and trains a GNN to predict the backbone.

**Goal:** Find a model with >95% accuracy, which would indicate a polynomial-time solution to identifying hard SAT structures.

In [None]:
# 1. Install Dependencies
!pip install torch numpy networkx

In [None]:
# 2. Setup Project Structure
import os
os.makedirs('src', exist_ok=True)
os.makedirs('data', exist_ok=True)

In [None]:
%%writefile src/sat_generator.py
import random
import numpy as np

class SatGenerator:
    @staticmethod
    def generate_3sat(n_vars, alpha=4.26):
        n_clauses = int(round(n_vars * alpha))
        clauses = []
        for _ in range(n_clauses):
            vars_idx = random.sample(range(1, n_vars + 1), 3)
            clause = [v if random.random() > 0.5 else -v for v in vars_idx]
            clauses.append(clause)
        return clauses, n_vars

In [None]:
%%writefile src/dpll_solver.py
import sys
from collections import Counter
sys.setrecursionlimit(50000)

class DpllSolver:
    def __init__(self):
        self.steps = 0
        self.backtracks = 0

    def solve(self, clauses, n_vars):
        self.steps = 0
        self.backtracks = 0
        if not clauses: return True, {}
        return self._dpll(clauses, {})

    def _dpll(self, clauses, assignment):
        self.steps += 1
        if not clauses: return True, assignment
        for c in clauses:
            if not c: return False, None

        # Unit Prop
        while True:
            unit_lit = None
            for c in clauses:
                if len(c) == 1:
                    unit_lit = c[0]
                    break
            if unit_lit is None: break
            assignment[abs(unit_lit)] = (unit_lit > 0)
            clauses = self._simplify(clauses, unit_lit)
            if not clauses: return True, assignment
            for c in clauses:
                if not c: return False, None

        # Heuristic
        counter = Counter()
        for c in clauses:
            for lit in c: counter[lit] += 1
        if not counter: return True, assignment

        best_lit, _ = counter.most_common(1)[0]
        new_assign = assignment.copy()
        new_assign[abs(best_lit)] = (best_lit > 0)
        res, final_assign = self._dpll(self._simplify(clauses, best_lit), new_assign)
        if res: return True, final_assign

        self.backtracks += 1
        new_assign = assignment.copy()
        new_assign[abs(best_lit)] = (best_lit < 0)
        return self._dpll(self._simplify(clauses, -best_lit), new_assign)

    def _simplify(self, clauses, literal):
        new_clauses = []
        for c in clauses:
            if literal in c: continue
            if -literal in c:
                new_c = [l for l in c if l != -literal]
                new_clauses.append(new_c)
            else:
                new_clauses.append(c)
        return new_clauses

In [None]:
%%writefile src/backbone_finder.py
from .dpll_solver import DpllSolver

class BackboneFinder:
    def __init__(self):
        self.solver = DpllSolver()
    
    def find_backbone(self, clauses, n_vars):
        satisfiable, first_assignment = self.solver.solve(clauses, n_vars)
        if not satisfiable: return {}, False
        
        backbone = {}
        for var in range(1, n_vars + 1):
            if var not in first_assignment: continue
            val = first_assignment[var]
            # Try to flip
            negated_unit = -var if val else var
            test_clauses = clauses + [[negated_unit]]
            is_flippable, _ = self.solver.solve(test_clauses, n_vars)
            if not is_flippable:
                backbone[var] = val
        return backbone, True

In [None]:
%%writefile src/gnn_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class BackboneMPNN(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.type_embed = nn.Embedding(2, hidden_dim)
        self.msg_v2c = nn.Sequential(nn.Linear(hidden_dim * 2 + 1, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
        self.msg_c2v = nn.Sequential(nn.Linear(hidden_dim * 2 + 1, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
        self.update_var = nn.GRUCell(hidden_dim, hidden_dim)
        self.update_clause = nn.GRUCell(hidden_dim, hidden_dim)
        self.projection_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1))

    def forward(self, n_vars, clauses, device='cpu'):
        num_clauses = len(clauses)
        num_nodes = n_vars + num_clauses
        node_types = torch.zeros(num_nodes, dtype=torch.long, device=device)
        node_types[n_vars:] = 1 
        h = self.type_embed(node_types)
        
        v_indices = []
        c_indices = []
        signs = []
        for c_idx, clause in enumerate(clauses):
            c_node = n_vars + c_idx
            for lit in clause:
                v_indices.append(abs(lit) - 1)
                c_indices.append(c_node)
                signs.append(1.0 if lit > 0 else -1.0)
        
        v_tensor = torch.tensor(v_indices, dtype=torch.long, device=device)
        c_tensor = torch.tensor(c_indices, dtype=torch.long, device=device)
        s_tensor = torch.tensor(signs, dtype=torch.float, device=device).unsqueeze(1)
        
        for _ in range(self.num_layers):
            h_v = h[v_tensor]
            h_c_target = h[c_tensor]
            msg = self.msg_v2c(torch.cat([h_v, h_c_target, s_tensor], dim=1))
            agg_c = torch.zeros(num_nodes, self.hidden_dim, device=device)
            agg_c.index_add_(0, c_tensor, msg)
            h_clauses_new = self.update_clause(agg_c[n_vars:], h[n_vars:])
            
            h_vars_current = h[:n_vars]
            h = torch.cat([h_vars_current, h_clauses_new], dim=0)
            
            h_c = h[c_tensor]
            h_v_target = h[v_tensor]
            msg = self.msg_c2v(torch.cat([h_c, h_v_target, s_tensor], dim=1))
            agg_v = torch.zeros(num_nodes, self.hidden_dim, device=device)
            agg_v.index_add_(0, v_tensor, msg)
            h_vars_new = self.update_var(agg_v[:n_vars], h[:n_vars])
            
            h_clauses_current = h[n_vars:]
            h = torch.cat([h_vars_new, h_clauses_current], dim=0)
            
        return torch.sigmoid(self.projection_head(h[:n_vars])).squeeze(1)

In [None]:
# 3. Run the Automated Search Loop
import torch
import torch.optim as optim
import torch.nn as nn
from src.sat_generator import SatGenerator
from src.backbone_finder import BackboneFinder
from src.gnn_model import BackboneMPNN

def generate_batch(num_samples, n_vars, alpha=4.26):
    finder = BackboneFinder()
    samples = []
    count = 0
    attempts = 0
    while count < num_samples:
        attempts += 1
        clauses, _ = SatGenerator.generate_3sat(n_vars, alpha)
        # Limit N_vars for quick testing in notebook
        backbone, sat = finder.find_backbone(clauses, n_vars)
        if not sat: continue
        
        labels = []
        backbone_set = set(backbone.keys())
        for v in range(1, n_vars+1):
            labels.append(1.0 if v in backbone_set else 0.0)
            
        samples.append((n_vars, clauses, torch.tensor(labels, dtype=torch.float)))
        count += 1
    return samples

def search_loop():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = BackboneMPNN(hidden_dim=64, num_layers=4).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.002)
    criterion = nn.BCELoss()
    
    cycle = 0
    best_acc = 0.0
    
    while True:
        cycle += 1
        print(f"\n--- CYCLE {cycle} ---")
        print("Generating data...")
        data = generate_batch(num_samples=20, n_vars=30)
        
        train_data = []
        for n, c, l in data:
            train_data.append((n, c, l.to(device)))
            
        split = int(0.8 * len(train_data))
        train_set = train_data[:split]
        val_set = train_data[split:]
        
        print("Training...")
        model.train()
        for epoch in range(15):
            total_loss = 0
            for n, c, l in train_set:
                optimizer.zero_grad()
                preds = model(n, c, device)
                loss = criterion(preds, l)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
        
        model.eval()
        total_corr = 0
        total_nodes = 0
        with torch.no_grad():
            for n, c, l in val_set:
                preds = model(n, c, device)
                bin_preds = (preds > 0.5).float()
                total_corr += (bin_preds == l).sum().item()
                total_nodes += n
        
        acc = (total_corr / total_nodes) * 100 if total_nodes > 0 else 0
        print(f"Validation Accuracy: {acc:.2f}%")
        
        if acc > best_acc:
            best_acc = acc
            print(f"New Best! {best_acc:.2f}%")
            
search_loop()