In [2]:
"""
TSP Factor-Graph Message-Passing Solver
========================================
논문의 알고리즘을 식 번호 단위로 정확히 구현.

구조:
  1) Trellis Module   — Forward ψ,α (eq 15-16) / Backward β,ξ (eq 17-18)
  2) Assignment Module — λ_it (eq 21-27), ζ_it (eq 19-20/23)
  3) Bipartite Matching Module — γ̃,ω̃,φ̃,η̃,ρ̃,δ̃ (eq 37-42)

상태 표현:
  x_t = (m_t, a_t)
    m_t ∈ {0,1}^N : 방문 마스크 (N-bit)
    a_t ∈ {0,...,N-1} : 현재 방문 노드

시간: t = 1,...,T (T = N), depot에서 출발/귀환.
"""

import numpy as np

NEG = -1e12  # numerical -∞


class TSPFactorGraphSolver:

    def __init__(self, D, start_city=0,
                 damping=0.3, iters=200, verbose=False,
                 seed=0, patience=20, cost_tol=1e-12):
        """
        Parameters
        ----------
        D : (C x C) 거리 행렬. C = N+1 (N개 도시 + 1개 depot).
        start_city : depot으로 사용할 노드 인덱스.
        damping : γ̃, ω̃ 업데이트 시 damping factor (0~1).
        iters : 최대 반복 횟수.
        verbose : 매 반복 출력 여부.
        patience : cost 변화 없으면 조기 종료까지의 횟수.
        cost_tol : cost 변화 판단 임계값.
        """
        D = np.array(D, dtype=float)
        assert D.shape[0] == D.shape[1], "D must be square"
        C = D.shape[0]
        N = C - 1

        # --- 내부 인덱싱: start_city를 마지막 인덱스(depot)로 이동 ---
        perm = np.arange(C)
        if start_city != C - 1:
            perm[start_city], perm[C - 1] = perm[C - 1], perm[start_city]
        inv_perm = np.empty(C, dtype=int)
        inv_perm[perm] = np.arange(C)

        self.orig_D = D
        self.D_perm = D[perm][:, perm]
        self.perm = perm
        self.inv_perm = inv_perm
        self.C = C
        self.N = N               # 도시 수 (depot 제외)
        self.depot = C - 1       # 내부 depot 인덱스

        # --- Similarity matrix (eq 5): s(u,v) = max(D) - D(u,v) ---
        mx = np.max(self.D_perm)
        self.s = mx - self.D_perm

        self.T = N                # 시간 단계 수
        self.M = 1 << N           # 마스크 상태 수

        self.damping = float(damping)
        self.iters = int(iters)
        self.verbose = bool(verbose)
        self.rng = np.random.default_rng(seed)
        self.patience = int(patience)
        self.cost_tol = float(cost_tol)

        # --- Bipartite Matching 메시지 초기화 (shape: [N, T]) ---
        # γ̃_it, ω̃_it 만 상태로 유지 (damping 적용 대상)
        self.gamma_t = np.zeros((N, N))   # γ̃_it
        self.omega_t = np.zeros((N, N))   # ω̃_it

    # =================================================================
    #                         Public Interface
    # =================================================================
    def run(self):
        """메시지 패싱 반복 실행. (best_route, best_cost) 반환."""
        best_route, best_cost = None, None
        stable = 0
        last_cost = None

        for it in range(self.iters):
            # ── (1) BM 메시지에서 φ̃, η̃, ρ̃ 도출 (eq 37, 40, 41) ──
            phi_t, eta_t, rho_t = self._derive_bm_messages()

            # ── (2) Forward pass: ψ_t, α_t (eq 15-16) ──
            psi, alpha, backptr = self._forward(rho_t)

            # ── (3) Backward pass: β_t, ξ_t (eq 17-18) ──
            beta, xi = self._backward(rho_t)

            # ── (4) δ̃_it 계산 (eq 42, via ζ from eq 23) ──
            delta_t = self._compute_delta(psi, beta, rho_t)

            # ── (5) γ̃, ω̃ 업데이트 with damping (eq 38-39) ──
            gamma_new = eta_t + delta_t    # eq (39): γ̃ = η̃ + δ̃
            omega_new = phi_t + delta_t    # eq (38): ω̃ = φ̃ + δ̃
            self.gamma_t = self.damping * gamma_new + (1 - self.damping) * self.gamma_t
            self.omega_t = self.damping * omega_new + (1 - self.damping) * self.omega_t

            # ── (6) Decode & cost ──
            route = self._decode(alpha, backptr)
            cost = self._route_cost(route)

            if self.verbose:
                print(f"[{it+1:03d}] cost={cost:.6f}  route={route}")

            if best_cost is None or cost < best_cost:
                best_cost, best_route = cost, route

            # Early stopping
            if last_cost is not None and abs(cost - last_cost) <= self.cost_tol:
                stable += 1
            else:
                stable = 0
            last_cost = cost
            if stable >= self.patience:
                break

        return best_route, best_cost

    # =================================================================
    #        Bipartite Matching: φ̃, η̃, ρ̃ 도출 (eq 37, 40, 41)
    # =================================================================
    def _derive_bm_messages(self):
        """
        현재 γ̃, ω̃로부터:
          φ̃_it = -max_{i'≠i} γ̃_{i't}   ... (40)
          η̃_it = -max_{t'≠t} ω̃_{it'}   ... (37)
          ρ̃_it = η̃_it + φ̃_it           ... (41)
        """
        N, T = self.N, self.T
        phi_t = np.zeros((N, T))
        eta_t = np.zeros((N, T))

        # eq (40): φ̃_it = -max_{i'≠i} γ̃_{i't}
        for t in range(T):
            col = self.gamma_t[:, t]
            for i in range(N):
                phi_t[i, t] = -np.max(np.delete(col, i)) if N > 1 else 0.0

        # eq (37): η̃_it = -max_{t'≠t} ω̃_{it'}
        for i in range(N):
            row = self.omega_t[i, :]
            for t in range(T):
                eta_t[i, t] = -np.max(np.delete(row, t)) if T > 1 else 0.0

        # eq (41): ρ̃_it = η̃_it + φ̃_it
        rho_t = eta_t + phi_t

        return phi_t, eta_t, rho_t

    # =================================================================
    #             Assignment Module: λ 관련 헬퍼 (eq 24-27)
    # =================================================================
    def _lambda_sum(self, rho_t, t_idx, a):
        """
        Σ_i λ_it(a_t=a).

        Row-gauge-fix (eq 24-27)에 의해:
          λ_it(a) = (N-1)/N · ρ̃_it  if a=i
                   -1/N    · ρ̃_it  if a≠i

        합산하면:
          Σ_i λ_it(a) = ρ̃_{a,t} - (1/N) Σ_j ρ̃_{j,t}
        """
        return rho_t[a, t_idx] - np.mean(rho_t[:, t_idx])

    def _lambda_it_val(self, rho_t, i, t_idx, a):
        """
        λ_it(x_t) where a_t = a.  (eq 27)
          = (N-1)/N · ρ̃_it   if a = i
          = -1/N    · ρ̃_it   if a ≠ i
        """
        N = self.N
        if a == i:
            return (N - 1.0) / N * rho_t[i, t_idx]
        else:
            return -1.0 / N * rho_t[i, t_idx]

    def _lambda_sum_excl_i(self, rho_t, i, t_idx, a):
        """
        Σ_{i'≠i} λ_{i't}(a_t=a)  =  Σ_i λ_it(a) - λ_it(a).

        eq (19): ζ_it = ψ_t + β_t + Σ_{i'≠i} λ_{i't}  에서 사용.
        """
        return self._lambda_sum(rho_t, t_idx, a) - self._lambda_it_val(rho_t, i, t_idx, a)

    # =================================================================
    #             Trellis Forward Pass: ψ_t, α_t (eq 15-16)
    # =================================================================
    def _forward(self, rho_t):
        """
        eq (15): ψ_t(x_t) = max_{x_{t-1}} [ G_t(x_{t-1}, x_t) + α_{t-1}(x_{t-1}) ]
        eq (16): α_t(x_t) = ψ_t(x_t) + Σ_i λ_it(x_t)

        반환: psi[t, mask, a], alpha[t, mask, a], backptr[t, mask, a]
        """
        T, N, M = self.T, self.N, self.M
        depot = self.depot

        psi   = np.full((T + 1, M, N), NEG)
        alpha = np.full((T + 1, M, N), NEG)
        backptr = np.full((T + 1, M, N, 2), -1, dtype=int)

        # ── t = 1: depot → 첫 번째 도시 ──
        # α_0(x_0) = 0  (depot 단일 상태)
        # G_1(x_0, x_1) = s(depot, a_1)
        # ψ_1({a}, a) = s(depot, a) + 0 = s(depot, a)
        # α_1({a}, a) = ψ_1 + Σ_i λ_{i,1}(a)
        for a in range(N):
            m = 1 << a
            psi[1, m, a] = self.s[depot, a]
            alpha[1, m, a] = psi[1, m, a] + self._lambda_sum(rho_t, 0, a)
            backptr[1, m, a] = (0, -1)  # sentinel

        # ── t = 2 .. T ──
        for t in range(2, T + 1):
            t_idx = t - 1   # ρ̃의 0-based 시간 인덱스
            for mask in range(M):
                if bin(mask).count('1') != t:
                    continue
                for a in range(N):
                    if not (mask & (1 << a)):
                        continue  # a ∈ mask 이어야 유효
                    prev_mask = mask ^ (1 << a)

                    best = NEG
                    best_last = -1
                    # prev_mask 안의 모든 last에 대해 max
                    m = prev_mask
                    while m:
                        last = (m & -m).bit_length() - 1
                        m ^= (1 << last)
                        # eq (15): G_t + α_{t-1}
                        cand = alpha[t - 1, prev_mask, last] + self.s[last, a]
                        if cand > best:
                            best = cand
                            best_last = last

                    if best > NEG / 2:
                        psi[t, mask, a] = best                                    # eq (15)
                        alpha[t, mask, a] = best + self._lambda_sum(rho_t, t_idx, a)  # eq (16)
                        backptr[t, mask, a] = (prev_mask, best_last)

        return psi, alpha, backptr

    # =================================================================
    #            Trellis Backward Pass: β_t, ξ_t (eq 17-18)
    # =================================================================
    def _backward(self, rho_t):
        """
        eq (18): β_t(x_t) = max_{x_{t+1}} [ G_{t+1}(x_t, x_{t+1}) + ξ_{t+1}(x_{t+1}) ]
        eq (17): ξ_t(x_t) = β_t(x_t) + Σ_i λ_it(x_t)

        경계: t = T → β_T(full, a) = s(a, depot)  (depot 귀환)

        반환: beta[t, mask, a], xi[t, mask, a]
        """
        T, N, M = self.T, self.N, self.M
        full = (1 << N) - 1
        depot = self.depot

        beta = np.full((T + 1, M, N), NEG)
        xi   = np.full((T + 1, M, N), NEG)

        # ── t = T: closure (마지막 도시 → depot) ──
        for a in range(N):
            beta[T, full, a] = self.s[a, depot]
            xi[T, full, a] = beta[T, full, a] + self._lambda_sum(rho_t, T - 1, a)

        # ── t = T-1 .. 1 ──
        for t in range(T - 1, 0, -1):
            t_idx = t - 1   # ρ̃의 0-based 시간 인덱스
            for mask in range(M):
                if bin(mask).count('1') != t:
                    continue
                for last in range(N):
                    if not (mask & (1 << last)):
                        continue  # last ∈ mask

                    best = NEG
                    avail = (~mask) & full
                    m = avail
                    while m:
                        a = (m & -m).bit_length() - 1
                        m ^= (1 << a)
                        new_mask = mask | (1 << a)
                        # eq (18): G_{t+1} + ξ_{t+1}
                        cand = self.s[last, a] + xi[t + 1, new_mask, a]
                        if cand > best:
                            best = cand

                    if best > NEG / 2:
                        beta[t, mask, last] = best                                      # eq (18)
                        xi[t, mask, last] = best + self._lambda_sum(rho_t, t_idx, last)  # eq (17)

        return beta, xi

    # =================================================================
    #                   δ̃_it 계산 (eq 42, via ζ eq 23)
    # =================================================================
    def _compute_delta(self, psi, beta, rho_t):
        """
        eq (42):
          δ̃_it = max_{m_t} ζ_it(m_t, i)  -  max_{m_t, a_t≠i} ζ_it(m_t, a_t)

        eq (23):
          ζ_it(x_t) = ψ_t(x_t) + β_t(x_t) + Σ_{i'≠i} λ_{i't}(a_t)

        ψ + β = Γ (total path metric without λ at time t).
        Σ_{i'≠i} λ 는 a_t에만 의존 (m_t에 무관).
        """
        T, N, M = self.T, self.N, self.M
        delta = np.zeros((N, T))

        for t in range(1, T + 1):
            t_idx = t - 1

            for i in range(N):
                best_with = NEG      # max over (mask, a_t=i)
                best_without = NEG   # max over (mask, a_t≠i)

                for mask in range(M):
                    if bin(mask).count('1') != t:
                        continue

                    # ── a_t = i (bit=1 case) ──
                    if mask & (1 << i):
                        gamma_v = psi[t, mask, i] + beta[t, mask, i]
                        if gamma_v > NEG / 2:
                            z = gamma_v + self._lambda_sum_excl_i(rho_t, i, t_idx, i)
                            if z > best_with:
                                best_with = z

                    # ── a_t ≠ i (bit=0 case) ──
                    m2 = mask
                    while m2:
                        a = (m2 & -m2).bit_length() - 1
                        m2 ^= (1 << a)
                        if a == i:
                            continue
                        gamma_v = psi[t, mask, a] + beta[t, mask, a]
                        if gamma_v > NEG / 2:
                            z = gamma_v + self._lambda_sum_excl_i(rho_t, i, t_idx, a)
                            if z > best_without:
                                best_without = z

                if best_with <= NEG / 2 and best_without <= NEG / 2:
                    delta[i, t_idx] = 0.0
                else:
                    delta[i, t_idx] = best_with - best_without

        return delta

    # =================================================================
    #                         Route Decoding
    # =================================================================
    def _decode(self, alpha, backptr):
        """
        α_T(full, a) + s(a, depot)를 최대화하는 경로를 backtrack.
        """
        T, N = self.T, self.N
        full = (1 << N) - 1
        depot = self.depot

        best_val = NEG
        best_last = -1
        for a in range(N):
            val = alpha[T, full, a] + self.s[a, depot]
            if val > best_val:
                best_val = val
                best_last = a

        if best_last < 0:
            # Fallback: greedy nearest-neighbor
            route = [depot]
            used = set()
            for _ in range(N):
                scores = self.s[route[-1], :N].copy()
                for u in used:
                    scores[u] = NEG
                a = int(np.argmax(scores))
                used.add(a)
                route.append(a)
            route.append(depot)
        else:
            route_inner = []
            mask = full
            last = best_last
            t = T
            while t > 0 and last >= 0:
                route_inner.append(last)
                pm, pl = backptr[t, mask, last]
                mask, last = pm, pl
                t -= 1
            route_inner.reverse()
            route = [depot] + route_inner + [depot]

        # 내부 인덱스 → 원래 인덱스 변환
        return [int(self.inv_perm[c]) for c in route]

    def _route_cost(self, route):
        """원래 거리 행렬 기준 경로 비용."""
        return float(sum(self.orig_D[route[k], route[k + 1]]
                         for k in range(len(route) - 1)))

