In [None]:
video without yolo clipping

In [None]:
import os, json, uuid, time, requests

# === CONFIG ===
COMFY = "http://192.168.27.13:23476"   # your ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
INPUT_IMAGE = "input.png"               # input image path
OUTPUT_DIR = "comfy_output"
GDINO_PROMPT = "clothes"
GDINO_THRESHOLD = 0.35
# ===============

os.makedirs(OUTPUT_DIR, exist_ok=True)


def upload_image_to_comfy(local_path, server=COMFY, folder_type="input"):
    """Upload image to ComfyUI input folder."""
    dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name


def patch_nodes(prompt_dict, image_name, new_prompt=None, new_threshold=None):
    """Patch LoadImage and GroundingDino nodes."""
    for node in prompt_dict.values():
        cls = node.get("class_type", "").lower()
        if cls == "loadimage":
            node["inputs"]["image"] = image_name
        if cls.startswith("groundingdinosamsegment"):
            if new_prompt: node["inputs"]["prompt"] = new_prompt
            if new_threshold: node["inputs"]["threshold"] = new_threshold


def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)


def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()


def download_output_image(filename, server=COMFY, folder_type="output", subfolder="", save_dir=OUTPUT_DIR):
    """Download result image from ComfyUI output folder."""
    os.makedirs(save_dir, exist_ok=True)
    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()
    out_path = os.path.join(save_dir, filename)
    with open(out_path, "wb") as f:
        f.write(r.content)
    return out_path


def run_sam_on_image(image_path):
    """Run SAM workflow via ComfyUI and return output image path."""
    # 1Ô∏è‚É£ Upload
    uploaded = upload_image_to_comfy(image_path)

    # 2Ô∏è‚É£ Load and patch workflow JSON
    with open(WORKFLOW_JSON, "r") as f:
        prompt = json.load(f)
    patch_nodes(prompt, uploaded, new_prompt=GDINO_PROMPT, new_threshold=GDINO_THRESHOLD)

    # 3Ô∏è‚É£ Send to ComfyUI
    prompt_id = queue_prompt(prompt)

    # 4Ô∏è‚É£ Poll for output
    print("[INFO] Waiting for ComfyUI output...")
    seg_path = None
    deadline = time.time() + 600
    while time.time() < deadline:
        hist = get_history(prompt_id)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    fn = im["filename"]
                    sub = im.get("subfolder", "")
                    typ = im.get("type", "output")
                    seg_path = download_output_image(fn, folder_type=typ, subfolder=sub)
                    break
        if seg_path:
            break
        time.sleep(1)

    if not seg_path:
        raise RuntimeError("No output image received from ComfyUI")

    print(f"[‚úÖ] SAM output saved to: {seg_path}")
    return seg_path


# ==== RUN ====
if __name__ == "__main__":
    output_image = run_sam_on_image(INPUT_IMAGE)
    print("Output:", output_image)


In [1]:
import os
import cv2
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Using:", torch.cuda.get_device_name(0))
import numpy as np
import pickle
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
GDINO_PROMPT = "clothes" # grounding dino prompt
GDINO_THRESHOLD = 0.35   # grounding dino threshold
# =================
ENLARGE_PERCENT = 0.2 
USE_REGION_AB_AVERAGE = True   # üîπ new parameter to toggle region-averaged colors
def apply_region_color_transfer(f_in, f_de, mask):
    if not USE_REGION_AB_AVERAGE:
        # default pixel-wise transfer
        out = f_in.copy()
        out[mask > 127] = f_de[mask > 127]
        return out

    # Convert both to LAB
    lab_in = cv2.cvtColor(f_in, cv2.COLOR_BGR2LAB)
    lab_de = cv2.cvtColor(f_de, cv2.COLOR_BGR2LAB)
    out_lab = lab_in.copy()

    # Find connected regions
    num_labels, labels = cv2.connectedComponents((mask > 127).astype(np.uint8))

    for lbl in range(1, num_labels):  # skip background
        region_mask = (labels == lbl)

        # Extract A/B values from DeOldify
        A_vals = lab_de[...,1][region_mask]
        B_vals = lab_de[...,2][region_mask]
        if A_vals.size == 0: continue

        meanA, meanB = int(np.mean(A_vals)), int(np.mean(B_vals))

        # Replace only A/B, keep original L
        out_lab[...,1][region_mask] = meanA
        out_lab[...,2][region_mask] = meanB

    # Convert back to BGR
    fused_bgr = cv2.cvtColor(out_lab, cv2.COLOR_LAB2BGR)

    # Apply only to mask area
    out = f_in.copy()
    out[mask > 127] = fused_bgr[mask > 127]
    return out



def patch_groundingdino_node(prompt_dict, new_prompt=None, new_threshold=None):
    """Patch GroundingDinoSAMSegment node with new prompt/threshold values."""
    for node in prompt_dict.values():
        if node.get("class_type", "").lower().startswith("groundingdinosamsegment"):
            if new_prompt is not None:
                node["inputs"]["prompt"] = new_prompt
            if new_threshold is not None:
                node["inputs"]["threshold"] = new_threshold
            return True
    return False
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name

def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False

def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None, save_dir=None):
    if save_dir is None:
        save_dir = os.path.join(OUTPUT_ROOT, "comfy_downloads")
    os.makedirs(save_dir, exist_ok=True)

    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()

    if to_path is None:
        to_path = os.path.join(save_dir, filename)

    with open(to_path, "wb") as f:
        f.write(r.content)

    return to_path


def run_sam_on_frame(frame_path, comfy_server=COMFY):
    """Send one frame through ComfyUI workflow and return saved ComfyUI_* path."""
    uploaded = upload_image_to_comfy(frame_path, server=comfy_server)

    # Load workflow JSON
    with open(WORKFLOW_JSON, "r") as f:
        prompt = json.load(f)

    # Patch LoadImage node
    if not patch_loadimage_node(prompt, uploaded):
        raise RuntimeError("Could not patch LoadImage node in workflow JSON.")

    patch_groundingdino_node(prompt, new_prompt=GDINO_PROMPT, new_threshold=GDINO_THRESHOLD)
    # Queue
    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time() + 600  # up to 10 min per frame
    seg_path = None

    # Poll history until output is ready
    while time.time() < deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    fn = im["filename"]
                    sub = im.get("subfolder", "")
                    typ = im.get("type", "output")
                    # Save as ComfyUI_<originalname>.png in same folder
                    base = os.path.splitext(os.path.basename(frame_path))[0]
                    save_dir = os.path.dirname(frame_path)
                    out_path = os.path.join(save_dir, f"ComfyUI_{base}.png")
                    seg_path = download_image(fn, server=comfy_server,
                                              subfolder=sub, folder_type=typ,
                                              to_path=out_path)
                    break
        if seg_path:
            break
        time.sleep(0.5)

    if not seg_path:
        raise RuntimeError(f"No outputs from ComfyUI for {frame_path}")

    return seg_path


# ---- Stage 3: SAM (extract ‚Üí ComfyUI per frame ‚Üí video) ----
def run_sam(input_path, out_dir, name):
    sam_frames_dir = os.path.join(out_dir, f"{name}_sam_frames")
    os.makedirs(sam_frames_dir, exist_ok=True)
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")

    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM video: {sam_path}")
        return sam_path

    # --- Step 1: Extract all frames first ---
    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    idx = 0
    with tqdm(total=total_frames, desc="Extracting frames", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_file = os.path.join(sam_frames_dir, f"frame_{idx:06d}.png")
            if not os.path.exists(frame_file):
                cv2.imwrite(frame_file, frame_bgr)
            idx += 1
            pbar.update(1)
    cap.release()

    # --- Step 2: Run ComfyUI on each frame ---
    frame_files = sorted([f for f in os.listdir(sam_frames_dir) if f.startswith("frame_")])
    for fname in tqdm(frame_files, desc="Processing with ComfyUI", unit="frame"):
        frame_file = os.path.join(sam_frames_dir, fname)
        out_file = os.path.join(sam_frames_dir, f"ComfyUI_{fname}")
        if os.path.exists(out_file):
            continue
        try:
            run_sam_on_frame(frame_file, comfy_server=COMFY)
        except Exception as e:
            print(f"‚ö†Ô∏è SAM failed on {fname}: {e}")

    # --- Step 3: Collect only ComfyUI_* frames ---
    sam_files = sorted([f for f in os.listdir(sam_frames_dir) if f.startswith("ComfyUI_")])

    # --- Step 4: Combine into video ---
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))

    for seg_file in tqdm(sam_files, desc="Building SAM video", unit="frame"):
        img = cv2.imread(os.path.join(sam_frames_dir, seg_file))
        if img is None:
            continue
        img_resized = cv2.resize(img, (width, height))
        writer.write(img_resized)

    writer.release()
    print(f"[INFO] SAM video saved: {sam_path}")
    return sam_path



# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    results_path = os.path.join(out_dir, f"{name}_yolo_results.pkl")

    if os.path.exists(yolo_path) and os.path.exists(results_path):
        print(f"[CACHE] Using cached YOLO + results: {yolo_path}")
        with open(results_path, "rb") as f:
            results_per_frame = pickle.load(f)
        return yolo_path, results_per_frame

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))

    results_per_frame = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            results_per_frame.append({
                "boxes": results[0].boxes.xyxy.cpu().numpy(),
                "conf": results[0].boxes.conf.cpu().numpy(),
                "cls": results[0].boxes.cls.cpu().numpy()
            })
            pbar.update(1)

    cap.release()
    writer.release()

    with open(results_path, "wb") as f:
        pickle.dump(results_per_frame, f)

    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path



def run_deoldify(input_path, out_dir, name, force_bw=True):
    """
    Run DeOldify on a video.
    
    Args:
        input_path (str): Path to input video.
        out_dir (str): Output directory.
        name (str): Base name for output file.
        force_bw (bool): If True, convert input frames to grayscale before DeOldify.
                         If False, feed original color frames (default = True).
    """
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            if force_bw:
                # üîπ Convert to grayscale first, then back to 3-channel RGB
                frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
                frame_rgb = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2RGB)
            else:
                # üîπ Use original color frame
                frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)

    cap.release()
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM (with frame folder + resume support) ----






# ---- Stage 4: Fusion ----
def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)      # original frames
    cap_deold = cv2.VideoCapture(deoldify_path)   # deoldify video
    cap_sam = cv2.VideoCapture(sam_path)          # sam masks

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width,height))

    total_frames = int(min(
        len(yolo_results),
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_in = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            # SAM mask
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # YOLO mask
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # Intersection
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # Fusion: base is ORIGINAL frame
            fusion_frame = frame_in.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path


def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    """
    Fusion step with YOLO bounding box enlargement using global ENLARGE_PERCENT.
    """
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)
    cap_deold = cv2.VideoCapture(deoldify_path)
    cap_sam = cv2.VideoCapture(sam_path)

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        len(yolo_results),
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_in = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            # --- SAM mask ---
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # --- YOLO mask with global enlargement ---
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)

                # expand box by percentage
                dx = int((x2 - x1) * ENLARGE_PERCENT)
                dy = int((y2 - y1) * ENLARGE_PERCENT)

                x1 = max(0, x1 - dx)
                y1 = max(0, y1 - dy)
                x2 = min(width,  x2 + dx)
                y2 = min(height, y2 + dy)

                yolo_mask[y1:y2, x1:x2] = 255

            # --- Intersection ---
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # # --- Fusion ---
            # fusion_frame = frame_in.copy()
            # fusion_frame[mask_bool] = frame_deold[mask_bool]

            # writer.write(fusion_frame)

            # --- Fusion with optional region-averaged A/B ---
            fusion_frame = apply_region_color_transfer(frame_in, frame_deold, intersect)
            writer.write(fusion_frame)

            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path



# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }

def process_video_cached(input_path):
    """
    Cached wrapper around process_video().
    Returns only the final fusion video path.
    """
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    final_path = os.path.join(out_dir, f"{name}_final.mp4")

    if os.path.exists(final_path):
        print(f"[CACHE] Final output exists: {final_path}")
        return final_path

    outputs = process_video(input_path)
    return outputs["final"]


# # ---- Example ----
# if __name__ == "__main__":
#     input_video = "input_videos/thatha_manavadu_test.mp4"
#     outputs = process_video(input_video)
#     print("Pipeline outputs:")
#     for k, v in outputs.items():
#         print(f" - {k}: {v}")


