In [None]:
#!/usr/bin/env python3 
# coding:utf-8

import os
import sys
import glob
import csv
import time
import math
import argparse
import subprocess
import importlib
from typing import List, Tuple, Optional, Dict

for pkg in ("networkx", "gurobipy"):
    try:
        importlib.import_module(pkg)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

import networkx as nx
import gurobipy as gp
from gurobipy import GRB

SKIP_LIST = {"Paley17"}
DEFAULT_MAX_V = 20
DEFAULT_MAX_E = 2000
DEFAULT_TIME = 180
DEFAULT_OUTFILE = "results_treedepth_ilp_optimized99.csv"
DEFAULT_THREADS = max(1, os.cpu_count() // 4)


def apex_vertices(g):
    """
    Remove apex vertices (degree == n-1), return (reduced_graph, count_removed).
    """
    if g.number_of_nodes() <= 1:
        return g, 0
    
    to_remove = [u for u, d in g.degree() if d == g.number_of_nodes() - 1]
    g.remove_nodes_from(to_remove)
    buff = len(to_remove)
    g = nx.convert_node_labels_to_integers(g, first_label=0, ordering="default")
    return g, buff


def degree_one_reduction(g):
    """
    Remove all but one degree-1 neighbor per vertex.
    """
    if g.number_of_nodes() <= 1:
        return g
    
    to_remove = set()
    for u in g.nodes():
        seen = False
        for v in list(g.neighbors(u)):
            if g.degree(v) == 1:
                if not seen:
                    seen = True
                else:
                    to_remove.add(v)
    
    g.remove_nodes_from(to_remove)
    return nx.convert_node_labels_to_integers(g, first_label=0, ordering="default")


def preprocess_graph(G, enable_preprocessing=True):
    stats = {
        "original_nodes": G.number_of_nodes(),
        "original_edges": G.number_of_edges(),
        "degree_one_removed": 0,
        "apex_removed": 0,
        "final_nodes": 0,
        "final_edges": 0
    }
    
    if not enable_preprocessing:
        stats["final_nodes"] = G.number_of_nodes()
        stats["final_edges"] = G.number_of_edges()
        return G.copy(), 0, stats
    
    g1 = degree_one_reduction(G.copy())
    stats["degree_one_removed"] = stats["original_nodes"] - g1.number_of_nodes()
    
    g2, apex_buffer = apex_vertices(g1)
    stats["apex_removed"] = apex_buffer
    stats["final_nodes"] = g2.number_of_nodes()
    stats["final_edges"] = g2.number_of_edges()
    
    return g2, apex_buffer, stats


def read_edge_file(filename: str) -> Tuple[List[Tuple[int, int]], int]:
    edges = []
    n_declared = 0
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('c'):
                continue
            if line.startswith('p '):
                parts = line.split()
                if len(parts) >= 3:
                    n_declared = int(parts[2])
            elif line.startswith('e '):
                parts = line.split()
                if len(parts) >= 3:
                    u, v = int(parts[1]) - 1, int(parts[2]) - 1
                    edges.append((u, v))
    return edges, n_declared


def create_graph_from_edges(edges: List[Tuple[int, int]], n_nodes: int) -> nx.Graph:
    G = nx.Graph()
    G.add_nodes_from(range(n_nodes))
    G.add_edges_from(edges)
    return G


def estimate_bounds(G: nx.Graph) -> Tuple[int, int]:
    n = G.number_of_nodes()
    if n <= 1:
        return n, n

    def dfs_height(component_nodes: List[int]) -> int:
        visited = set()
        maxh = 1
        for start in component_nodes:
            if start in visited:
                continue
            stack = [(start, 1, None)]
            while stack:
                v, d, parent = stack.pop()
                if v in visited:
                    continue
                visited.add(v)
                maxh = max(maxh, d)
                for w in G.neighbors(v):
                    if w != parent:
                        stack.append((w, d + 1, v))
        return maxh

    U = 0
    for comp in nx.connected_components(G):
        U = max(U, dfs_height(list(comp)))
    U = max(U, 1)

    LB = 1
    for comp in nx.connected_components(G):
        H = G.subgraph(comp)
        if H.number_of_nodes() == 1:
            LBc = 1
        else:
            src = next(iter(comp))
            dist1 = nx.single_source_shortest_path_length(G, src)
            far = max(dist1, key=dist1.get)
            dist2 = nx.single_source_shortest_path_length(G, far)
            diam = max(dist2.values())
            LBc = max(1, math.ceil(math.log2(diam + 1)))
        LB = max(LB, LBc)

    U = min(U, n)
    return LB, U


def build_dfs_mipstart(G: nx.Graph):
    nodes = list(G.nodes())
    parent = {v: None for v in nodes}
    depth = {v: None for v in nodes}

    def dfs(root: int):
        stack = [(root, 1, None)]
        while stack:
            v, d, p = stack.pop()
            if depth[v] is not None:
                continue
            depth[v] = d
            parent[v] = p
            for w in G.neighbors(v):
                if depth[w] is None:
                    stack.append((w, d + 1, v))

    for comp in nx.connected_components(G):
        r = min(comp)
        dfs(r)

    H = max(depth.values()) if depth else 0

    def is_ancestor(u, v):
        cur = parent[v]
        while cur is not None:
            if cur == u:
                return True
            cur = parent[cur]
        return False

    ancestor = {(u, v): int(u != v and is_ancestor(u, v)) for u in nodes for v in nodes}
    root = {v: int(parent[v] is None) for v in nodes}

    return {"H": H, "depth": depth, "parent": parent, "ancestor": ancestor, "root": root}


def lazy_transitivity_callback(model: gp.Model, where):
    if where != GRB.Callback.MIPSOL:
        return
    a = model._a
    nodes = model._nodes
    aval = {(i, j): model.cbGetSolution(a[i, j]) for i in nodes for j in nodes if i != j}
    for i in nodes:
        for j in nodes:
            if i == j or aval[(i, j)] < 0.5:
                continue
            for k in nodes:
                if k != i and k != j and aval[(j, k)] > 0.5 and aval[(i, k)] < 0.5:
                    model.cbLazy(a[i, k] >= a[i, j] + a[j, k] - 1)


def solve_single_component(G: nx.Graph, time_limit: int, threads: int) -> Tuple[Optional[int], bool]:
    n = G.number_of_nodes()
    if n <= 1:
        return n, False

    nodes = list(G.nodes())
    edges = list(G.edges())

    LB, UB = estimate_bounds(G)
    U = UB if UB > 0 else n

    model = gp.Model("Treedepth_ILP_Direct")
    model.Params.OutputFlag = 0
    model.Params.TimeLimit = time_limit
    model.Params.Threads = threads
    model.Params.MIPFocus = 1
    model.Params.Heuristics = 0.4
    model.Params.Presolve = 2
    model.Params.Cuts = 2
    model.Params.Symmetry = 2
    model.Params.NodefileStart = 0.5
    model.Params.LazyConstraints = 1

    d = model.addVars(nodes, vtype=GRB.INTEGER, lb=1, ub=U, name="d")
    r = model.addVars(nodes, vtype=GRB.BINARY, name="r")
    p = model.addVars(nodes, nodes, vtype=GRB.BINARY, name="p")
    a = model.addVars(nodes, nodes, vtype=GRB.BINARY, name="a")

    try:
        model.setObjective(gp.max_([d[v] for v in nodes]), GRB.MINIMIZE)
        use_direct_max = True
    except:
        max_depth = model.addVar(vtype=GRB.INTEGER, lb=LB, ub=UB, name="max_depth")
        model.setObjective(max_depth, GRB.MINIMIZE)
        for v in nodes:
            model.addConstr(max_depth >= d[v])
        use_direct_max = False

    for i in nodes:
        model.addConstr(p[i, i] == 0)
        model.addConstr(a[i, i] == 0)

    for v in nodes:
        model.addConstr(gp.quicksum(p[u, v] for u in nodes if u != v) + r[v] == 1)
        model.addConstr(d[v] <= 1 + U * (1 - r[v]))
        model.addConstr(d[v] >= 1 - U * (1 - r[v]))

    for u in nodes:
        for v in nodes:
            if u != v:
                model.addConstr(d[v] - d[u] >= 1 - U * (1 - p[u, v]))
                model.addConstr(d[v] - d[u] <= 1 + U * (1 - p[u, v]))

    for u in nodes:
        for v in nodes:
            if u != v:
                model.addConstr(d[u] + 1 <= d[v] + U * (1 - a[u, v]))

    for v in nodes:
        model.addConstr(gp.quicksum(a[u, v] for u in nodes if u != v) == d[v] - 1)

    for u in nodes:
        for v in nodes:
            if u != v:
                model.addConstr(a[u, v] >= p[u, v])

    for i in nodes:
        for j in nodes:
            if i != j:
                model.addConstr(a[i, j] + a[j, i] <= 1)

    for u, v in edges:
        model.addConstr(a[u, v] + a[v, u] >= 1)

    ms = build_dfs_mipstart(G)
    if not use_direct_max:
        max_depth.Start = ms["H"]
    for v in nodes:
        d[v].Start = ms["depth"][v]
        r[v].Start = ms["root"][v]
        for u in nodes:
            p[u, v].Start = 1 if ms["parent"][v] == u else 0
            a[u, v].Start = ms["ancestor"][(u, v)]

    model._a = a
    model._nodes = nodes
    model.optimize(lambda m, w: lazy_transitivity_callback(m, w))

    if model.Status == GRB.OPTIMAL:
        result = max(int(d[v].X) for v in nodes)
        return result, False
    elif model.Status == GRB.TIME_LIMIT and model.SolCount > 0:
        result = max(int(d[v].X) for v in nodes)
        return result, True
    else:
        return None, True


def build_ilp_and_solve_direct(G: nx.Graph, time_limit: int, threads: int, enable_preprocessing: bool = True) -> Tuple[Optional[int], bool, Dict]:
    processed_G, apex_buffer, preprocess_stats = preprocess_graph(G, enable_preprocessing)
    
    if processed_G.number_of_nodes() <= 1:
        return processed_G.number_of_nodes() + apex_buffer, False, preprocess_stats

    max_td = 0
    overall_timeout = False
    
    for comp in nx.connected_components(processed_G):
        sub_G = processed_G.subgraph(comp).copy()
        sub_G = nx.convert_node_labels_to_integers(sub_G, first_label=0, ordering="default")
        
        td, timeout = solve_single_component(sub_G, time_limit, threads)
        if td is None:
            return None, True, preprocess_stats
        
        max_td = max(max_td, td)
        if timeout:
            overall_timeout = True
    
    final_td = max_td + apex_buffer
    return final_td, overall_timeout, preprocess_stats


def main():
    parser = argparse.ArgumentParser(description="ILP treedepth solver (with preprocessing)")
    parser.add_argument("--timeout", type=int, default=DEFAULT_TIME)
    parser.add_argument("--max_v", type=int, default=DEFAULT_MAX_V)
    parser.add_argument("--max_e", type=int, default=DEFAULT_MAX_E)
    parser.add_argument("--threads", type=int, default=DEFAULT_THREADS)
    parser.add_argument("--output", type=str, default=DEFAULT_OUTFILE)
    parser.add_argument("--enable_preprocessing", action="store_true", default=False,
                       help="Enable graph preprocessing (degree-1 reduction and apex removal)")
    parser.add_argument("--disable_preprocessing", action="store_true", default=False,
                       help="Disable graph preprocessing")
    parser.add_argument("files", nargs="*")
    args, _ = parser.parse_known_args()

    if args.disable_preprocessing:
        enable_preprocessing = False
    elif args.enable_preprocessing:
        enable_preprocessing = True
    else:
        enable_preprocessing = True

    if args.files:
        file_patterns = args.files
    else:
        file_patterns = ["inputs/famous/*.edge", "inputs/standard/*.edge"]

    files = []
    for pat in file_patterns:
        files.extend(glob.glob(pat))
    files = sorted(set(files))
    if not files:
        print("No .edge files found")
        return

    print(f"Preprocessing: {'Enabled' if enable_preprocessing else 'Disabled'}")
    
    stats = {"solved": 0, "timeout": 0, "skipped": 0, "failed": 0}
    results_detail = []

    csv_headers = [
        "dataset", "instance", "n_original", "m_original", 
        "n_after_preprocess", "m_after_preprocess",
        "degree_one_removed", "apex_removed",
        "treedepth", "timeout", "time_sec"
    ]

    with open(args.output, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(csv_headers)

        for idx, filepath in enumerate(files, 1):
            dataset = os.path.basename(os.path.dirname(filepath))
            instance = os.path.splitext(os.path.basename(filepath))[0]
            print(f"\n[{idx}/{len(files)}] Processing {dataset}/{instance}")

            if instance in SKIP_LIST:
                print("  Skipped (blacklist)")
                stats["skipped"] += 1
                results_detail.append((instance, "-", "Skipped"))
                continue

            try:
                edges, n_declared = read_edge_file(filepath)
                G = create_graph_from_edges(edges, n_declared)
                n, m = G.number_of_nodes(), G.number_of_edges()
                print(f"  Original size: n={n}, m={m}")
            except Exception as e:
                print(f"  Error: {e}")
                stats["failed"] += 1
                results_detail.append((instance, "-", "Failed"))
                continue

            if n > args.max_v or m > args.max_e:
                print(f"  Skipped (exceeds size limit: n>{args.max_v} or m>{args.max_e})")
                stats["skipped"] += 1
                results_detail.append((instance, "-", "Skipped"))
                continue

            start_time = time.time()
            try:
                td, is_timeout, preprocess_stats = build_ilp_and_solve_direct(
                    G, args.timeout, args.threads, enable_preprocessing
                )
                solve_time = time.time() - start_time
                
                if enable_preprocessing:
                    print(f"  Preprocessing: degree-1 removed={preprocess_stats['degree_one_removed']}, "
                          f"apex removed={preprocess_stats['apex_removed']}, "
                          f"final size: n={preprocess_stats['final_nodes']}, m={preprocess_stats['final_edges']}")
                
                if td is not None:
                    if is_timeout:
                        stats["timeout"] += 1
                        print(f"  Result: treedepth≤{td} (timeout), time {solve_time:.2f}s")
                        results_detail.append((instance, f"≤{td}", "Timeout"))
                    else:
                        stats["solved"] += 1
                        print(f"  Result: treedepth={td} (optimal), time {solve_time:.2f}s")
                        results_detail.append((instance, str(td), "Optimal"))
                else:
                    stats["failed"] += 1
                    print(f"  Failed: could not solve, time {solve_time:.2f}s")
                    results_detail.append((instance, "-", "Failed"))
                
                writer.writerow([
                    dataset, instance, n, m, 
                    preprocess_stats["final_nodes"], preprocess_stats["final_edges"],
                    preprocess_stats["degree_one_removed"], preprocess_stats["apex_removed"],
                    td, is_timeout, round(solve_time, 2)
                ])
            except Exception as e:
                solve_time = time.time() - start_time
                print(f"  Error: {e}")
                stats["failed"] += 1
                results_detail.append((instance, "-", "Failed"))
                writer.writerow([
                    dataset, instance, n, m, 0, 0, 0, 0,
                    None, True, round(solve_time, 2)
                ])

    print("\n=== Processing complete ===")
    print(f"Preprocessing: {'Enabled' if enable_preprocessing else 'Disabled'}")
    print(f"Optimal solutions: {stats['solved']}")
    print(f"Timeouts: {stats['timeout']}")
    print(f"Skipped: {stats['skipped']}")
    print(f"Failed: {stats['failed']}")
    print(f"Results saved to: {args.output}")

    print("\n=== Summary by category ===")
    def show_category(title, category):
        items = [f"{name}({td})" for name, td, st in results_detail if st == category]
        print(f"[{title} {len(items)}]")
        if items:
            print(", ".join(items))

    show_category("Optimal", "Optimal")
    show_category("Timeout", "Timeout")
    show_category("Failed", "Failed")
    show_category("Skipped", "Skipped")


if __name__ == "__main__":
    main()