In [34]:
"""
TSP Factor-Graph Message-Passing Solver — GPU-Accelerated (PyTorch)
====================================================================
핵심 가속 전략:
  1) 마스크별 순차 루프 제거 → 동일 popcount 마스크를 텐서 배치로 처리
  2) Forward/Backward에서 per-city 루프 + 배치 gather/scatter
  3) δ̃ 계산: peak 팩토리제이션으로 O(N²T·M) → O(N²T + N·T·M) 축소

메모리 프로파일 (float64 기준):
  psi, alpha, beta, xi : 각 [T+1, M, N] → N=15: ~61MB, N=20: ~320MB
  중간 텐서 (per step)  : [M, N]         → N=15: ~4MB,  N=20: ~160MB
  총 피크: N=15 ~300MB, N=20 ~2GB

권장 N 범위: ≤ 20 (GPU 24GB 기준), ≤ 15 (GPU 8GB 기준)
"""

import torch
import numpy as np
import gc
from typing import Optional, Tuple, List

NEG = -1e12


def get_device(device: Optional[str] = None) -> torch.device:
    """CUDA > MPS > CPU 자동 선택."""
    if device is not None:
        return torch.device(device)
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def flush_gpu(device: Optional[torch.device] = None):
    """GPU 메모리 강제 정리. solver 생성 전/후에 호출 가능."""
    gc.collect()
    if device is not None and device.type == "cpu":
        return
    if torch.cuda.is_available():
        torch.cuda.synchronize()          # 미완료 커널 대기
        gc.collect()                       # Python 참조 해제
        torch.cuda.empty_cache()           # PyTorch 캐시 → OS 반환
        torch.cuda.ipc_collect()           # IPC 공유 메모리 정리
        torch.cuda.reset_peak_memory_stats()
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        gc.collect()
        if hasattr(torch.mps, "empty_cache"):
            torch.mps.empty_cache()        # PyTorch 2.1+