CUDA available: True
CUDA device count: 1
Using: NVIDIA GeForce RTX 5060 Ti


  import pkg_resources
  warn("""Your validation set is empty. If this is by design, use `split_none()`


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /workspace/models/torch/hub/checkpoints/resnet34-b627a593.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 83.3M/83.3M [00:01<00:00, 44.6MB/s]
  WeightNorm.apply(module, name, dim)


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL functools.partial was not an allowed global by default. Please use `torch.serialization.add_safe_globals([functools.partial])` or the `torch.serialization.safe_globals([functools.partial])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [None]:
image without yolo clipping

In [7]:
GDINO_PROMPT = "clothes" # grounding dino prompt
GDINO_THRESHOLD = 0.30   # grounding dino threshold

USE_REGION_AB_AVERAGE = True     # Use DeOldify‚Äôs averaged A/B if True
USE_CUSTOM_AB = True            # If True, override with fixed values
CUSTOM_A = 128                # Example fixed A channel value (0‚Äì255)
CUSTOM_B = 115                  # Example fixed B channel value (0‚Äì255)


def apply_region_color_transfer(f_in, f_de, mask):
    if not (USE_REGION_AB_AVERAGE or USE_CUSTOM_AB):
        # default per-pixel DeOldify transfer
        out = f_in.copy()
        out[mask > 127] = f_de[mask > 127]
        return out

    lab_in = cv2.cvtColor(f_in, cv2.COLOR_BGR2LAB)
    lab_de = cv2.cvtColor(f_de, cv2.COLOR_BGR2LAB)
    out_lab = lab_in.copy()

    num_labels, labels = cv2.connectedComponents((mask > 127).astype(np.uint8))
    for lbl in range(1, num_labels):
        region_mask = (labels == lbl)

        if USE_CUSTOM_AB:
            # üîπ Force custom values for all regions
            meanA, meanB = CUSTOM_A, CUSTOM_B
        else:
            # üîπ Average DeOldify‚Äôs A/B values
            A_vals = lab_de[...,1][region_mask]
            B_vals = lab_de[...,2][region_mask]
            if A_vals.size == 0: 
                continue
            meanA, meanB = int(np.mean(A_vals)), int(np.mean(B_vals))

        out_lab[...,1][region_mask] = meanA
        out_lab[...,2][region_mask] = meanB

    fused_bgr = cv2.cvtColor(out_lab, cv2.COLOR_LAB2BGR)
    out = f_in.copy()
    out[mask > 127] = fused_bgr[mask > 127]
    return out


def process_image(input_image):
    folder, fname = os.path.split(input_image)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    # ---- Load frame ----
    frame_bgr = cv2.imread(input_image)
    if frame_bgr is None:
        raise FileNotFoundError(f"Could not load image {input_image}")
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    h, w = frame_bgr.shape[:2]

    # ---- Stage 1: YOLO ----
    yolo_results = yolo_model.predict(frame_bgr, conf=CONF_THRESHOLD, verbose=False, device=device_str)
    yolo_frame = yolo_results[0].plot()
    yolo_path = os.path.join(out_dir, f"{name}_yolo.png")
    cv2.imwrite(yolo_path, yolo_frame)

    results_dict = {
        "boxes": yolo_results[0].boxes.xyxy.cpu().numpy(),
        "conf": yolo_results[0].boxes.conf.cpu().numpy(),
        "cls": yolo_results[0].boxes.cls.cpu().numpy()
    }
    with open(os.path.join(out_dir, f"{name}_yolo_results.pkl"), "wb") as f:
        pickle.dump([results_dict], f)

    # ---- Stage 2: DeOldify ----
    f_gray = cv2.cvtColor(frame_rgb, cv2.COLOR_BGR2GRAY)
    f_gray3 = cv2.cvtColor(f_gray, cv2.COLOR_GRAY2RGB)
    deold = deoldify_inference(f_gray3)
    deold_bgr = cv2.cvtColor(deold, cv2.COLOR_RGB2BGR)
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.png")
    cv2.imwrite(deoldify_path, deold_bgr)

    # ---- Stage 3: SAM (ComfyUI workflow on full image) ----
    try:
        seg_path = run_sam_on_frame(input_image, comfy_server=COMFY)
        seg_img = cv2.imread(seg_path)
        sam_gray = cv2.cvtColor(seg_img, cv2.COLOR_BGR2GRAY)
        _, sam_mask = cv2.threshold(sam_gray, 1, 255, cv2.THRESH_BINARY)
    except Exception as e:
        print(f"‚ö†Ô∏è SAM failed on image: {e}")
        sam_mask = np.zeros((h, w), dtype=np.uint8)

    sam_path = os.path.join(out_dir, f"{name}_sam.png")
    cv2.imwrite(sam_path, sam_mask)

    # ---- Stage 4: Fusion (YOLO ‚à© SAM mask) ----
    yolo_mask = np.zeros((h, w), dtype=np.uint8)
    for box, conf, cls in zip(results_dict["boxes"], results_dict["conf"], results_dict["cls"]):
        if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
            continue
        x1, y1, x2, y2 = map(int, box)
        yolo_mask[y1:y2, x1:x2] = 255

    intersect = cv2.bitwise_and(sam_mask, yolo_mask)
    mask_bool = intersect > 127

    # fusion_frame = frame_bgr.copy()
    # fusion_frame[mask_bool] = deold_bgr[mask_bool]

    fusion_frame = apply_region_color_transfer(frame_bgr, deold_bgr, intersect)


    fusion_path = os.path.join(out_dir, f"{name}_final.png")
    cv2.imwrite(fusion_path, fusion_frame)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# ---- Example ----
if __name__ == "__main__":
    input_image ="input_videos/first_frame.jpg"
    outputs = process_image(input_image)
    print("Pipeline outputs:")
    for k, v in outputs.items():
        print(f" - {k}: {v}")


[INFO] Outputs written to outputs/first_frame
Pipeline outputs:
 - yolo: outputs/first_frame/first_frame_yolo.png
 - deoldify: outputs/first_frame/first_frame_deoldify.png
 - sam: outputs/first_frame/first_frame_sam.png
 - final: outputs/first_frame/first_frame_final.png


In [None]:
import cv2

# Read an image (BGR format by default in OpenCV)
img_bgr = cv2.imread("input_videos/input_100.jpg")

# Convert BGR ‚Üí Grayscale
img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)

# Save the grayscale image
cv2.imwrite("input_videos/input_100_g.jpg", img_gray)

# (Optional) Display images
cv2.imshow("Original BGR", img_bgr)
cv2.imshow("Grayscale", img_gray)
cv2.waitKey(0)
cv2.destroyAllWindows()


In [None]:
video with yolo clipping

In [3]:
import os
import cv2
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Using:", torch.cuda.get_device_name(0))
import numpy as np
import pickle
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time



# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"
WORKFLOW_JSON = "ClothesDetect_api.json"

# ---- NEW CONFIG ----
BBOX_ENLARGE = 0.2       # enlarge bbox by 20%
TOP_K_BBOX = 6           # number of top boxes per frame to run SAM
GDINO_PROMPT = "clothes" # grounding dino prompt
GDINO_THRESHOLD = 0.35   # grounding dino threshold
# =================


def patch_groundingdino_node(prompt_dict, new_prompt=None, new_threshold=None):
    """Patch GroundingDinoSAMSegment node with new prompt/threshold values."""
    for node in prompt_dict.values():
        if node.get("class_type", "").lower().startswith("groundingdinosamsegment"):
            if new_prompt is not None:
                node["inputs"]["prompt"] = new_prompt
            if new_threshold is not None:
                node["inputs"]["threshold"] = new_threshold
            return True
    return False



# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name





def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False



def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None, save_dir=None):
    if save_dir is None:
        save_dir = os.path.join(OUTPUT_ROOT, "comfy_downloads")
    os.makedirs(save_dir, exist_ok=True)

    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()

    if to_path is None:
        to_path = os.path.join(save_dir, filename)

    with open(to_path, "wb") as f:
        f.write(r.content)

    return to_path


def run_sam_on_frame(frame_path, comfy_server=COMFY):
    """Send one frame (crop) through ComfyUI workflow and return saved mask path."""
    uploaded = upload_image_to_comfy(frame_path, server=comfy_server)

    with open(WORKFLOW_JSON, "r") as f:
        prompt = json.load(f)

    # Patch LoadImage node
    if not patch_loadimage_node(prompt, uploaded):
        raise RuntimeError("Could not patch LoadImage node in workflow JSON.")

    # üîπ Patch GroundingDino node with dynamic prompt & threshold
    patch_groundingdino_node(prompt, new_prompt=GDINO_PROMPT, new_threshold=GDINO_THRESHOLD)

    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time() + 600
    seg_path = None

    while time.time() < deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    fn = im["filename"]
                    sub = im.get("subfolder", "")
                    typ = im.get("type", "output")
                    base = os.path.splitext(os.path.basename(frame_path))[0]
                    save_dir = os.path.dirname(frame_path)
                    out_path = os.path.join(save_dir, f"ComfyUI_{base}.png")
                    seg_path = download_image(fn, server=comfy_server,
                                              subfolder=sub, folder_type=typ,
                                              to_path=out_path)
                    break
        if seg_path:
            break
        time.sleep(0.5)

    if not seg_path:
        raise RuntimeError(f"No outputs from ComfyUI for {frame_path}")

    return seg_path

# ---- Stage 3: SAM with YOLO BBoxes ----
def run_sam(input_path, out_dir, name, yolo_results):
    sam_frames_dir = os.path.join(out_dir, f"{name}_sam_frames")
    os.makedirs(sam_frames_dir, exist_ok=True)
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")

    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM video: {sam_path}")
        return sam_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    frame_idx = 0
    with tqdm(total=total_frames, desc="SAM with YOLO", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            results = yolo_results[frame_idx]
            boxes, confs, clses = results["boxes"], results["conf"], results["cls"]

            order = np.argsort(confs)[::-1][:TOP_K_BBOX]
            masks_for_frame = []

            for i in order:
                if int(clses[i]) != 0 or confs[i] < CONF_THRESHOLD:
                    continue
                x1, y1, x2, y2 = map(int, boxes[i])

                # enlarge box
                bw = x2 - x1
                bh = y2 - y1
                x1 = max(0, int(x1 - BBOX_ENLARGE * bw))
                y1 = max(0, int(y1 - BBOX_ENLARGE * bh))
                x2 = min(width, int(x2 + BBOX_ENLARGE * bw))
                y2 = min(height, int(y2 + BBOX_ENLARGE * bh))

                crop = frame_bgr[y1:y2, x1:x2]
                if crop.size == 0:
                    continue

                crop_path = os.path.join(sam_frames_dir, f"frame_{frame_idx:06d}_box{i}.png")
                cv2.imwrite(crop_path, crop)

                try:
                    seg_path = run_sam_on_frame(crop_path, comfy_server=COMFY)
                    seg_img = cv2.imread(seg_path)
                    seg_resized = cv2.resize(seg_img, (x2 - x1, y2 - y1))
                    mask = np.zeros((height, width), dtype=np.uint8)
                    mask[y1:y2, x1:x2] = cv2.cvtColor(seg_resized, cv2.COLOR_BGR2GRAY)
                    masks_for_frame.append(mask)
                except Exception as e:
                    print(f"‚ö†Ô∏è SAM failed on frame {frame_idx}, box {i}: {e}")
                finally:
                    # cleanup crop + box-level SAM output
                    if os.path.exists(crop_path):
                        os.remove(crop_path)
                    box_seg = crop_path.replace(".png", "").replace("frame_", "ComfyUI_frame_") + ".png"
                    if os.path.exists(box_seg):
                        os.remove(box_seg)

            # always save a mask (blank if no detections)
            if masks_for_frame:
                final_mask = np.zeros((height, width), dtype=np.uint8)
                for m in masks_for_frame:
                    final_mask = cv2.bitwise_or(final_mask, m)
            else:
                final_mask = np.zeros((height, width), dtype=np.uint8)

            out_path = os.path.join(sam_frames_dir, f"ComfyUI_frame_{frame_idx:06d}.png")
            cv2.imwrite(out_path, final_mask)

            frame_idx += 1
            pbar.update(1)
    cap.release()

    # build video ONLY from final per-frame masks
    sam_files = sorted([
        f for f in os.listdir(sam_frames_dir)
        if f.startswith("ComfyUI_frame_") and "_box" not in f
    ])
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))
    for seg_file in tqdm(sam_files, desc="Building SAM video", unit="frame"):
        img = cv2.imread(os.path.join(sam_frames_dir, seg_file))
        writer.write(cv2.resize(img, (width, height)))
    writer.release()
    return sam_path


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    results_path = os.path.join(out_dir, f"{name}_yolo_results.pkl")

    if os.path.exists(yolo_path) and os.path.exists(results_path):
        print(f"[CACHE] Using cached YOLO + results: {yolo_path}")
        with open(results_path, "rb") as f:
            results_per_frame = pickle.load(f)
        return yolo_path, results_per_frame

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))

    results_per_frame = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            results_per_frame.append({
                "boxes": results[0].boxes.xyxy.cpu().numpy(),
                "conf": results[0].boxes.conf.cpu().numpy(),
                "cls": results[0].boxes.cls.cpu().numpy()
            })
            pbar.update(1)

    cap.release()
    writer.release()

    with open(results_path, "wb") as f:
        pickle.dump(results_per_frame, f)

    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path



def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            # üîπ Convert frame to grayscale, then back to 3-channel RGB
            frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
            frame_rgb = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2RGB)

            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)

    cap.release()
    writer.release()
    return deoldify_path



# ---- Stage 4: Fusion ----
def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)
    cap_deold = cv2.VideoCapture(deoldify_path)
    cap_sam = cv2.VideoCapture(sam_path)

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        len(yolo_results),
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_in = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
            mask_bool = sam_mask > 127

            fusion_frame = frame_in.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name, yolo_results)
    fusion_path = run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }



def process_video_cached(input_path):
    """
    Cached wrapper around process_video().
    Returns only the final fusion video path.
    """
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    final_path = os.path.join(out_dir, f"{name}_final.mp4")

    if os.path.exists(final_path):
        print(f"[CACHE] Final output exists: {final_path}")
        return final_path

    outputs = process_video(input_path)
    return outputs["final"]


# ---- Example ----
# if __name__ == "__main__":
#     input_video = "input_videos/thatha_manavadu_test.mp4"
#     outputs = process_video(input_video)
#     print("Pipeline outputs:")
#     for k, v in outputs.items():
#         print(f" - {k}: {v}")





CUDA available: True
CUDA device count: 1
Using: NVIDIA GeForce RTX 4060 Ti
[INFO] Using device: cuda


In [4]:
image with yolo clipping

SyntaxError: invalid syntax (4134125518.py, line 1)

In [10]:
GDINO_PROMPT = "arms" # grounding dino prompt
GDINO_THRESHOLD = 0.27   # grounding dino threshold


def process_image(input_image):
    folder, fname = os.path.split(input_image)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    # ---- Load frame ----
    frame_bgr = cv2.imread(input_image)
    if frame_bgr is None:
        raise FileNotFoundError(f"Could not load image {input_image}")
    frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    h, w = frame_bgr.shape[:2]

    # ---- Stage 1: YOLO ----
    yolo_results = yolo_model.predict(frame_bgr, conf=CONF_THRESHOLD, verbose=False, device=device_str)
    yolo_frame = yolo_results[0].plot()
    yolo_path = os.path.join(out_dir, f"{name}_yolo.png")
    cv2.imwrite(yolo_path, yolo_frame)

    results_dict = {
        "boxes": yolo_results[0].boxes.xyxy.cpu().numpy(),
        "conf": yolo_results[0].boxes.conf.cpu().numpy(),
        "cls": yolo_results[0].boxes.cls.cpu().numpy()
    }
    with open(os.path.join(out_dir, f"{name}_yolo_results.pkl"), "wb") as f:
        pickle.dump([results_dict], f)

    # ---- Stage 2: DeOldify ----
    f_gray = cv2.cvtColor(frame_rgb, cv2.COLOR_BGR2GRAY)
    f_gray3 = cv2.cvtColor(f_gray, cv2.COLOR_GRAY2RGB)

    deold = deoldify_inference(f_gray3)
    deold_bgr = cv2.cvtColor(deold, cv2.COLOR_RGB2BGR)
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.png")
    cv2.imwrite(deoldify_path, deold_bgr)

    # ---- Stage 3: SAM ----
    sam_frames_dir = os.path.join(out_dir, f"{name}_sam")
    os.makedirs(sam_frames_dir, exist_ok=True)

    boxes, confs, clses = results_dict["boxes"], results_dict["conf"], results_dict["cls"]
    order = np.argsort(confs)[::-1][:TOP_K_BBOX]
    masks_for_frame = []

    for i in order:
        if int(clses[i]) != 0 or confs[i] < CONF_THRESHOLD:
            continue
        x1, y1, x2, y2 = map(int, boxes[i])
        bw, bh = x2 - x1, y2 - y1
        x1 = max(0, int(x1 - BBOX_ENLARGE * bw))
        y1 = max(0, int(y1 - BBOX_ENLARGE * bh))
        x2 = min(w, int(x2 + BBOX_ENLARGE * bw))
        y2 = min(h, int(y2 + BBOX_ENLARGE * bh))

        crop = frame_bgr[y1:y2, x1:x2]
        if crop.size == 0:
            continue

        crop_path = os.path.join(sam_frames_dir, f"{name}_box{i}.png")
        cv2.imwrite(crop_path, crop)

        try:
            seg_path = run_sam_on_frame(crop_path, comfy_server=COMFY)
            seg_img = cv2.imread(seg_path)

            
            # seg_resized = cv2.resize(seg_img, (x2 - x1, y2 - y1))
            # mask = np.zeros((h, w), dtype=np.uint8)
            # mask[y1:y2, x1:x2] = cv2.cvtColor(seg_resized, cv2.COLOR_BGR2GRAY)
            # masks_for_frame.append(mask)
            seg_resized = cv2.resize(seg_img, (x2 - x1, y2 - y1))
            seg_gray = cv2.cvtColor(seg_resized, cv2.COLOR_BGR2GRAY)
            
            # üîπ Add this line to force binary
            _, seg_bin = cv2.threshold(seg_gray, 1, 255, cv2.THRESH_BINARY)
            
            mask = np.zeros((h, w), dtype=np.uint8)
            mask[y1:y2, x1:x2] = seg_bin
            masks_for_frame.append(mask)

        except Exception as e:
            print(f"SAM failed on box {i}: {e}")

    # Save final mask (blank if none)
    if masks_for_frame:
        final_mask = np.zeros((h, w), dtype=np.uint8)
        for m in masks_for_frame:
            final_mask = cv2.bitwise_or(final_mask, m)
    else:
        final_mask = np.zeros((h, w), dtype=np.uint8)

    sam_path = os.path.join(out_dir, f"{name}_sam.png")
    cv2.imwrite(sam_path, final_mask)

    # ---- Stage 4: Fusion (exact same logic as run_fusion) ----
    gray = final_mask.copy()
    _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
    mask_bool = sam_mask > 127

    fusion_frame = frame_bgr.copy()
    fusion_frame[mask_bool] = deold_bgr[mask_bool]

    fusion_path = os.path.join(out_dir, f"{name}_final.png")
    cv2.imwrite(fusion_path, fusion_frame)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }



# ---- Example ----
if __name__ == "__main__":
    input_image = r"input_videos/12th_from_last_frame.jpg"
    outputs = process_image(input_image)
    print("Pipeline outputs:")
    for k, v in outputs.items():
        print(f" - {k}: {v}")

[INFO] Outputs written to outputs/12th_from_last_frame
Pipeline outputs:
 - yolo: outputs/12th_from_last_frame/12th_from_last_frame_yolo.png
 - deoldify: outputs/12th_from_last_frame/12th_from_last_frame_deoldify.png
 - sam: outputs/12th_from_last_frame/12th_from_last_frame_sam.png
 - final: outputs/12th_from_last_frame/12th_from_last_frame_final.png


In [None]:
image with yolo clipping

In [23]:
# GDINO_PROMPT = "clothes" # grounding dino prompt
# GDINO_THRESHOLD = 0.30   # grounding dino threshold


# def process_image(input_image):
#     folder, fname = os.path.split(input_image)
#     name, _ = os.path.splitext(fname)
#     out_dir = os.path.join(OUTPUT_ROOT, name)
#     os.makedirs(out_dir, exist_ok=True)

#     # ---- Load frame ----
#     frame_bgr = cv2.imread(input_image)
#     if frame_bgr is None:
#         raise FileNotFoundError(f"Could not load image {input_image}")
#     frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
#     h, w = frame_bgr.shape[:2]

#     # ---- Stage 1: YOLO ----
#     yolo_results = yolo_model.predict(frame_bgr, conf=CONF_THRESHOLD, verbose=False, device=device_str)
#     yolo_frame = yolo_results[0].plot()
#     yolo_path = os.path.join(out_dir, f"{name}_yolo.png")
#     cv2.imwrite(yolo_path, yolo_frame)

#     results_dict = {
#         "boxes": yolo_results[0].boxes.xyxy.cpu().numpy(),
#         "conf": yolo_results[0].boxes.conf.cpu().numpy(),
#         "cls": yolo_results[0].boxes.cls.cpu().numpy()
#     }
#     with open(os.path.join(out_dir, f"{name}_yolo_results.pkl"), "wb") as f:
#         pickle.dump([results_dict], f)

#     # ---- Stage 2: DeOldify ----
#     deold = deoldify_inference(frame_rgb)
#     deold_bgr = cv2.cvtColor(deold, cv2.COLOR_RGB2BGR)
#     deoldify_path = os.path.join(out_dir, f"{name}_deoldify.png")
#     cv2.imwrite(deoldify_path, deold_bgr)

#     # ---- Stage 3: SAM ----
#     sam_frames_dir = os.path.join(out_dir, f"{name}_sam")
#     os.makedirs(sam_frames_dir, exist_ok=True)

#     boxes, confs, clses = results_dict["boxes"], results_dict["conf"], results_dict["cls"]
#     order = np.argsort(confs)[::-1][:TOP_K_BBOX]
#     masks_for_frame = []

#     for i in order:
#         if int(clses[i]) != 0 or confs[i] < CONF_THRESHOLD:
#             continue
#         x1, y1, x2, y2 = map(int, boxes[i])
#         bw, bh = x2 - x1, y2 - y1
#         x1 = max(0, int(x1 - BBOX_ENLARGE * bw))
#         y1 = max(0, int(y1 - BBOX_ENLARGE * bh))
#         x2 = min(w, int(x2 + BBOX_ENLARGE * bw))
#         y2 = min(h, int(y2 + BBOX_ENLARGE * bh))

#         crop = frame_bgr[y1:y2, x1:x2]
#         if crop.size == 0:
#             continue

#         crop_path = os.path.join(sam_frames_dir, f"{name}_box{i}.png")
#         cv2.imwrite(crop_path, crop)

#         try:
#             seg_path = run_sam_on_frame(crop_path, comfy_server=COMFY)
#             seg_img = cv2.imread(seg_path)
#             seg_resized = cv2.resize(seg_img, (x2 - x1, y2 - y1))
#             mask = np.zeros((h, w), dtype=np.uint8)
#             mask[y1:y2, x1:x2] = cv2.cvtColor(seg_resized, cv2.COLOR_BGR2GRAY)
#             masks_for_frame.append(mask)
#         except Exception as e:
#             print(f"SAM failed on box {i}: {e}")

#     # Save final mask (blank if none)
#     if masks_for_frame:
#         final_mask = np.zeros((h, w), dtype=np.uint8)
#         for m in masks_for_frame:
#             final_mask = cv2.bitwise_or(final_mask, m)
#     else:
#         final_mask = np.zeros((h, w), dtype=np.uint8)

#     sam_path = os.path.join(out_dir, f"{name}_sam.png")
#     cv2.imwrite(sam_path, final_mask)

#     # ---- Stage 4: Fusion (exact same logic as run_fusion) ----
#     gray = final_mask.copy()
#     _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
#     mask_bool = sam_mask > 127

#     fusion_frame = frame_bgr.copy()
#     fusion_frame[mask_bool] = deold_bgr[mask_bool]

#     fusion_path = os.path.join(out_dir, f"{name}_final.png")
#     cv2.imwrite(fusion_path, fusion_frame)

#     print(f"[INFO] Outputs written to {out_dir}")
#     return {
#         "yolo": yolo_path,
#         "deoldify": deoldify_path,
#         "sam": sam_path,
#         "final": fusion_path
#     }



# # ---- Example ----
# if __name__ == "__main__":
#     input_image = r"input_videos/first_frame.jpg"
#     outputs = process_image(input_image)
#     print("Pipeline outputs:")
#     for k, v in outputs.items():
#         print(f" - {k}: {v}")

In [None]:
import os
import cv2
import torch
import numpy as np
import pickle
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name

def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False

def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None, save_dir=None):
    if save_dir is None:
        save_dir = os.path.join(OUTPUT_ROOT, "comfy_downloads")
    os.makedirs(save_dir, exist_ok=True)

    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()

    if to_path is None:
        to_path = os.path.join(save_dir, filename)

    with open(to_path, "wb") as f:
        f.write(r.content)

    return to_path

def run_sam_on_frame(frame_rgb, comfy_server=COMFY):
    tmp_path = f"temp_frame_{uuid.uuid4().hex}.png"
    Image.fromarray(frame_rgb).save(tmp_path)
    uploaded = upload_image_to_comfy(tmp_path, server=comfy_server)

    with open(WORKFLOW_JSON, "r") as f:
        prompt = json.load(f)
    patch_loadimage_node(prompt, uploaded)

    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time() + 60
    seg_path = None
    while time.time() < deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    seg_path = download_image(
                        im["filename"], server=comfy_server,
                        subfolder=im.get("subfolder", ""),
                        folder_type=im.get("type", "output")
                    )
                    break
        if seg_path: break
        time.sleep(0.5)
    os.remove(tmp_path)
    return seg_path


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    frames_dir = os.path.join(out_dir, f"{name}_yolo_frames")
    results_path = os.path.join(out_dir, f"{name}_yolo_results.pkl")
    os.makedirs(frames_dir, exist_ok=True)

    if os.path.exists(yolo_path) and os.path.exists(results_path):
        print(f"[CACHE] Using cached YOLO: {yolo_path}")
        with open(results_path, "rb") as f:
            results_per_frame = pickle.load(f)
        return yolo_path, results_per_frame

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    results_per_frame = []

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        for frame_idx in range(total_frames):
            ret, frame = cap.read()
            if not ret:
                break
            frame_path = os.path.join(frames_dir, f"frame_{frame_idx:05d}.png")

            if os.path.exists(frame_path):
                results_per_frame.append(None)  # results loaded separately if needed
                pbar.update(1)
                continue

            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            plotted = results[0].plot()
            cv2.imwrite(frame_path, plotted)

            results_per_frame.append({
                "boxes": results[0].boxes.xyxy.cpu().numpy(),
                "conf": results[0].boxes.conf.cpu().numpy(),
                "cls": results[0].boxes.cls.cpu().numpy()
            })
            pbar.update(1)
    cap.release()

    # Save video
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))
    for frame_idx in range(total_frames):
        img = cv2.imread(os.path.join(frames_dir, f"frame_{frame_idx:05d}.png"))
        if img is not None:
            writer.write(img)
    writer.release()

    with open(results_path, "wb") as f:
        pickle.dump(results_per_frame, f)

    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    frames_dir = os.path.join(out_dir, f"{name}_deoldify_frames")
    os.makedirs(frames_dir, exist_ok=True)

    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        for frame_idx in range(total_frames):
            frame_path = os.path.join(frames_dir, f"frame_{frame_idx:05d}.png")
            if os.path.exists(frame_path):
                pbar.update(1)
                continue

            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            cv2.imwrite(frame_path, cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()

    # Save video
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))
    for frame_idx in range(total_frames):
        img = cv2.imread(os.path.join(frames_dir, f"frame_{frame_idx:05d}.png"))
        if img is not None:
            writer.write(img)
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM ----
def run_sam(input_path, out_dir, name):
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    frames_dir = os.path.join(out_dir, f"{name}_sam_frames")
    os.makedirs(frames_dir, exist_ok=True)

    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM: {sam_path}")
        return sam_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()

    with tqdm(total=total_frames, desc="SAM", unit="frame") as pbar:
        for frame_idx in range(total_frames):
            frame_path = os.path.join(frames_dir, f"frame_{frame_idx:05d}.png")
            if os.path.exists(frame_path):
                pbar.update(1)
                continue

            cap = cv2.VideoCapture(input_path)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame_bgr = cap.read()
            cap.release()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                seg_resized = cv2.resize(seg_img, (width, height)) if seg_img is not None else frame_bgr
                cv2.imwrite(frame_path, seg_resized)
            except Exception as e:
                print(f"‚ö†Ô∏è SAM failed: {e}")
                cv2.imwrite(frame_path, frame_bgr)

            pbar.update(1)

    # Save video
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))
    for frame_idx in range(total_frames):
        img = cv2.imread(os.path.join(frames_dir, f"frame_{frame_idx:05d}.png"))
        if img is not None:
            writer.write(img)
    writer.release()
    return sam_path


# ---- Stage 4: Fusion ----
def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)      # original frames
    cap_deold = cv2.VideoCapture(deoldify_path)   # deoldify video
    cap_sam = cv2.VideoCapture(sam_path)          # sam masks

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width,height))

    total_frames = int(min(
        len(yolo_results),
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_in = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            # SAM mask
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # YOLO mask
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # Intersection
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # Fusion: base is ORIGINAL frame
            fusion_frame = frame_in.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# ---- Example ----
if __name__ == "__main__":
    input_video = "input_videos/thatha_manavadu_test.mp4"
    outputs = process_video(input_video)
    print("Pipeline outputs:")
    for k, v in outputs.items():
        print(f" - {k}: {v}")


In [12]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # your ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name

def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False

def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None):
    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()
    if to_path is None:
        to_path = filename
    with open(to_path, "wb") as f:
        f.write(r.content)
    return to_path

def run_sam_on_frame(frame_rgb, comfy_server=COMFY):
    tmp_path = f"temp_frame_{uuid.uuid4().hex}.png"
    Image.fromarray(frame_rgb).save(tmp_path)
    uploaded = upload_image_to_comfy(tmp_path, server=comfy_server)

    with open(WORKFLOW_JSON,"r") as f:
        prompt = json.load(f)
    patch_loadimage_node(prompt, uploaded)

    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time()+60
    seg_path = None
    while time.time()<deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    seg_path = download_image(im["filename"], server=comfy_server,
                                              subfolder=im.get("subfolder",""),
                                              folder_type=im.get("type","output"))
                    break
        if seg_path: break
        time.sleep(0.5)
    os.remove(tmp_path)
    return seg_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    # Output paths
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer_yolo = cv2.VideoWriter(yolo_path, fourcc, fps, (width,height))
    writer_deoldify = cv2.VideoWriter(deoldify_path, fourcc, fps, (width,height))
    writer_sam = cv2.VideoWriter(sam_path, fourcc, fps, (width,height))
    writer_fusion = cv2.VideoWriter(fusion_path, fourcc, fps, (width,height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="Processing", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret: break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

            # --- YOLO detections ---
            results = yolo_model.predict(frame_bgr, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer_yolo.write(results[0].plot())

            # --- DeOldify full-frame ---
            deold = deoldify_inference(frame_rgb)
            writer_deoldify.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))

            # --- SAM segmentation ---
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                if seg_img is None:
                    seg_resized = frame_bgr
                else:
                    seg_resized = cv2.resize(seg_img, (width,height))
                writer_sam.write(seg_resized)
            except Exception as e:
                print(f"‚ö†Ô∏è SAM failed on frame: {e}")
                seg_resized = frame_bgr
                writer_sam.write(frame_bgr)

            # --- Fusion: apply DeOldify only on SAM ‚à© YOLO ---
            fusion_frame = frame_rgb.copy()
            if results and results[0].boxes is not None:
                boxes = results[0].boxes.xyxy.cpu().numpy()
                confs = results[0].boxes.conf.cpu().numpy()
                classes = results[0].boxes.cls.cpu().numpy()

                gray = cv2.cvtColor(seg_resized, cv2.COLOR_BGR2GRAY)
                _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

                for box, conf, cls in zip(boxes, confs, classes):
                    if int(cls) != 0 or conf < CONF_THRESHOLD:
                        continue  # only person class with conf >= 0.6

                    x1, y1, x2, y2 = map(int, box)
                    yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
                    yolo_mask[y1:y2, x1:x2] = 255

                    intersect = cv2.bitwise_and(sam_mask, yolo_mask)
                    mask_bool = intersect > 127

                    fusion_frame[mask_bool] = deold[mask_bool]

            writer_fusion.write(cv2.cvtColor(fusion_frame, cv2.COLOR_RGB2BGR))

            pbar.update(1)

    cap.release()
    writer_yolo.release()
    writer_deoldify.release()
    writer_sam.release()
    writer_fusion.release()

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# --- Example call ---
if __name__ == "__main__":
    input_video = "input_videos/Dr - Trim.mp4"
    outputs = process_video(input_video)
    print("Pipeline outputs:")
    for k,v in outputs.items():
        print(f" - {k}: {v}")


[INFO] Using device: cpu


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:11<00:00,  3.69s/frame]

[INFO] Outputs written to outputs/Dr - Trim
Pipeline outputs:
 - yolo: outputs/Dr - Trim/Dr - Trim_yolo.mp4
 - deoldify: outputs/Dr - Trim/Dr - Trim_deoldify.mp4
 - sam: outputs/Dr - Trim/Dr - Trim_sam.mp4
 - final: outputs/Dr - Trim/Dr - Trim_final.mp4





In [4]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time



YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"

In [6]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name

def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False

def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None):
    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()
    if to_path is None:
        to_path = filename
    with open(to_path, "wb") as f:
        f.write(r.content)
    return to_path

def run_sam_on_frame(frame_rgb, comfy_server=COMFY):
    tmp_path = f"temp_frame_{uuid.uuid4().hex}.png"
    Image.fromarray(frame_rgb).save(tmp_path)
    uploaded = upload_image_to_comfy(tmp_path, server=comfy_server)

    with open(WORKFLOW_JSON,"r") as f:
        prompt = json.load(f)
    patch_loadimage_node(prompt, uploaded)

    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time()+60
    seg_path = None
    while time.time()<deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    seg_path = download_image(im["filename"], server=comfy_server,
                                              subfolder=im.get("subfolder",""),
                                              folder_type=im.get("type","output"))
                    break
        if seg_path: break
        time.sleep(0.5)
    os.remove(tmp_path)
    return seg_path


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width,height))

    results_per_frame = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret: break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            results_per_frame.append(results[0])
            pbar.update(1)
    cap.release()
    writer.release()
    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width,height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret: break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM ----
def run_sam(input_path, out_dir, name):
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width,height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="SAM", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret: break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                seg_resized = cv2.resize(seg_img, (width,height)) if seg_img is not None else frame_bgr
                writer.write(seg_resized)
            except Exception as e:
                print(f"‚ö†Ô∏è SAM failed: {e}")
                writer.write(frame_bgr)
            pbar.update(1)
    cap.release()
    writer.release()
    return sam_path


# ---- Stage 4: Fusion ----
def run_fusion(input_path, out_dir, name, yolo_results):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width,height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret: break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

            # reload DeOldify
            deold = deoldify_inference(frame_rgb)

            # reload SAM
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                gray = cv2.cvtColor(seg_img, cv2.COLOR_BGR2GRAY)
                _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
            except:
                sam_mask = np.zeros((height,width), dtype=np.uint8)

            # YOLO boxes
            fusion_frame = frame_rgb.copy()
            if frame_idx < len(yolo_results):
                boxes = yolo_results[frame_idx].boxes.xyxy.cpu().numpy()
                confs = yolo_results[frame_idx].boxes.conf.cpu().numpy()
                classes = yolo_results[frame_idx].boxes.cls.cpu().numpy()
                for box, conf, cls in zip(boxes, confs, classes):
                    if int(cls) != 0 or conf < CONF_THRESHOLD:
                        continue
                    x1,y1,x2,y2 = map(int, box)
                    yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
                    yolo_mask[y1:y2,x1:x2] = 255
                    intersect = cv2.bitwise_and(sam_mask, yolo_mask)
                    mask_bool = intersect > 127
                    fusion_frame[mask_bool] = deold[mask_bool]

            writer.write(cv2.cvtColor(fusion_frame, cv2.COLOR_RGB2BGR))
            frame_idx += 1
            pbar.update(1)
    cap.release()
    writer.release()
    return fusion_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(input_path, out_dir, name, yolo_results)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# # ---- Example ----
# if __name__ == "__main__":
#     input_video = "input_videos/THATHA MANAVADU colored Trim.mp4"
#     outputs = process_video(input_video)
#     print("Pipeline outputs:")
#     for k,v in outputs.items():
#         print(f" - {k}: {v}")


  warn("""Your validation set is empty. If this is by design, use `split_none()`
  WeightNorm.apply(module, name, dim)
  state = torch.load(tmp_file)
  state = torch.load(source, map_location=device)


[INFO] Using device: cpu


In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Using:", torch.cuda.get_device_name(0))


CUDA available: False
CUDA device count: 1


In [11]:
import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    if os.path.exists(yolo_path):
        print(f"[CACHE] Using cached YOLO: {yolo_path}")
        return yolo_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            pbar.update(1)
    cap.release()
    writer.release()
    return yolo_path


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM ----
# ---- Stage 3: SAM ----
def run_sam(input_path, out_dir, name):
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM: {sam_path}")
        return sam_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="SAM", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                seg_resized = cv2.resize(seg_img, (width, height)) if seg_img is not None else frame_bgr
                writer.write(seg_resized)
            except Exception as e:
                print(f"‚ö†Ô∏è SAM failed: {e}")
                writer.write(frame_bgr)
            pbar.update(1)

    cap.release()
    writer.release()
    return sam_path




# ---- Stage 4: Fusion ----
def run_fusion(yolo_path, deoldify_path, sam_path, out_dir, name):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_yolo = cv2.VideoCapture(yolo_path)
    cap_deold = cv2.VideoCapture(deoldify_path)
    cap_sam = cv2.VideoCapture(sam_path)

    fps = int(cap_yolo.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_yolo.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_yolo.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        cap_yolo.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        for _ in range(total_frames):
            ret1, frame_yolo = cap_yolo.read()
            ret2, frame_deold = cap_deold.read()
            ret3, frame_sam = cap_sam.read()
            if not (ret1 and ret2 and ret3):
                break

            # Build mask from SAM video
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
            mask_bool = sam_mask > 127

            # Fusion: base is YOLO frame, apply DeOldify where mask is valid
            fusion_frame = frame_yolo.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            pbar.update(1)

    cap_yolo.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()
    return fusion_path


import cv2
import numpy as np
from tqdm import tqdm

def run_fusion(input_path, yolo_results, deoldify_path, sam_path, out_dir, name, conf_thresh=0.6):
    """
    Fusion: Apply DeOldify color only where SAM mask ‚à© YOLO(person, conf>=0.6) overlap.
    
    Args:
        input_path (str): original input video path (for reference size/fps)
        yolo_results (list): list of YOLO results per frame (from run_yolo)
        deoldify_path (str): path to DeOldify output video
        sam_path (str): path to SAM output video
        out_dir (str): output directory
        name (str): base name
        conf_thresh (float): YOLO confidence threshold
    
    Returns:
        str: path to fusion video
    """
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    # Open video sources
    cap_deold = cv2.VideoCapture(deoldify_path)
    cap_sam = cv2.VideoCapture(sam_path)

    fps = int(cap_deold.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_deold.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_deold.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        len(yolo_results),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_deold and ret_sam):
                break

            # --- Build SAM mask ---
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # --- Build YOLO mask ---
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            boxes = yolo_results[frame_idx].boxes.xyxy.cpu().numpy()
            confs = yolo_results[frame_idx].boxes.conf.cpu().numpy()
            classes = yolo_results[frame_idx].boxes.cls.cpu().numpy()

            for box, conf, cls in zip(boxes, confs, classes):
                if int(cls) != 0 or conf < conf_thresh:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # --- Intersection mask (SAM ‚à© YOLO) ---
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # --- Apply DeOldify only where intersection is True ---
            fusion_frame = frame_deold.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(input_path, yolo_path, deoldify_path, sam_path, out_dir, name)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# ---- Example ----
if __name__ == "__main__":
    input_video = "input_videos/THATHA MANAVADU colored Trim.mp4"
    outputs = process_video(input_video)
    print("Pipeline outputs:")
    for k, v in outputs.items():
        print(f" - {k}: {v}")


[INFO] Using device: cpu
[CACHE] Using cached YOLO: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_yolo.mp4
[CACHE] Using cached DeOldify: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_deoldify.mp4
[CACHE] Using cached SAM: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_sam.mp4
[CACHE] Using cached Fusion: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_final.mp4
[INFO] Outputs written to outputs/THATHA MANAVADU colored Trim
Pipeline outputs:
 - yolo: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_yolo.mp4
 - deoldify: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_deoldify.mp4
 - sam: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_sam.mp4
 - final: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_final.mp4


In [4]:
import os
import cv2
import torch
import numpy as np
import pickle
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Using:", torch.cuda.get_device_name(0))


# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)


# yolo_model = YOLO(YOLO_MODEL_PATH).to("cuda")
# print(f"YOLO running on: {yolo_model.device}")
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)

# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    results_path = os.path.join(out_dir, f"{name}_yolo_results.pkl")

    if os.path.exists(yolo_path) and os.path.exists(results_path):
        print(f"[CACHE] Using cached YOLO + results: {yolo_path}")
        with open(results_path, "rb") as f:
            results_per_frame = pickle.load(f)
        return yolo_path, results_per_frame

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))

    results_per_frame = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            # Store only necessary info to reduce memory
            results_per_frame.append({
                "boxes": results[0].boxes.xyxy.cpu().numpy(),
                "conf": results[0].boxes.conf.cpu().numpy(),
                "cls": results[0].boxes.cls.cpu().numpy()
            })
            pbar.update(1)

    cap.release()
    writer.release()

    with open(results_path, "wb") as f:
        pickle.dump(results_per_frame, f)

    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM ----
