In [None]:
# milestone2_best_hungarian.py
import os, cv2, time, csv, itertools
import numpy as np
from pathlib import Path

# Hungarian
try:
    from scipy.optimize import linear_sum_assignment
    SCIPY_AVAILABLE = True
except:
    SCIPY_AVAILABLE = False
    linear_sum_assignment = None

# ---------------- CONFIG ----------------
ORIGIN = r"C:\Term 5\image_project\OneDrive_2025-11-26\Jigsaw Puzzle Dataset"
SCRAMBLED_FOLDER = os.path.join(ORIGIN, "Gravity Falls", "puzzle_2x2")
CORRECT_FOLDER   = os.path.join(ORIGIN, "Gravity Falls", "correct")
M1_RESULTS_DIR   = os.path.join(ORIGIN, "results", "results_2x2")
OUTPUT_DIR       = os.path.join(ORIGIN, "milestone2_outputs_tile_solver")
os.makedirs(OUTPUT_DIR, exist_ok=True)

N = 2                 # 2x2 puzzles
BORDER_WIDTH = 12     # edge strip width
# initial relative weights (will be normalized automatically)
W_APP = 1.0   # appearance (descriptor) weight
W_SCR = 4.0   # scrambled adjacency weight
W_COR = 3.0   # correct adjacency weight
ROT_PEN = 0.06  # per-90deg rotation penalty