class TSPFactorGraphSolverGPU:

    def __init__(self, D, start_city: int = 0,
                 damping: float = 0.3, iters: int = 200,
                 verbose: bool = False, seed: int = 0,
                 patience: int = 20, cost_tol: float = 1e-12,
                 device: Optional[str] = None):
        """
        Parameters
        ----------
        D : (C x C) 거리 행렬. C = N+1 (N개 도시 + 1개 depot).
        device : 'cuda', 'mps', 'cpu' 또는 None (자동).
        """
        self.device = get_device(device)
        self.dtype = torch.float32
        torch.manual_seed(seed)

        # --- 기존 GPU 텐서 정리 ---
        flush_gpu(self.device)

        D_np = np.array(D, dtype=np.float32)
        assert D_np.shape[0] == D_np.shape[1], "D must be square"
        C = D_np.shape[0]
        N = C - 1

        # --- 내부 인덱싱: start_city → 마지막 인덱스(depot) ---
        perm = np.arange(C)
        if start_city != C - 1:
            perm[start_city], perm[C - 1] = perm[C - 1], perm[start_city]
        inv_perm = np.empty(C, dtype=int)
        inv_perm[perm] = np.arange(C)

        self.orig_D = torch.tensor(D_np, dtype=self.dtype, device=self.device)
        D_perm = D_np[perm][:, perm]
        self.D_perm = torch.tensor(D_perm, dtype=self.dtype, device=self.device)
        self.inv_perm = inv_perm
        self.C, self.N, self.depot = C, N, C - 1

        # Similarity (eq 5)
        self.s = self.D_perm.max() - self.D_perm
        # 도시 간 유사도 서브매트릭스 [N, N]
        self.s_city = self.s[:N, :N]

        self.T = N
        self.M = 1 << N
        self.damping = damping
        self.iters = iters
        self.verbose = verbose
        self.patience = patience
        self.cost_tol = cost_tol

        # BM 메시지
        self.gamma_t = torch.zeros((N, N), dtype=self.dtype, device=self.device)
        self.omega_t = torch.zeros((N, N), dtype=self.dtype, device=self.device)

        # --- 마스크 테이블 사전 계산 ---
        self._precompute_tables()

        if self.verbose:
            print(f"[Device: {self.device}] N={N}, M={self.M}, "
                  f"peak mem ~{self._estimate_mem_mb():.0f} MB")

    # =================================================================
    #                    마스크 테이블 사전 계산
    # =================================================================
    def _precompute_tables(self):
        N, M = self.N, self.M
        dev, dt = self.device, self.dtype

        mask_range = torch.arange(M, device=dev, dtype=torch.long)

        # popcount[m] = number of set bits in m
        pc = torch.zeros(M, dtype=torch.int32, device=dev)
        for i in range(N):
            pc += ((mask_range >> i) & 1).int()
        self._popcount = pc

        # bit_set[m, a] = bool: bit a is set in mask m
        bit_vals = (1 << torch.arange(N, device=dev, dtype=torch.long))
        self._bit_set = (mask_range.unsqueeze(1) & bit_vals.unsqueeze(0)) != 0  # [M, N]

        # prev_mask[m, a] = m ^ (1<<a)  (remove bit a)
        self._prev_mask = mask_range.unsqueeze(1) ^ bit_vals.unsqueeze(0)  # [M, N]

        # next_mask[m, a] = m | (1<<a)  (add bit a)
        self._next_mask = mask_range.unsqueeze(1) | bit_vals.unsqueeze(0)  # [M, N]

        # NEG 텐서 캐시
        self._NEG_M = torch.full((M,), NEG, dtype=dt, device=dev)
        self._NEG_MN = torch.full((M, N), NEG, dtype=dt, device=dev)
        self._NEG_1 = torch.tensor(NEG, dtype=dt, device=dev)
        self._ZERO = torch.tensor(0.0, dtype=dt, device=dev)
        self._NEG1_long = torch.tensor(-1, dtype=torch.long, device=dev)

    def _estimate_mem_mb(self):
        T, M, N = self.T, self.M, self.N
        main = 4 * (T + 1) * M * N * 8
        tables = 3 * M * N * 8
        return (main + tables) / (1024 ** 2)

    # =================================================================
    #                    GPU 메모리 관리
    # =================================================================
    def cleanup(self):
        """명시적 GPU 메모리 해제. 솔버 사용 후 호출 권장."""
        dev = self.device if hasattr(self, 'device') else None

        # 모든 텐서 속성 삭제
        tensor_attrs = [k for k, v in self.__dict__.items()
                        if isinstance(v, torch.Tensor)]
        for attr in tensor_attrs:
            delattr(self, attr)

        # 사전 계산 테이블 등 나머지 대형 속성
        for attr in ['_bit_set', '_prev_mask', '_next_mask', '_popcount',
                      '_NEG_M', '_NEG_MN', '_NEG_1', '_ZERO', '_NEG1_long']:
            if hasattr(self, attr):
                delattr(self, attr)

        flush_gpu(dev)

    def __del__(self):
        try:
            self.cleanup()
        except Exception:
            pass

    def _print_gpu_mem(self, tag: str = ""):
        """CUDA 메모리 사용량 출력 (디버깅용)."""
        if self.device.type == "cuda":
            alloc = torch.cuda.memory_allocated(self.device) / (1024 ** 2)
            reserved = torch.cuda.memory_reserved(self.device) / (1024 ** 2)
            print(f"  [GPU mem {tag}] alloc={alloc:.1f}MB, reserved={reserved:.1f}MB")

    # =================================================================
    #                         Public Interface
    # =================================================================
    def run(self) -> Tuple[List[int], float]:
        best_route, best_cost = None, None
        stable = 0
        last_cost = None

        for it in range(self.iters):
            phi_t, eta_t, rho_t = self._derive_bm_messages()
            psi, alpha, backptr = self._forward(rho_t)
            beta, xi = self._backward(rho_t)
            delta_t = self._compute_delta(psi, beta, rho_t)

            # γ̃, ω̃ damping 업데이트
            gamma_new = eta_t + delta_t
            omega_new = phi_t + delta_t
            self.gamma_t = self.damping * gamma_new + (1 - self.damping) * self.gamma_t
            self.omega_t = self.damping * omega_new + (1 - self.damping) * self.omega_t
            print(self.gamma_t[0])
            route = self._decode(alpha, backptr)
            cost = self._route_cost(route)

            # ── 중간 텐서 즉시 해제 ──
            del psi, beta, xi, phi_t, eta_t, rho_t, delta_t
            del gamma_new, omega_new

            if self.verbose:
                print(f"[{it+1:03d}] cost={cost:.6f}  route={route}")

            if best_cost is None or cost < best_cost:
                best_cost, best_route = cost, route

            # decode에서 사용 완료
            del alpha, backptr

            if last_cost is not None and abs(cost - last_cost) <= self.cost_tol:
                stable += 1
            else:
                stable = 0
            last_cost = cost
            if stable >= self.patience:
                break

        return best_route, best_cost

    # =================================================================
    #        Bipartite Matching: φ̃, η̃, ρ̃ (eq 37, 40, 41)
    # =================================================================
    def _derive_bm_messages(self):
        """
        벡터화된 top-2 기반 "자기 제외 max" 계산.
        """
        N, T = self.N, self.T
        dev, dt = self.device, self.dtype

        # --- eq (40): φ̃_it = -max_{i'≠i} γ̃_{i't} ---
        phi_t = torch.zeros((N, T), dtype=dt, device=dev)
        if N > 1:
            top2_v, top2_i = self.gamma_t.topk(2, dim=0)  # [2, T]
            for i in range(N):
                is_top = (top2_i[0] == i)
                phi_t[i] = -torch.where(is_top, top2_v[1], top2_v[0])

        # --- eq (37): η̃_it = -max_{t'≠t} ω̃_{it'} ---
        eta_t = torch.zeros((N, T), dtype=dt, device=dev)
        if T > 1:
            top2_v, top2_i = self.omega_t.topk(2, dim=1)  # [N, 2]
            for t in range(T):
                is_top = (top2_i[:, 0] == t)
                eta_t[:, t] = -torch.where(is_top, top2_v[:, 1], top2_v[:, 0])

        rho_t = eta_t + phi_t
        return phi_t, eta_t, rho_t

    # =================================================================
    #         Forward Pass (벡터화): ψ_t, α_t (eq 15-16)
    # =================================================================
    def _forward(self, rho_t):
        """
        per-city 루프 + 배치 gather로 마스크 루프 제거.
        메모리: O(M·N) per time step.
        """
        T, N, M = self.T, self.N, self.M
        dev, dt = self.device, self.dtype
        depot = self.depot

        psi     = torch.full((T + 1, M, N), NEG, dtype=dt, device=dev)
        alpha   = torch.full((T + 1, M, N), NEG, dtype=dt, device=dev)
        backptr = torch.full((T + 1, M, N), -1, dtype=torch.long, device=dev)

        # λ_sum 전체 사전 계산: [N, T]
        lambda_sum_all = rho_t - rho_t.mean(dim=0, keepdim=True)

        # ── t = 1: depot → 첫 도시 ──
        for a in range(N):
            m = 1 << a
            s_val = self.s[depot, a]
            psi[1, m, a] = s_val
            alpha[1, m, a] = s_val + lambda_sum_all[a, 0]

        # ── t = 2 .. T: 배치 처리 ──
        for t in range(2, T + 1):
            t_idx = t - 1
            valid_pop = (self._popcount == t)  # [M]

            for a in range(N):
                # valid[m] = popcount(m)==t AND bit a set in m
                valid = valid_pop & self._bit_set[:, a]  # [M]

                # prev_mask[m] = m ^ (1<<a)
                prev = self._prev_mask[:, a]  # [M]

                # gather alpha[t-1, prev[m], :] → [M, N]
                alpha_gathered = alpha[t - 1][prev]  # [M, N]

                # scores[m, last] = alpha[t-1, prev, last] + s(last, a)
                scores = alpha_gathered + self.s_city[:, a].unsqueeze(0)  # [M, N]

                # last는 prev_mask에 bit가 있어야 유효
                valid_last = self._bit_set[prev]  # [M, N]
                scores = torch.where(valid_last, scores, self._NEG_MN)

                # max over last
                psi_val, argmax_last = scores.max(dim=1)  # [M]

                psi[t, :, a] = torch.where(valid, psi_val, self._NEG_M)
                alpha[t, :, a] = torch.where(
                    valid,
                    psi_val + lambda_sum_all[a, t_idx],
                    self._NEG_M
                )
                backptr[t, :, a] = torch.where(
                    valid, argmax_last, self._NEG1_long.expand(M)
                )

        return psi, alpha, backptr

    # =================================================================
    #         Backward Pass (벡터화): β_t, ξ_t (eq 17-18)
    # =================================================================
    def _backward(self, rho_t):
        """
        per-city gather + 배치 scatter.
        """
        T, N, M = self.T, self.N, self.M
        dev, dt = self.device, self.dtype
        full = (1 << N) - 1
        depot = self.depot

        beta = torch.full((T + 1, M, N), NEG, dtype=dt, device=dev)
        xi   = torch.full((T + 1, M, N), NEG, dtype=dt, device=dev)

        lambda_sum_all = rho_t - rho_t.mean(dim=0, keepdim=True)

        # ── t = T: closure ──
        for a in range(N):
            beta[T, full, a] = self.s[a, depot]
            xi[T, full, a] = self.s[a, depot] + lambda_sum_all[a, T - 1]

        # ── t = T-1 .. 1 ──
        for t in range(T - 1, 0, -1):
            t_idx = t - 1
            valid_pop = (self._popcount == t)  # [M]

            # xi_next_a[m, a] = xi[t+1, m|(1<<a), a]
            xi_next = xi[t + 1]  # [M, N]
            xi_next_a = torch.full((M, N), NEG, dtype=dt, device=dev)
            for a in range(N):
                next_m = self._next_mask[:, a]  # [M]
                xi_next_a[:, a] = xi_next[next_m, a]

            for last in range(N):
                valid = valid_pop & self._bit_set[:, last]  # [M]

                # scores[m, a] = s(last, a) + xi[t+1, m|(1<<a), a]
                scores = self.s_city[last, :].unsqueeze(0) + xi_next_a  # [M, N]

                # a는 현재 마스크에 없어야 함
                not_in_mask = ~self._bit_set  # [M, N]
                scores = torch.where(not_in_mask, scores, self._NEG_MN)

                # max over a
                beta_val, _ = scores.max(dim=1)  # [M]

                beta[t, :, last] = torch.where(valid, beta_val, self._NEG_M)
                xi[t, :, last] = torch.where(
                    valid,
                    beta_val + lambda_sum_all[last, t_idx],
                    self._NEG_M
                )

        return beta, xi

    # =================================================================
    #     δ̃ 계산 (최적화): peak 팩토리제이션 (eq 42)
    # =================================================================
    def _compute_delta(self, psi, beta, rho_t):
        """
        핵심 최적화:
          gamma_val = ψ + β  (경로 메트릭)
          peak[t, a] = max_{m: popcount=t, bit_a∈m} gamma_val[t, m, a]

        λ_sum_excl_i는 mask에 무관하므로 factoring 가능:
          best_with[i,t]    = excl_λ(i,t,i) + peak[t,i]
          best_without[i,t] = max_{a≠i} [excl_λ(i,t,a) + peak[t,a]]

        복잡도: O(N²·T) vs 원래 O(N²·T·M)
        """
        T, N, M = self.T, self.N, self.M
        dev, dt = self.device, self.dtype

        gamma_val = psi + beta  # [T+1, M, N]

        # --- peak[t, a] 계산 ---
        peak = torch.full((T + 1, N), NEG, dtype=dt, device=dev)
        for t in range(1, T + 1):
            valid = (self._popcount == t).unsqueeze(1) & self._bit_set  # [M, N]
            gv = gamma_val[t]  # [M, N]
            gv_masked = torch.where(valid & (gv > NEG / 2), gv, self._NEG_MN)
            peak[t] = gv_masked.max(dim=0).values  # [N]

        # --- lambda_sum_all ---
        lambda_sum_all = rho_t - rho_t.mean(dim=0, keepdim=True)  # [N, T]

        delta = torch.zeros((N, T), dtype=dt, device=dev)

        for t in range(1, T + 1):
            t_idx = t - 1
            pk = peak[t]                    # [N]
            lsa = lambda_sum_all[:, t_idx]  # [N]
            rho_col = rho_t[:, t_idx]       # [N]

            # --- best_with[i] ---
            # excl_λ(i, t, a=i) = λ_sum(t,i) - (N-1)/N · ρ[i,t]
            excl_w = lsa - (N - 1.0) / N * rho_col  # [N]
            best_with = excl_w + pk  # [N]
            best_with = torch.where(pk > NEG / 2, best_with,
                                    torch.full_like(best_with, NEG))

            # --- best_without[i] = max_{a≠i} [excl_λ(i,t,a) + peak[t,a]] ---
            # excl_λ(i, t, a≠i) = λ_sum(t,a) + 1/N · ρ[i,t]
            # score[i, a] = (lsa[a] + pk[a]) + 1/N · ρ[i,t]
            base_a = lsa + pk  # [N]
            rho_term = (1.0 / N) * rho_col  # [N]

            scores_wo = base_a.unsqueeze(0) + rho_term.unsqueeze(1)  # [N, N]
            scores_wo.fill_diagonal_(NEG)

            # peak가 NEG인 a 제외
            invalid_pk = (pk <= NEG / 2)
            scores_wo[:, invalid_pk] = NEG

            best_without = scores_wo.max(dim=1).values  # [N]

            both_neg = (best_with <= NEG / 2) & (best_without <= NEG / 2)
            delta[:, t_idx] = torch.where(both_neg, self._ZERO,
                                          best_with - best_without)

        return delta

    # =================================================================
    #                         Route Decoding
    # =================================================================
    def _decode(self, alpha, backptr):
        T, N = self.T, self.N
        full = (1 << N) - 1
        depot = self.depot

        final_scores = alpha[T, full, :] + self.s[:N, depot]
        best_last = final_scores.argmax().item()

        if final_scores[best_last].item() <= NEG / 2:
            return self._greedy_fallback()

        route_inner = []
        mask = full
        a = best_last
        for t in range(T, 0, -1):
            route_inner.append(a)
            prev_a = backptr[t, mask, a].item()
            mask = mask ^ (1 << a)
            a = prev_a
        route_inner.reverse()
        route = [depot] + route_inner + [depot]
        return [int(self.inv_perm[c]) for c in route]

    def _greedy_fallback(self):
        N, depot = self.N, self.depot
        route = [depot]
        used = set()
        for _ in range(N):
            scores = self.s[route[-1], :N].clone()
            for u in used:
                scores[u] = NEG
            a = scores.argmax().item()
            used.add(a)
            route.append(a)
        route.append(depot)
        return [int(self.inv_perm[c]) for c in route]

    def _route_cost(self, route) -> float:
        cost = 0.0
        for k in range(len(route) - 1):
            cost += self.orig_D[route[k], route[k + 1]].item()
        return cost

