## Sinkhorn BLOCK OT

In [20]:
import torch
import torch.nn.functional as F
from PIL import Image
import os
import imageio.v2 as imageio

In [22]:
def load_image(path, size=(128,128)):
    img = Image.open(path).convert("RGB")
    img = img.resize(size, Image.LANCZOS)
    return np.array(img)

def save_image(array,path):
    img = Image.fromarray(array.astype(np.uint8))
    img.save(path)
    print(f"Saved : {path}")

In [24]:
def sinkhorn_transport(cost_matrix, epsilon = 0.01, n_iters=100):
    device = cost_matrix.device
    N = cost_matrix.shape[0]
    a = torch.full((N,), 1.0 / N, device=device, dtype=torch.float32)
    b = torch.full((N,), 1.0 / N, device=device, dtype=torch.float32)
    K = torch.exp(-cost_matrix.float() / float(epsilon)).float()
    K = torch.clamp(K, min=1e-12)
    u = torch.ones_like(a)
    v = torch.ones_like(b)
    for _ in range(n_iters):
        Kv = K @ v
        Kv = torch.clamp(Kv, min=1e-12)
        u = a / Kv
        KTu = K.t() @ u
        KTu = torch.clamp(KTu, min=1e-12)
        v = b / KTu
    P = (u.unsqueeze(1) * K) * v.unsqueeze(0)
    return P

In [26]:
def build_pixel_features(img_np, pos_weight=0.1):
    h, w, _ = img_np.shape
    rgb = img_np.reshape(-1, 3).astype(np.float32) / 255.0
    ys, xs = np.meshgrid(
        np.linspace(0, 1, h, endpoint=False),
        np.linspace(0, 1, w, endpoint=False),
        indexing="ij"
    )
    xs = xs.reshape(-1, 1)
    ys = ys.reshape(-1, 1)
    feats = np.concatenate(
        [rgb, pos_weight * xs, pos_weight * ys],
        axis=1
    )
    return feats

def compute_cost_matrix(src_feats, tgt_feats, device):
    src = torch.from_numpy(src_feats).to(device=device, dtype=torch.float32)
    tgt = torch.from_numpy(tgt_feats).to(device=device, dtype=torch.float32)
    src_sq = (src ** 2).sum(dim=1, keepdim=True)     
    tgt_sq = (tgt ** 2).sum(dim=1, keepdim=True).T   
    cross = src @ tgt.T                              
    cost = src_sq + tgt_sq - 2 * cross               
    cost = torch.clamp(cost, min=0.0)
    return cost

In [28]:
def transport_to_permutation(P):
    P_np = P.detach().cpu().numpy()
    N = P_np.shape[0]
    flat_indices = np.argsort(-P_np, axis=None)
    used_src = np.zeros(N, dtype=bool)
    used_tgt = np.zeros(N, dtype=bool)
    perm = -np.ones(N, dtype=int)
    for idx in flat_indices:
        i = idx // N 
        j = idx % N   
        if (not used_src[i]) and (not used_tgt[j]):
            perm[j] = i
            used_src[i] = True
            used_tgt[j] = True
        if used_tgt.all():
            break
    assert (perm >= 0).all(), "Some targets were not assigned a source pixel."
    return perm

In [30]:
def smooth_flow_field(flow, iters=200):
    """
    Smooth a displacement field using Jacobi-like relaxation.

    flow: [1, H, W, 2] tensor (dx, dy in normalized coords)
    returns: [1, H, W, 2] smoothed tensor
    """
    # Work in [N, C, H, W] layout for convenience
    f = flow.permute(0, 3, 1, 2).clone()   # [1, 2, H, W]
    base = f.clone()                       # original displacement, acts as anchor

    for _ in range(iters):
        # pad=(left, right, top, bottom)
        f_pad = F.pad(f, (1, 1, 1, 1), mode="reflect")

        up    = f_pad[:, :, 0:-2, 1:-1]
        down  = f_pad[:, :, 2:  , 1:-1]
        left  = f_pad[:, :, 1:-1, 0:-2]
        right = f_pad[:, :, 1:-1, 2:  ]

        neighbor_avg = 0.25 * (up + down + left + right)

        # blend neighbor-average with original flow so it doesn't collapse
        f = 0.5 * base + 0.5 * neighbor_avg

    # back to [1, H, W, 2]
    f = f.permute(0, 2, 3, 1)
    return f

