In [None]:
# coding=utf-8
import os
import sys
import time
import glob
import csv
import signal
import subprocess
import argparse

import networkx as nx


# Configure graphs to skip (without .edge suffix)
SKIP_LIST = {
    'Paley17',
}


def apex_vertices(g):
    """
    Remove apex vertices (degree == n-1), return (reduced_graph, count_removed).
    """
    to_remove = [u for u, d in g.degree() if d == g.number_of_nodes() - 1]
    g.remove_nodes_from(to_remove)
    buff = len(to_remove)
    g = nx.convert_node_labels_to_integers(g, first_label=0, ordering="default")
    return g, buff


def degree_one_reduction(g):
    """
    Remove all but one degree‐1 neighbor per vertex.
    """
    to_remove = set()
    for u in g.nodes():
        seen = False
        for v in g.neighbors(u):
            if g.degree(v) == 1:
                if not seen:
                    seen = True
                else:
                    to_remove.add(v)
    g.remove_nodes_from(to_remove)
    return nx.convert_node_labels_to_integers(g, first_label=0, ordering="default")


def read_edge(filename):
    """
    Read a .edge file: skip to 'p ' line, parse number of edges,
    then read that many lines starting with 'e '.
    """
    with open(filename) as f:
        lines = [l.strip() for l in f if l.strip()]
    p_idx = next(i for i, l in enumerate(lines) if l.startswith('p '))
    m = int(lines[p_idx].split()[-1])
    edges = []
    for l in lines[p_idx+1 : p_idx+1+m]:
        if not l.startswith('e '):
            continue
        _, a, b, *_ = l.split()
        edges.append((int(a), int(b)))
    if len(edges) < m:
        for l in lines[p_idx+1+m:]:
            if l.startswith('e '):
                _, a, b, *_ = l.split()
                edges.append((int(a), int(b)))
                if len(edges) >= m:
                    break
    return edges


def make_vars(g, width):
    nv = g.number_of_nodes()
    p = [[[0]*width for _ in range(nv)] for __ in range(nv)]
    cur = 1
    for u in range(nv):
        for v in range(u, nv):
            for i in range(width):
                p[u][v][i] = cur
                cur += 1
    return p, cur-1


def generate_encoding(g, width):
    """
    Build CNF encoding string for treedepth ≤ width.
    """
    s, nvar = make_vars(g, width)
    clauses = []
    nv = g.number_of_nodes()
    nclauses = 0

    # 1) each pair has some depth
    for u in range(nv):
        for v in range(u, nv):
            clauses.append(f"{s[u][v][width-1]} 0"); nclauses += 1
            clauses.append(f"-{s[u][v][0]} 0");        nclauses += 1

    # 2) monotonicity
    for u in range(nv):
        for v in range(u, nv):
            for i in range(1, width):
                clauses.append(f"-{s[u][v][i-1]} {s[u][v][i]} 0"); nclauses += 1

    # 3) transitivity
    for u in range(nv):
        for v in range(u+1, nv):
            for w in range(v+1, nv):
                for i in range(width):
                    clauses.append(f"-{s[u][v][i]} -{s[u][w][i]} {s[v][w][i]} 0")
                    clauses.append(f"-{s[u][v][i]} -{s[v][w][i]} {s[u][w][i]} 0")
                    clauses.append(f"-{s[u][w][i]} -{s[v][w][i]} {s[u][v][i]} 0")
                    nclauses += 3

    # 4) pair ⇒ vertices
    for u in range(nv):
        for v in range(u+1, nv):
            for i in range(width):
                clauses.append(f"-{s[u][v][i]} {s[u][u][i]} 0")
                clauses.append(f"-{s[u][v][i]} {s[v][v][i]} 0")
                nclauses += 2

    # 5) root ancestors
    for u in range(nv):
        for v in range(u+1, nv):
            for i in range(1, width):
                clauses.append(f"-{s[u][v][i]} {s[u][u][i-1]} {s[v][v][i-1]} 0"); nclauses += 1

    # 6) edges ⇒ connectivity
    for (u, v) in g.edges():
        if u > v: u, v = v, u
        for i in range(1, width):
            clauses.append(f"-{s[u][u][i]} {s[u][u][i-1]} -{s[v][v][i]} {s[u][v][i]} 0")
            clauses.append(f"-{s[u][u][i]} {s[v][v][i-1]} -{s[v][v][i]} {s[u][v][i]} 0")
            nclauses += 2

    preamble = f"p cnf {nvar} {nclauses}"
    return preamble + "\n" + "\n".join(clauses) + "\n"