In [35]:
import numpy as np
from itertools import permutations

class TSPBruteForceSolver:
    """
    Brute-force TSP solver for small N.
    - Fixes start_city to remove rotational symmetry.
    - Returns a Hamiltonian cycle: [start, ..., start]
    """

    def __init__(self, D: np.ndarray, start_city: int = 0, verbose: bool = False):
        D = np.asarray(D, dtype=float)
        if D.ndim != 2 or D.shape[0] != D.shape[1]:
            raise ValueError("D must be a square 2D matrix.")
        if not (0 <= start_city < D.shape[0]):
            raise ValueError("start_city out of range.")
        self.D = D
        self.N = D.shape[0]
        self.start = start_city
        self.verbose = verbose

    def route_cost(self, route):
        """
        route: list/tuple of cities, must be a cycle [s, ..., s]
        """
        total = 0.0
        for a, b in zip(route[:-1], route[1:]):
            total += float(self.D[a, b])
        return total

    def run(self):
        cities = list(range(self.N))
        others = [c for c in cities if c != self.start]

        best_cost = float("inf")
        best_route = None

        # Evaluate all permutations of the remaining cities
        # Tour: start -> perm... -> start
        for idx, perm in enumerate(permutations(others), start=1):
            route = (self.start,) + perm + (self.start,)
            cost = 0.0
            # inline cost to reduce overhead a bit
            for a, b in zip(route[:-1], route[1:]):
                cost += self.D[a, b]

            if cost < best_cost:
                best_cost = float(cost)
                best_route = list(route)
                if self.verbose:
                    print(f"[BruteForce] New best @ {idx}: cost={best_cost:.6f}, route={best_route}")

        return best_route, best_cost