In [34]:
def create_fluid_morph_video(
    src_np,
    perm,
    out_path="fluid_morph.mp4",
    n_frames=180,
    fps=30,
    smooth_iters=250,
    use_gpu=True,
    upsample_to=(128, 128)
):
    """
    src_np : [H,W,3] uint8 (here H=W=64)
    perm   : length N permutation
    upsample_to: final video resolution (e.g. (128,128))
    """
    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")

    h, w, _ = src_np.shape
    N = h * w

    img = torch.from_numpy(src_np).float() / 255.0
    img = img.permute(2, 0, 1).unsqueeze(0).to(device)

    ys, xs = torch.meshgrid(
        torch.arange(h, device=device, dtype=torch.float32),
        torch.arange(w, device=device, dtype=torch.float32),
        indexing="ij"
    )
    x_id = 2 * xs / (w - 1) - 1
    y_id = 2 * ys / (h - 1) - 1
    id_grid = torch.stack([x_id, y_id], dim=-1).unsqueeze(0)

    perm_t = torch.from_numpy(perm).long().to(device)
    src_idx = torch.arange(N, device=device)
    sy = torch.div(src_idx, w, rounding_mode="floor")
    sx = src_idx % w

    y_src = sy[perm_t].view(h, w)
    x_src = sx[perm_t].view(h, w)

    xt = 2 * x_src / (w - 1) - 1
    yt = 2 * y_src / (h - 1) - 1
    target_grid = torch.stack([xt, yt], dim=-1).unsqueeze(0)

    disp = target_grid - id_grid
    disp_s = smooth_flow_field(disp, iters=smooth_iters)

    frames = []

    # --- 1) hold input frame 1 sec ---
    hold_in_frames = int(1.0 * fps)
    for _ in range(hold_in_frames):
        fr = src_np.copy()
        if upsample_to is not None:
            fr = np.array(Image.fromarray(fr).resize(upsample_to, Image.NEAREST))
        frames.append(fr)

    # --- 2) morph ---
    phase_split = 0.9  # 90% fluid, 10% correction

    for fi in range(n_frames):
        raw = fi / (n_frames - 1)
        # very slow global motion
        raw_t = (raw ** 3) ** 0.35

        if raw_t < phase_split:
            t = (raw_t / phase_split) ** 0.8
            soft = min(1.0, raw_t * 4.0)
            grid = id_grid + (t * soft) * disp_s
        else:
            k = (raw_t - phase_split) / (1 - phase_split)
            k = k ** 1.8
            grid_fluid_end = id_grid + disp_s
            grid = (1 - k) * grid_fluid_end + k * target_grid

        warped = F.grid_sample(
            img, grid,
            mode="bilinear",
            padding_mode="border",
            align_corners=True
        )
        fr = warped[0].permute(1, 2, 0).detach().cpu().numpy()
        fr = np.clip(fr * 255, 0, 255).astype(np.uint8)
        if upsample_to is not None:
            fr = np.array(Image.fromarray(fr).resize(upsample_to, Image.NEAREST))
        frames.append(fr)

    # --- 3) hold final frame 1 sec ---
    hold_out_frames = int(1.0 * fps)
    src_flat = src_np.reshape(-1, 3)
    out_flat = src_flat[perm]
    out_final64 = out_flat.reshape(h, w, 3).astype(np.uint8)
    if upsample_to is not None:
        out_final = np.array(Image.fromarray(out_final64).resize(upsample_to, Image.NEAREST))
    else:
        out_final = out_final64

    for _ in range(hold_out_frames):
        frames.append(out_final.copy())

    # --- save MP4 ---
    writer = imageio.get_writer(
        out_path,
        fps=fps,
        codec="libx264",
        format="ffmpeg",
        quality=8,
        pixelformat="yuv420p"
    )
    for fr in frames:
        writer.append_data(fr)
    writer.close()
    print("Saved video:", out_path)