def run_sam(input_path, out_dir, name):
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM: {sam_path}")
        return sam_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="SAM", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            # Here, you should integrate your SAM inference instead of dummy mask
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            seg_resized = frame_bgr  # fallback
            writer.write(seg_resized)
            pbar.update(1)

    cap.release()
    writer.release()
    return sam_path


# ---- Stage 4: Fusion ----
def run_fusion(yolo_results, deoldify_path, sam_path, out_dir, name):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_deold = cv2.VideoCapture(deoldify_path)
    cap_sam = cv2.VideoCapture(sam_path)

    fps = int(cap_deold.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_deold.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_deold.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        len(yolo_results),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_deold and ret_sam):
                break

            # SAM mask
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # YOLO mask
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # Intersection
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # Fusion: apply DeOldify where intersection
            fusion_frame = frame_deold.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path


# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(yolo_results, deoldify_path, sam_path, out_dir, name)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# ---- Example ----
if __name__ == "__main__":
    input_video = "input_videos/THATHA MANAVADU colored Trim.mp4"
    outputs = process_video(input_video)
    print("Pipeline outputs:")
    for k, v in outputs.items():
        print(f" - {k}: {v}")


CUDA available: True
CUDA device count: 1
Using: NVIDIA GeForce RTX 4060 Ti
[INFO] Using device: cuda