In [36]:
import time

if __name__ == "__main__":
    
    np.random.seed(42)
    N_CITIES = 12
    D = np.random.rand(N_CITIES, N_CITIES)
    np.fill_diagonal(D, 0)

    '''start1 = time.time()
    solver = TSPFactorGraphSolver(
        D, start_city=0,
        damping=0.3,
        iters=100,
        verbose=True,
        patience=20,
    )
    route_fg, cost_fg = solver.run()
    end1 = time.time()
    print(f"\nFinal Route: {route_fg}")
    print(f"Final Cost:  {cost_fg:.6f}")
    print(f"Total Time:  {(end1 - start1)*1000:.3f} ms\n")'''

    start2 = time.time()
    solver_GPU = TSPFactorGraphSolverGPU(
        D, start_city=0,
        damping=0.3,
        iters=100,
        verbose=True,
        patience=20,
    )
    route_GPU, cost_GPU = solver_GPU.run()
    end2 = time.time()
    
    print(f"\nFinal Route: {route_GPU}")
    print(f"Final Cost:  {cost_GPU:.6f}")
    print(f"Total Time:  {(end2 - start2)*1000:.3f} ms")

    # 메모리 진단
    if solver_GPU.device.type == "cuda":
        solver_GPU._print_gpu_mem("before cleanup")

    solver_GPU.cleanup()

    if torch.cuda.is_available():
        alloc = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"  [GPU mem after cleanup] alloc={alloc:.1f}MB, reserved={reserved:.1f}MB")
        print(f"  ※ reserved는 PyTorch 풀이 유지하는 메모리. "
              f"프로세스 종료 시 반환됨.")
    print("Cleanup done.")

    '''# 브루트포스 솔버 (정답)
    start3 = time.time()
    bf = TSPBruteForceSolver(D, start_city=0, verbose=False)
    route_bf, cost_bf = bf.run()
    end3 = time.time()
    print(f"\n[BruteForce] Optimal Route: {route_bf}")
    print(f"[BruteForce] Optimal Cost:  {cost_bf:.6f}")
    print(f"Total Time:  {(end3 - start3)*1000:.3f} ms")

    # 비교 지표
    gap = cost_fg - cost_bf
    rel_gap = gap / (abs(cost_bf) + 1e-12)
    print(f"\nGap (FG - BF):     {gap:.6f}")
    print(f"Relative gap:      {rel_gap*100:.3f}%")
    print(f"FG matches optimum? {abs(gap) < 1e-9}")'''