class Timer:
    def __init__(self, time_list=None):
        self.time_list = time_list
    def __enter__(self):
        self.start = time.time()
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        dur = time.time() - self.start
        if self.time_list is not None:
            self.time_list.append(dur)


def solve_component(g, cli_args):
    """
    Use original paper logic: search width descending from n+1 to 2,
    return at first rc == 20 (SAT) as in the paper code.
    """
    encoding_times = []
    solving_times = []
    lb = ub = 0
    to = False
    n = g.number_of_nodes()
    if n <= 1:
        return n, n, n, False, encoding_times, solving_times

    temp = os.path.abspath(cli_args.temp)
    inst = cli_args.instance
    # descending search
    for w in range(n+1, 1, -1):
        with Timer(time_list=encoding_times):
            cnf_str = generate_encoding(g, w)
            cnf_file = os.path.join(temp, f"{inst}_{w}.cnf")
            with open(cnf_file, 'w') as f:
                f.write(cnf_str)

        sol_file = os.path.join(temp, f"{inst}_{w}.sol")
        cmd = [cli_args.solver, f"-cpu-lim={cli_args.timeout}", cnf_file, sol_file]
        with Timer(time_list=solving_times):
            p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            p.communicate()
            rc = p.returncode

        if rc == 0:
            to = True
            if lb == ub == 0:
                ub = w
        elif rc == 20:
            if to:
                lb = w
            if lb == ub == 0:
                lb = ub = w
            return w, lb, ub, to, encoding_times, solving_times

    return 1, lb, ub, to, encoding_times, solving_times


def parse_args():
    parser = argparse.ArgumentParser(description="Batch-run treedepthp2sat")
    parser.add_argument('--solver',  type=str, default='glucose', help='SAT solver')
    parser.add_argument('--timeout', type=int, default=900,      help='timeout per SAT (s)')
    parser.add_argument('--temp',    type=str, default=os.path.join(os.getcwd(), 'temp'),
                        help='temp dir')
    args, _ = parser.parse_known_args()
    return args


def batch_run(cli_args):
    os.makedirs(cli_args.temp, exist_ok=True)

    # collect graphs
    tasks = []
    for ds in ['famous', 'standard']:
        for fn in glob.glob(os.path.join('inputs', ds, '*.edge')):
            tasks.append((ds, fn))
    total = len(tasks)
    print(f"Found {total} graphs in total, starting batch processing...")

    with open('resultsat.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([
            'dataset','instance',
            'n_original','m_original',
            'n_reduced','buff',
            'td_lb','td_ub',
            'total_time_s',
            'sum_encoding_s','sum_solving_s'
        ])

        for idx, (ds, fn) in enumerate(tasks, start=1):
            inst = os.path.splitext(os.path.basename(fn))[0]
            # blacklist
            if inst in SKIP_LIST:
                print(f"[{idx}/{total}] ({ds}) Skipping {inst} (blacklisted)")
                continue

            print(f"[{idx}/{total}] ({ds}) Processing {inst} ...", end='', flush=True)

            edges = read_edge(fn)
            g0 = nx.Graph(); g0.add_edges_from(edges)
            n0, m0 = g0.number_of_nodes(), g0.number_of_edges()

            # skip large
            if n0 > 20:
                print(f" Skipped, vertex count {n0} > 20")
                continue

            t0 = time.time()
            g1 = degree_one_reduction(g0.copy())
            g2, buff = apex_vertices(g1)

            td_lb, td_ub = float('inf'), -1
            all_enc, all_sol = [], []
            for comp in nx.connected_components(g2):
                sub = g2.subgraph(comp).copy()
                sub = nx.convert_node_labels_to_integers(sub, first_label=0, ordering="default")
                cli_args.instance = inst
                w, lb, ub, to, et, st = solve_component(sub, cli_args)
                td_lb = min(td_lb, lb)
                td_ub = max(td_ub, ub)
                all_enc.extend(et)
                all_sol.extend(st)

            # add back apex buffer
            td_lb += buff
            td_ub += buff

            total_time = time.time() - t0
            writer.writerow([
                ds, inst,
                n0, m0,
                g2.number_of_nodes(), buff,
                td_lb, td_ub,
                f"{total_time:.2f}",
                f"{sum(all_enc):.2f}", f"{sum(all_sol):.2f}"
            ])
            print(f" Completed, td=[{td_lb}-{td_ub}], time {total_time:.2f}s")


def main():
    if hasattr(signal, 'SIGHUP'):
        signal.signal(signal.SIGHUP, lambda s,f: sys.exit(0))
    args = parse_args()
    batch_run(args)


if __name__ == "__main__":
    main()