YOLO: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1871/1871 [00:47<00:00, 39.15frame/s]


[CACHE] Using cached DeOldify: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_deoldify.mp4
[CACHE] Using cached SAM: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_sam.mp4


Fusion: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1871/1871 [00:24<00:00, 75.69frame/s]

[INFO] Fusion video saved: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_final.mp4
[INFO] Outputs written to outputs/THATHA MANAVADU colored Trim
Pipeline outputs:
 - yolo: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_yolo.mp4
 - deoldify: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_deoldify.mp4
 - sam: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_sam.mp4
 - final: outputs/THATHA MANAVADU colored Trim/THATHA MANAVADU colored Trim_final.mp4





In [1]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Using:", torch.cuda.get_device_name(0))


CUDA available: True
CUDA device count: 1
Using: NVIDIA GeForce RTX 4060 Ti


In [1]:
import os
import cv2
import torch
import numpy as np
import pickle
from tqdm import tqdm
from ultralytics import YOLO
from deoldify.visualize import get_image_colorizer
from deoldify import device
from deoldify.device_id import DeviceId
from PIL import Image
import uuid, json, requests, time

# ==== CONFIG ====
YOLO_MODEL_PATH = "models/yolo11x-seg.pt"
CONF_THRESHOLD = 0.6
OUTPUT_ROOT = "outputs"
COMFY = "http://192.168.27.13:23476"    # ComfyUI server
WORKFLOW_JSON = "ClothesDetect_api.json"
# =================