In [36]:
def make_patches_indices(h, w, patch_size):
    """
    Split an HxW image into patches of size patch_size x patch_size.
    Returns:
        patch_indices: list of 1D numpy arrays of pixel indices (flattened index)
    """
    patch_indices = []
    for y0 in range(0, h, patch_size):
        for x0 in range(0, w, patch_size):
            ys, xs = np.meshgrid(
                np.arange(y0, min(y0 + patch_size, h)),
                np.arange(x0, min(x0 + patch_size, w)),
                indexing="ij"
            )
            idxs = ys * w + xs  # flatten index
            patch_indices.append(idxs.reshape(-1))
    return patch_indices  # length = (h/ps)*(w/ps) if divisible


def patch_centroids(feats_np, patch_indices):
    """
    Compute the centroid (mean feature) of each patch.
    feats_np: [N, D] numpy
    patch_indices: list of arrays of indices
    Returns:
        centroids: [M, D] numpy
    """
    D = feats_np.shape[1]
    M = len(patch_indices)
    centroids = np.zeros((M, D), dtype=np.float32)
    for i, idxs in enumerate(patch_indices):
        centroids[i] = feats_np[idxs].mean(axis=0)
    return centroids

In [38]:
def multi_scale_ot_permutation(
    src128_np,
    tgt128_np,
    pos_weight=0.1,
    coarse_size=32,
    epsilon_coarse=0.02,
    n_iters_coarse=80,
    use_gpu=True
):
    """
    Improved multi-scale OT without block artifacts.

    Idea:
      1. Downsample both images to coarse_size x coarse_size (e.g. 32x32).
      2. Run full Sinkhorn OT at coarse resolution.
      3. Compute a barycentric target position for each coarse *source* pixel.
      4. Propagate that mapping to 128x128 by assigning each high-res pixel
         the barycentric coord of its coarse cell.
      5. Turn those continuous target coords into a global permutation
         by sorting source pixels by their barycentric position and matching
         them to the regular 128x128 grid order.

    Returns:
        perm128: length (128*128) permutation, perm[j] = source index
                 for target index j.
    """
    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
    H, W, _ = src128_np.shape
    assert H == 128 and W == 128, "Assumes 128x128 images."

    # ---------- 1. Coarse images ----------
    src32_np = np.array(Image.fromarray(src128_np).resize((coarse_size, coarse_size), Image.LANCZOS))
    tgt32_np = np.array(Image.fromarray(tgt128_np).resize((coarse_size, coarse_size), Image.LANCZOS))

    src32_feats = build_pixel_features(src32_np, pos_weight=pos_weight)  # [Nc, 5]
    tgt32_feats = build_pixel_features(tgt32_np, pos_weight=pos_weight)

    Nc = coarse_size * coarse_size
    print(f"Coarse OT at {coarse_size}x{coarse_size} ({Nc} pixels)…")

    # ---------- 2. Coarse OT ----------
    cost_coarse = compute_cost_matrix(src32_feats, tgt32_feats, device=device)
    P_coarse = sinkhorn_transport(cost_coarse, epsilon=epsilon_coarse, n_iters=n_iters_coarse)  # [Nc, Nc]

    # ---------- 3. Barycentric target coords for each coarse SOURCE pixel ----------
    # Build coarse target coords in [0,1]x[0,1]
    ys_c, xs_c = np.meshgrid(
        np.linspace(0.0, 1.0, coarse_size, endpoint=False),
        np.linspace(0.0, 1.0, coarse_size, endpoint=False),
        indexing="ij"
    )
    coords_c = np.stack([ys_c.reshape(-1), xs_c.reshape(-1)], axis=1).astype(np.float32)  # [Nc, 2]
    coords_c_t = torch.from_numpy(coords_c).to(device=device)

    # Normalize rows of P_coarse → barycentric weights
    row_sums = P_coarse.sum(dim=1, keepdim=True)            # [Nc,1]
    weights = P_coarse / torch.clamp(row_sums, min=1e-8)    # [Nc,Nc]

    # Barycentric target coord for each coarse source pixel
    bary_src = weights @ coords_c_t                         # [Nc, 2]

    # ---------- 4. Lift mapping to 128x128 ----------
    block = H // coarse_size                                # 128/32 = 4
    ys_hi, xs_hi = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
    cy = ys_hi // block                                     # coarse y
    cx = xs_hi // block                                     # coarse x
    coarse_index = (cy * coarse_size + cx).reshape(-1)      # [N]

    # For each high-res source pixel, assign bary coords of its coarse cell
    bary_hi = bary_src[coarse_index].detach().cpu().numpy()  # [N, 2]

    # Convert barycentric coords to continuous [0,H]x[0,W] positions
    tY = bary_hi[:, 0] * (H - 1)    # [N]
    tX = bary_hi[:, 1] * (W - 1)    # [N]

    # ---------- 5. Build a global permutation by sorting ----------
    N = H * W
    src_indices = np.arange(N, dtype=np.int64)

    # Key for sorting: y*W + x in continuous space
    src_keys = tY * W + tX          # [N]

    # Target grid is already in row-major order 0..N-1
    tgt_indices = np.arange(N, dtype=np.int64)

    # Sort both
    src_sorted = np.argsort(src_keys)
    tgt_sorted = tgt_indices        # already sorted, but keep for clarity

    perm128 = np.empty(N, dtype=np.int64)
    # Match kth sorted source → kth target
    perm128[tgt_sorted] = src_indices[src_sorted]

    return perm128