# ----------------- HELPERS -----------------
def rotate_img(img, angle):
    if angle % 360 == 0:
        return img
    k = (angle // 90) % 4
    return np.ascontiguousarray(np.rot90(img, -k))

def split_into_tiles(img, N):
    h, w = img.shape[:2]
    th, tw = h // N, w // N
    tiles = []
    tid = 0
    for r in range(N):
        for c in range(N):
            y0, y1 = r*th, (r+1)*th
            x0, x1 = c*tw, (c+1)*tw
            tiles.append({'id': tid, 'pos': (r,c), 'tile': img[y0:y1, x0:x1].copy()})
            tid += 1
    return tiles, (th, tw)

# ----------------- DESCRIPTORS -----------------
def color_hist_descriptor(img, bins=16):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    h = cv2.calcHist([hsv],[0],None,[bins],[0,180]).flatten()
    s = cv2.calcHist([hsv],[1],None,[bins],[0,256]).flatten()
    v = cv2.calcHist([hsv],[2],None,[bins],[0,256]).flatten()
    vec = np.concatenate([h,s,v]).astype(np.float32)
    ssum = vec.sum()
    return vec/ssum if ssum>0 else vec

def hu_moments_descriptor(img):
    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    g = cv2.GaussianBlur(g, (3,3), 0)
    m = cv2.moments(g)
    hu = cv2.HuMoments(m).flatten()
    for i in range(len(hu)):
        if hu[i] != 0:
            hu[i] = -np.sign(hu[i]) * np.log10(abs(hu[i]) + 1e-12)
    return hu.astype(np.float32)

def contour_fourier_descriptor(img, n_coeff=32):
    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    e = cv2.Canny(g, 60, 150)
    cnts, _ = cv2.findContours(e, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if not cnts:
        return np.zeros(n_coeff, dtype=np.float32)
    c = max(cnts, key=lambda x: x.shape[0]).squeeze()
    if c.ndim != 2 or c.shape[0] < 6:
        return np.zeros(n_coeff, dtype=np.float32)
    pts = c[:,0] + 1j*c[:,1]
    L = len(pts)
    M = max(128, n_coeff*2)
    t0 = np.linspace(0,1,L); t1 = np.linspace(0,1,M)
    real = np.interp(t1, t0, pts.real)
    imag = np.interp(t1, t0, pts.imag)
    sig = real + 1j*imag
    f = np.fft.fft(sig)
    mag = np.abs(f[:n_coeff])
    return (mag / (mag[0] + 1e-12)).astype(np.float32)

def make_descriptor(img):
    return np.concatenate([
        color_hist_descriptor(img),
        hu_moments_descriptor(img),
        contour_fourier_descriptor(img, n_coeff=32)
    ]).astype(np.float32)

# ----------------- EDGE STRIP + NCC -----------------
def extract_edge_strip_rgb(img, edge_idx, width=BORDER_WIDTH):
    h,w = img.shape[:2]
    f = img.astype(np.float32)
    if edge_idx == 0:
        region = f[0:width,:,:]; vec = region.mean(axis=0)
    elif edge_idx == 2:
        region = f[h-width:h,:,:]; vec = region[::-1,:,:].mean(axis=0)
    elif edge_idx == 1:
        region = f[:, w-width:w, :]; vec = region.mean(axis=1)
    elif edge_idx == 3:
        region = f[:, 0:width, :]; vec = region[:, ::-1, :].mean(axis=1)
    else:
        raise ValueError("edge idx must be 0..3")
    vec = vec.T.astype(np.float32)  # (3, L)
    L = vec.shape[1]
    if L > 1:
        tri = 1.0 - np.abs(2*np.linspace(0,1,L)-1.0)
        tri = tri / (tri.max() + 1e-12)
        vec = vec * tri[np.newaxis, :]
    for c in range(3):
        v = vec[c]; m = v.mean(); s = v.std()
        vec[c] = (v - m) / (s + 1e-12)
    return vec

def rgb_ncc(stripA, stripB, L_target=128):
    def resample(v, L):
        old = v.shape[1]
        if old == L:
            return v.copy()
        x = np.linspace(0,1,old); xi = np.linspace(0,1,L)
        out = np.zeros((v.shape[0], L), dtype=np.float32)
        for ch in range(v.shape[0]):
            out[ch] = np.interp(xi, x, v[ch])
        return out
    a = resample(stripA, L_target); b = resample(stripB, L_target)
    scores = []
    for ch in range(3):
        aa = a[ch] - a[ch].mean()
        bb = b[ch] - b[ch].mean()
        denom = (np.sqrt((aa*aa).sum()) * np.sqrt((bb*bb).sum())) + 1e-12
        scores.append(float((aa*bb).sum() / denom))
    return float(np.mean(scores))

# ----------------- SCRAMBLED↔SCRAMBLED ADJACENCY -----------------
def compute_scrambled_adjacency(scrambled_tiles, border_width=BORDER_WIDTH):
    NT = len(scrambled_tiles)
    strips = {}
    for t in scrambled_tiles:
        tid = t['id']
        for r in range(4):
            rimg = rotate_img(t['tile'], r*90)
            for e in range(4):
                strips[(tid, r, e)] = extract_edge_strip_rgb(rimg, e, width=border_width)

    adj = np.zeros((NT, NT, 4), dtype=np.float32)
    opposite = {0:2, 1:3, 2:0, 3:1}
    for i in range(NT):
        for j in range(NT):
            if i == j:
                continue
            for ei in range(4):
                ej = opposite[ei]
                best = -2.0
                for r in range(4):
                    score = rgb_ncc(strips[(i, 0, ei)], strips[(j, r, ej)])
                    if score > best:
                        best = score
                adj[i,j,ei] = best
    return adj

# ----------------- CORRECT-IMAGE EDGE STRIPS -----------------
def compute_correct_edge_strips(correct_tiles, border_width=BORDER_WIDTH):
    strips = {}
    for idx, t in enumerate(correct_tiles):
        for e in range(4):
            strips[(idx, e)] = extract_edge_strip_rgb(t['tile'], e, width=border_width)
    return strips

# ----------------- BUILD COST MATRICES & NORMALIZE -----------------
def build_costs(scrambled_tiles, correct_tiles, adj_scrambled,
                border_width=BORDER_WIDTH):
    NT = len(scrambled_tiles)
    positions = [(x,y) for y in range(N) for x in range(N)]

    # descriptors precompute
    correct_desc = [make_descriptor(t['tile']) for t in correct_tiles]

    # scrambled descriptors for all rotations
    scr_desc = {}
    scr_strip = {}
    for t in scrambled_tiles:
        tid = t['id']
        for r in range(4):
            rimg = rotate_img(t['tile'], r*90)
            scr_desc[(tid,r)] = make_descriptor(rimg)
            for e in range(4):
                scr_strip[(tid,r,e)] = extract_edge_strip_rgb(rimg, e, width=border_width)

    # precompute correct strips
    corr_strip = compute_correct_edge_strips(correct_tiles, border_width)

    # build raw maps (appearance, scrambled-adj, correct-adj, rotation penalty)
    app_raw = np.zeros((NT, NT), dtype=np.float32)   # tid x pos
    scr_raw = np.zeros((NT, NT), dtype=np.float32)   # negative NCC sums -> lower cost when high NCC
    cor_raw = np.zeros((NT, NT), dtype=np.float32)
    rot_raw = np.zeros((NT, NT), dtype=np.float32)
    best_rot_for_app = {}

    offsets = {0:(0,-1), 1:(1,0), 2:(0,1), 3:(-1,0)}
    opposite = {0:2,1:3,2:0,3:1}

    for t in scrambled_tiles:
        tid = t['id']
        for pi, pos in enumerate(positions):
            # appearance: choose best rotation minimizing descriptor distance
            best_d = 1e9; best_r = 0
            for r in range(4):
                d = np.linalg.norm(scr_desc[(tid,r)] - correct_desc[pi])
                if d < best_d:
                    best_d = d; best_r = r
            app_raw[tid, pi] = best_d
            best_rot_for_app[(tid, pi)] = best_r
            rot_raw[tid, pi] = best_r  # will be scaled later

            # correct adjacency: rotation-aware using best_r and test all neighbor positions
            gx, gy = pos
            cor_sum = 0.0
            for ei in range(4):
                nx, ny = gx + offsets[ei][0], gy + offsets[ei][1]
                if not (0 <= nx < N and 0 <= ny < N):
                    continue
                neighbor_idx = ny*N + nx
                corr_edge = corr_strip[(neighbor_idx, opposite[ei])]
                scr_edge = scr_strip[(tid, best_r, ei)]
                cor_sum += rgb_ncc(scr_edge, corr_edge)
            cor_raw[tid, pi] = -cor_sum  # negative because high NCC should lower cost

            # scrambled adjacency: check scrambled adjacency matrix symmetrically
            scr_sum = 0.0
            for ei in range(4):
                nx, ny = gx + offsets[ei][0], gy + offsets[ei][1]
                if not (0 <= nx < N and 0 <= ny < N):
                    continue
                neighbor_pos = ny*N + nx
                # use scr_adj[tid, neighbor_pos, edge] — it's the best match of neighbor (any rotation)
                scr_sum += adj_scrambled[tid, neighbor_pos, ei]
            scr_raw[tid, pi] = -scr_sum

    # Normalize each raw matrix to zero-mean unit-std to balance scales
    def znorm(mat):
        m = np.nanmean(mat)
        s = np.nanstd(mat)
        if s < 1e-9:
            return mat - m
        return (mat - m) / (s + 1e-12)

    app_n = znorm(app_raw)
    scr_n = znorm(scr_raw)
    cor_n = znorm(cor_raw)
    rot_n = znorm(rot_raw)

    # Now combine using initial relative weights but scale each component's std relative to others
    # Compute std of normalized matrices (should be ~1 but keep robust)
    s_app = np.nanstd(app_n) + 1e-12
    s_scr = np.nanstd(scr_n) + 1e-12
    s_cor = np.nanstd(cor_n) + 1e-12
    s_rot = np.nanstd(rot_n) + 1e-12

    # target weights (user set)
    Wapp = W_APP; Wscr = W_SCR; Wcor = W_COR; Wrot = ROT_PEN

    # scale factors so each weighted term contributes roughly in proportion
    # final cost = Wapp*(app_n/s_app) + Wscr*(scr_n/s_scr) + Wcor*(cor_n/s_cor) + Wrot*(rot_n/s_rot)
    app_scaled = (Wapp / s_app) * app_n
    scr_scaled = (Wscr / s_scr) * scr_n
    cor_scaled = (Wcor / s_cor) * cor_n
    rot_scaled = (Wrot / s_rot) * rot_n

    cost = app_scaled + scr_scaled + cor_scaled + rot_scaled

    # also create best_rot_choice using best_rot_for_app (as a starting rotation)
    best_rot_choice = { (tid,pi): int(best_rot_for_app[(tid,pi)]) for tid in range(NT) for pi in range(NT) }

    return cost, best_rot_choice

# ----------------- ASSIGNMENT -----------------
def assign_tiles_by_hungarian(cost, best_rot_choice):
    NT = cost.shape[0]
    positions = [(x,y) for y in range(N) for x in range(N)]
    if SCIPY_AVAILABLE:
        row_idx, col_idx = linear_sum_assignment(cost)
        grid = [[None for _ in range(N)] for __ in range(N)]
        for r, c in zip(row_idx, col_idx):
            gx, gy = positions[c]
            rot = best_rot_choice.get((int(r), int(c)), 0)
            grid[gy][gx] = (int(r), int(rot))
        return grid
    # greedy fallback
    flat = []
    P = cost.shape[1]
    for t in range(NT):
        for p in range(P):
            flat.append((cost[t,p], t, p))
    flat.sort(key=lambda x: x[0])
    used_t = set(); used_p = set()
    grid = [[None for _ in range(N)] for __ in range(N)]
    for val, t, p in flat:
        if t in used_t or p in used_p:
            continue
        gx, gy = positions[p]
        grid[gy][gx] = (int(t), int(best_rot_choice.get((t,p), 0)))
        used_t.add(t); used_p.add(p)
        if len(used_t) == NT:
            break
    return grid

# ----------------- RECONSTRUCT / EVAL -----------------
def reconstruct_from_grid(grid, scrambled_tiles, tile_size):
    th, tw = tile_size
    canvas = np.zeros((th*N, tw*N, 3), dtype=np.uint8)
    id2tile = {t['id']: t['tile'] for t in scrambled_tiles}
    for y in range(N):
        for x in range(N):
            cell = grid[y][x]
            if cell is None:
                continue
            tid, rot = cell
            img = rotate_img(id2tile[tid], rot*90)
            img = cv2.resize(img, (tw, th))
            canvas[y*th:(y+1)*th, x*tw:(x+1)*tw] = img
    return canvas

def generate_true_grid(N=N):
    k = 0
    grid = []
    for r in range(N):
        row=[]
        for c in range(N):
            row.append((k, 0)); k+=1
        grid.append(row)
    return grid

def evaluate_by_tile_ids(pred_grid, true_grid):
    total = N*N
    cp = 0; cr = 0
    for r in range(N):
        for c in range(N):
            pt, pr = pred_grid[r][c]
            tt, tr = true_grid[r][c]
            if pt == tt: cp += 1
            if pr == tr: cr += 1
    return {"placement_accuracy": cp / total, "rotation_accuracy": cr / total}

# ----------------- TOP-LEVEL SOLVER (Enhanced Hungarian Only) -----------------
def solve_one_puzzle_enhanced(scrambled_path, correct_path, use_m1=True, save_outdir=None):
    base = os.path.splitext(os.path.basename(scrambled_path))[0]
    scr = cv2.imread(scrambled_path)
    corr = cv2.imread(correct_path)
    if scr is None or corr is None:
        raise FileNotFoundError(f"Missing images for {base}")

    # prefer M1 processed images
    proc_path = None
    if use_m1:
        seg = os.path.join(M1_RESULTS_DIR, f"{base}_segmented.jpg")
        den = os.path.join(M1_RESULTS_DIR, f"{base}_denoised.jpg")
        if os.path.exists(seg):
            proc_path = seg
        elif os.path.exists(den):
            proc_path = den

    proc_img = cv2.imread(proc_path) if proc_path else scr
    scrambled_tiles, tile_size = split_into_tiles(proc_img, N)
    correct_tiles, _ = split_into_tiles(corr, N)
    correct_imgs = [t['tile'] for t in correct_tiles]

    # scrambled adjacency (tile-to-tile)
    adj_scrambled = compute_scrambled_adjacency(scrambled_tiles, border_width=BORDER_WIDTH)

    # build combined cost (appearance + adjacency terms)
    cost, best_rot_choice = build_costs(scrambled_tiles, correct_tiles, adj_scrambled, border_width=BORDER_WIDTH)

    # assign
    grid = assign_tiles_by_hungarian(cost, best_rot_choice)

    # attempt a local rotation refinement: try flipping rotation within +-1 step if adjacency improves
    # small local greedy rotation tweak
    improved = True
    iters = 0
    while improved and iters < 3:
        improved = False
        iters += 1
        for y in range(N):
            for x in range(N):
                tid, r = grid[y][x]
                best_local = None
                best_score = -1e9
                for cand_r in range(4):
                    grid[y][x] = (tid, cand_r)
                    # score adjacency sum around this tile
                    score = 0.0
                    offsets = {0:(0,-1),1:(1,0),2:(0,1),3:(-1,0)}
                    opp = {0:2,1:3,2:0,3:1}
                    for ei in range(4):
                        nx, ny = x + offsets[ei][0], y + offsets[ei][1]
                        if not (0 <= nx < N and 0 <= ny < N): continue
                        ntid, nrot = grid[ny][nx]
                        # compute NCC between edges
                        a = extract_edge_strip_rgb(rotate_img(scrambled_tiles[[t['id'] for t in scrambled_tiles].index(tid)]['tile'], cand_r*90), ei, width=BORDER_WIDTH)
                        b = extract_edge_strip_rgb(rotate_img(scrambled_tiles[[t['id'] for t in scrambled_tiles].index(ntid)]['tile'], nrot*90), opp[ei], width=BORDER_WIDTH)
                        score += rgb_ncc(a,b)
                    if score > best_score:
                        best_score = score
                        best_local = cand_r
                if best_local is not None and best_local != r:
                    grid[y][x] = (tid, best_local)
                    improved = True

    assembled = reconstruct_from_grid(grid, scrambled_tiles, tile_size)
    metrics = evaluate_by_tile_ids(grid, generate_true_grid(N))
    if save_outdir:
        outdir = os.path.join(save_outdir, base); Path(outdir).mkdir(parents=True, exist_ok=True)
        cv2.imwrite(os.path.join(outdir, base + "_assembled.png"), assembled)
        gt_resized = cv2.resize(corr, (assembled.shape[1], assembled.shape[0]))
        cv2.imwrite(os.path.join(outdir, "gt_vs_assembled.png"), np.concatenate([gt_resized, assembled], axis=1))
    return assembled, metrics, grid

# ----------------- BATCH RUNNER -----------------
def run_batch(save_outdir=OUTPUT_DIR):
    files = sorted([f for f in os.listdir(SCRAMBLED_FOLDER) if f.lower().endswith((".png",".jpg",".jpeg"))])
    rows = []
    accuracies = []
    start = time.time()

    for idx, fname in enumerate(files):
        base = os.path.splitext(fname)[0]
        s_path = os.path.join(SCRAMBLED_FOLDER, fname)
        # find GT
        c_path = None
        for ext in (".png",".jpg",".jpeg"):
            cand = os.path.join(CORRECT_FOLDER, base + ext)
            if os.path.exists(cand):
                c_path = cand; break
        if c_path is None:
            print(f"[{idx+1}/{len(files)}] Skipping {base} (no GT)")
            continue
        try:
            asm, metrics, grid = solve_one_puzzle_enhanced(s_path, c_path, use_m1=True, save_outdir=save_outdir)
            rows.append([base, metrics['placement_accuracy'], metrics['rotation_accuracy']])
            accuracies.append(metrics['placement_accuracy'])
            print(f"[{idx+1}/{len(files)}] {base}  acc={metrics['placement_accuracy']:.3f} rot={metrics['rotation_accuracy']:.3f}")
        except Exception as e:
            print(f"[{idx+1}/{len(files)}] ERROR {base}: {e}")
            rows.append([base, None, None])

    csv_path = os.path.join(save_outdir, "summary_results_best.csv")
    with open(csv_path, "w", newline="") as cf:
        w = csv.writer(cf)
        w.writerow(["base", "placement_acc", "rotation_acc"])
        w.writerows(rows)

    total_time = time.time() - start
    overall = sum(accuracies)/len(accuracies) if accuracies else 0.0
    fully_correct = sum(1 for r in rows if r[1] == 1.0)
    print("\nBatch finished.")
    print("Time (s):", total_time)
    print("Overall placement accuracy:", overall)
    print("Fully correct puzzles:", fully_correct)
    print("CSV saved to:", csv_path)
    return csv_path, overall, fully_correct

# ----------------- RUN -----------------
if __name__ == "__main__":
    run_batch()


[1/110] 0  acc=1.000 rot=0.500
[2/110] 1  acc=0.500 rot=0.000
[3/110] 10  acc=1.000 rot=0.500
[4/110] 100  acc=0.500 rot=0.250
[5/110] 101  acc=0.250 rot=0.000
[6/110] 102  acc=0.500 rot=0.250
[7/110] 103  acc=0.500 rot=0.000
[8/110] 104  acc=1.000 rot=0.250
[9/110] 105  acc=0.500 rot=0.250
[10/110] 106  acc=0.250 rot=0.000
[11/110] 107  acc=0.500 rot=0.000
[12/110] 108  acc=1.000 rot=0.000
[13/110] 109  acc=0.000 rot=0.000
[14/110] 11  acc=1.000 rot=0.250
[15/110] 12  acc=1.000 rot=0.250
[16/110] 13  acc=0.250 rot=0.250
[17/110] 14  acc=0.500 rot=0.250
[18/110] 15  acc=1.000 rot=0.000
[19/110] 16  acc=0.250 rot=0.000
[20/110] 17  acc=0.250 rot=0.000
[21/110] 18  acc=0.500 rot=0.250
[22/110] 19  acc=1.000 rot=0.000
[23/110] 2  acc=0.000 rot=0.000
[24/110] 20  acc=0.500 rot=0.000
[25/110] 21  acc=1.000 rot=0.250
[26/110] 22  acc=1.000 rot=0.000
[27/110] 23  acc=0.000 rot=0.000
[28/110] 24  acc=1.000 rot=0.250
[29/110] 25  acc=1.000 rot=0.000
[30/110] 26  acc=0.500 rot=0.000
[31/110] 27 