# ---- Setup DeOldify ----
device.set(device=DeviceId.GPU0)
colorizer = get_image_colorizer(artistic=True)

# ---- Setup YOLO ----
device_str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device_str}")
yolo_model = YOLO(YOLO_MODEL_PATH).to(device_str)


# ---- DeOldify inference ----
def deoldify_inference(frame_rgb):
    pil_img = Image.fromarray(frame_rgb).convert("RGB")
    ret = colorizer.get_transformed_image(pil_img, render_factor=16, post_process=True)
    return np.array(ret)


# ---- ComfyUI helpers ----
def upload_image_to_comfy(local_path, server=COMFY, *, dest_name=None, folder_type="input"):
    if dest_name is None:
        dest_name = os.path.basename(local_path)
    with open(local_path, "rb") as f:
        files = {"image": (dest_name, f, "image/png")}
        data = {"type": folder_type, "overwrite": "true"}
        r = requests.post(f"{server}/upload/image", files=files, data=data, timeout=60)
        r.raise_for_status()
    return dest_name

def patch_loadimage_node(prompt_dict, new_filename):
    for node in prompt_dict.values():
        if node.get("class_type","").lower() == "loadimage":
            node["inputs"]["image"] = new_filename
            return True
    return False

def queue_prompt(prompt_dict, server=COMFY):
    client_id = str(uuid.uuid4())
    r = requests.post(f"{server}/prompt", json={"prompt": prompt_dict, "client_id": client_id}, timeout=120)
    r.raise_for_status()
    return r.json().get("prompt_id", client_id)