[Device: mps] N=11, M=2048, peak mem ~9 MB
tensor([-0.2678, -0.0016, -0.0502, -0.1002, -0.0227, -0.0147,  0.0016, -0.0560,
        -0.0836, -0.0522, -0.0023], device='mps:0')
[001] cost=1.356860  route=[0, 10, 3, 4, 8, 2, 5, 11, 9, 1, 7, 6, 0]
tensor([-0.4774, -0.0194, -0.1075, -0.1888, -0.0496, -0.0397,  0.0194, -0.1074,
        -0.1547, -0.1089, -0.0259], device='mps:0')
[002] cost=1.356860  route=[0, 10, 3, 4, 8, 2, 5, 11, 9, 1, 7, 6, 0]
tensor([-0.7837, -0.2186, -0.3142, -0.4069, -0.1912, -0.2227,  0.1700, -0.2587,
        -0.3412, -0.2851, -0.2413], device='mps:0')
[003] cost=1.356860  route=[0, 10, 3, 4, 8, 2, 5, 11, 9, 1, 7, 6, 0]
tensor([-1.4784, -0.8331, -0.9663, -0.9561, -1.1899, -1.1377,  0.7708, -0.9471,
        -0.9697, -0.9778, -0.8184], device='mps:0')
[004] cost=1.356860  route=[0, 10, 3, 4, 8, 2, 5, 11, 9, 1, 7, 6, 0]
tensor([-3.2056, -2.4787, -2.6381, -2.5736, -3.3609, -3.0519,  2.3544, -2.6702,
        -2.6562, -2.6610, -2.4220], device='mps:0')
[005] cost=1.356860  