In [42]:
def transform_image_to_target(
    input_path,
    target_path,
    out_path="output_modi_128.png",
    video_path="fluid_morph_128.mp4",
    work_size=(64, 64),
    out_size=(128, 128),
    pos_weight=0.3,
    use_gpu=True
):
    # load at working resolution
    src = load_image(input_path, size=work_size)
    tgt = load_image(target_path, size=work_size)

    # OT at 64×64
    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
    src_feats = build_pixel_features(src, pos_weight=pos_weight)
    tgt_feats = build_pixel_features(tgt, pos_weight=pos_weight)

    print("Computing cost matrix at", work_size, "…")
    cost = compute_cost_matrix(src_feats, tgt_feats, device=device)
    print("Running Sinkhorn OT…")
    P = sinkhorn_transport(cost, epsilon=0.01, n_iters=150)
    perm = transport_to_permutation(P)

    # static result at 64×64, then upscale to 128×128
    flat = src.reshape(-1, 3)
    out_flat = flat[perm]
    out64 = out_flat.reshape(work_size[1], work_size[0], 3)
    out128 = np.array(Image.fromarray(out64.astype(np.uint8)).resize(out_size, Image.NEAREST))
    save_image(out128, out_path)

    # fluid morph video, directly writing 128×128 frames
    create_fluid_morph_video(
        src_np=src,
        perm=perm,
        out_path=video_path,
        n_frames=180,
        fps=30,
        smooth_iters=250,
        use_gpu=use_gpu,
        upsample_to=out_size
    )

In [44]:
inp = r"C:\Users\vaibh\Downloads\Image_created_with_a_mobile_phone.png"
out = r"C:\Users\vaibh\Downloads\WhatsApp Image 2025-11-16 at 15.12.38_1ce65b8f.jpg"

transform_image_to_target(
        input_path=inp,
        target_path=out,
        out_path=r"C:\Users\vaibh\Downloads\output_modi_multiscale.png",
        video_path=r"C:\Users\vaibh\Downloads\morph_to_modi_multiscale.mp4",
        work_size=(128, 128),    # internal OT resolution
    out_size=(128, 128),   # final image / video size
    pos_weight=0.3,
    use_gpu=True
)

Computing cost matrix at (128, 128) …
Running Sinkhorn OT…
Saved : C:\Users\vaibh\Downloads\output_modi_multiscale.png
Saved video: C:\Users\vaibh\Downloads\morph_to_modi_multiscale.mp4