def get_history(prompt_id, server=COMFY):
    r = requests.get(f"{server}/history/{prompt_id}", timeout=60)
    r.raise_for_status()
    return r.json()

def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None):
    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()
    if to_path is None:
        to_path = filename
    with open(to_path, "wb") as f:
        f.write(r.content)
    return to_path


def download_image(filename, server=COMFY, folder_type="output", subfolder="", to_path=None, save_dir=None):
    # If not specified, save under OUTPUT_ROOT/comfy_downloads/
    if save_dir is None:
        save_dir = os.path.join(OUTPUT_ROOT, "comfy_downloads")
    os.makedirs(save_dir, exist_ok=True)

    params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
    r = requests.get(f"{server}/view", params=params, timeout=60)
    r.raise_for_status()

    if to_path is None:
        to_path = os.path.join(save_dir, filename)

    with open(to_path, "wb") as f:
        f.write(r.content)

    return to_path


def run_sam_on_frame(frame_rgb, comfy_server=COMFY):
    tmp_path = f"temp_frame_{uuid.uuid4().hex}.png"
    Image.fromarray(frame_rgb).save(tmp_path)
    uploaded = upload_image_to_comfy(tmp_path, server=comfy_server)

    with open(WORKFLOW_JSON, "r") as f:
        prompt = json.load(f)
    patch_loadimage_node(prompt, uploaded)

    prompt_id = queue_prompt(prompt, server=comfy_server)
    deadline = time.time() + 60
    seg_path = None
    while time.time() < deadline:
        hist = get_history(prompt_id, server=comfy_server)
        item = hist.get(prompt_id)
        if item and "outputs" in item:
            for node_out in item["outputs"].values():
                for im in node_out.get("images", []):
                    seg_path = download_image(
                        im["filename"], server=comfy_server,
                        subfolder=im.get("subfolder", ""),
                        folder_type=im.get("type", "output")
                    )
                    break
        if seg_path: break
        time.sleep(0.5)
    os.remove(tmp_path)
    return seg_path


# ---- Stage 1: YOLO ----
def run_yolo(input_path, out_dir, name):
    yolo_path = os.path.join(out_dir, f"{name}_yolo.mp4")
    results_path = os.path.join(out_dir, f"{name}_yolo_results.pkl")

    if os.path.exists(yolo_path) and os.path.exists(results_path):
        print(f"[CACHE] Using cached YOLO + results: {yolo_path}")
        with open(results_path, "rb") as f:
            results_per_frame = pickle.load(f)
        return yolo_path, results_per_frame

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(yolo_path, fourcc, fps, (width, height))

    results_per_frame = []
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="YOLO", unit="frame") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            results = yolo_model.predict(frame, conf=CONF_THRESHOLD, verbose=False, device=device_str)
            writer.write(results[0].plot())
            results_per_frame.append({
                "boxes": results[0].boxes.xyxy.cpu().numpy(),
                "conf": results[0].boxes.conf.cpu().numpy(),
                "cls": results[0].boxes.cls.cpu().numpy()
            })
            pbar.update(1)

    cap.release()
    writer.release()

    with open(results_path, "wb") as f:
        pickle.dump(results_per_frame, f)

    return yolo_path, results_per_frame


# ---- Stage 2: DeOldify ----
def run_deoldify(input_path, out_dir, name):
    deoldify_path = os.path.join(out_dir, f"{name}_deoldify.mp4")
    if os.path.exists(deoldify_path):
        print(f"[CACHE] Using cached DeOldify: {deoldify_path}")
        return deoldify_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(deoldify_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="DeOldify", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            deold = deoldify_inference(frame_rgb)
            writer.write(cv2.cvtColor(deold, cv2.COLOR_RGB2BGR))
            pbar.update(1)
    cap.release()
    writer.release()
    return deoldify_path


# ---- Stage 3: SAM ----
def run_sam(input_path, out_dir, name):
    sam_path = os.path.join(out_dir, f"{name}_sam.mp4")
    if os.path.exists(sam_path):
        print(f"[CACHE] Using cached SAM: {sam_path}")
        return sam_path

    cap = cv2.VideoCapture(input_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(sam_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    with tqdm(total=total_frames, desc="SAM", unit="frame") as pbar:
        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            try:
                seg_path = run_sam_on_frame(frame_rgb)
                seg_img = cv2.imread(seg_path)
                seg_resized = cv2.resize(seg_img, (width, height)) if seg_img is not None else frame_bgr
                writer.write(seg_resized)
            except Exception as e:
                print(f"‚ö†Ô∏è SAM failed: {e}")
                writer.write(frame_bgr)
            pbar.update(1)
    cap.release()
    writer.release()
    return sam_path


# ---- Stage 4: Fusion ----
def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)       # original
    cap_deold = cv2.VideoCapture(deoldify_path)    # colorized
    cap_sam = cv2.VideoCapture(sam_path)           # masks

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width, height))

    total_frames = int(min(
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT),
        len(yolo_results)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_bgr = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            # original gray
            frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
            frame_original = cv2.cvtColor(frame_gray, cv2.COLOR_GRAY2BGR)

            # SAM mask
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # YOLO mask
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # Intersection
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # Fusion
            fusion_frame = frame_original.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()
    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path




def run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path):
    fusion_path = os.path.join(out_dir, f"{name}_final.mp4")
    if os.path.exists(fusion_path):
        print(f"[CACHE] Using cached Fusion: {fusion_path}")
        return fusion_path

    cap_input = cv2.VideoCapture(input_path)      # original frames
    cap_deold = cv2.VideoCapture(deoldify_path)   # deoldify video
    cap_sam = cv2.VideoCapture(sam_path)          # sam masks

    fps = int(cap_input.get(cv2.CAP_PROP_FPS)) or 25
    width = int(cap_input.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap_input.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(fusion_path, fourcc, fps, (width,height))

    total_frames = int(min(
        len(yolo_results),
        cap_input.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_deold.get(cv2.CAP_PROP_FRAME_COUNT),
        cap_sam.get(cv2.CAP_PROP_FRAME_COUNT)
    ))

    frame_idx = 0
    with tqdm(total=total_frames, desc="Fusion", unit="frame") as pbar:
        while frame_idx < total_frames:
            ret_in, frame_in = cap_input.read()
            ret_deold, frame_deold = cap_deold.read()
            ret_sam, frame_sam = cap_sam.read()
            if not (ret_in and ret_deold and ret_sam):
                break

            # SAM mask
            gray = cv2.cvtColor(frame_sam, cv2.COLOR_BGR2GRAY)
            _, sam_mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)

            # YOLO mask
            yolo_mask = np.zeros_like(sam_mask, dtype=np.uint8)
            for box, conf, cls in zip(
                yolo_results[frame_idx]["boxes"],
                yolo_results[frame_idx]["conf"],
                yolo_results[frame_idx]["cls"]
            ):
                if int(cls) != 0 or conf < CONF_THRESHOLD:  # only "person"
                    continue
                x1, y1, x2, y2 = map(int, box)
                yolo_mask[y1:y2, x1:x2] = 255

            # Intersection
            intersect = cv2.bitwise_and(sam_mask, yolo_mask)
            mask_bool = intersect > 127

            # Fusion: base is ORIGINAL frame
            fusion_frame = frame_in.copy()
            fusion_frame[mask_bool] = frame_deold[mask_bool]

            writer.write(fusion_frame)
            frame_idx += 1
            pbar.update(1)

    cap_input.release()
    cap_deold.release()
    cap_sam.release()
    writer.release()

    print(f"[INFO] Fusion video saved: {fusion_path}")
    return fusion_path




# ---- Main Pipeline ----
def process_video(input_path):
    folder, fname = os.path.split(input_path)
    name, _ = os.path.splitext(fname)
    out_dir = os.path.join(OUTPUT_ROOT, name)
    os.makedirs(out_dir, exist_ok=True)

    yolo_path, yolo_results = run_yolo(input_path, out_dir, name)
    deoldify_path = run_deoldify(input_path, out_dir, name)
    sam_path = run_sam(input_path, out_dir, name)
    fusion_path = run_fusion(input_path, out_dir, name, yolo_results, deoldify_path, sam_path)

    print(f"[INFO] Outputs written to {out_dir}")
    return {
        "yolo": yolo_path,
        "deoldify": deoldify_path,
        "sam": sam_path,
        "final": fusion_path
    }


# ---- Example ----
# if __name__ == "__main__":
#     input_video = "input_videos/thatha_manavadu_test.mp4"
#     outputs = process_video(input_video)
#     print("Pipeline outputs:")
#     for k, v in outputs.items():
#         print(f" - {k}: {v}")


  warn("""Your validation set is empty. If this is by design, use `split_none()`
  WeightNorm.apply(module, name, dim)
  state = torch.load(tmp_file)
  state = torch.load(source, map_location=device)


[INFO] Using device: cpu
