In [6]:
import os
import sys
import cv2
import torch
import numpy as np
from torchvision.transforms import Compose
from tqdm import tqdm

# ----- Path Setup -----
project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from midas.midas.dpt_depth import DPTDepthModel
from midas.midas.transforms import Resize, NormalizeImage, PrepareForNet

# ---------- Load MiDaS Model ----------
def load_midas_model():
    model = DPTDepthModel("../midas/midas/weights/dpt_large_384.pt", pretrained=True, backbone="vitl16_384", non_negative=True)
    model.eval().cuda()
    transform = Compose([
        Resize(384, 384, keep_aspect_ratio=True),
        NormalizeImage(mean=[0.5]*3, std=[0.5]*3),
        PrepareForNet()
    ])
    return model, transform

# ---------- Estimate Depth Map ----------
def get_depth_map(img, model, transform):
    sample = transform({"image": img / 255.0})
    input_tensor = sample["image"]
    if isinstance(input_tensor, np.ndarray):
        input_tensor = torch.from_numpy(input_tensor)
    input_tensor = input_tensor.unsqueeze(0).cuda()

    with torch.no_grad():
        prediction = model(input_tensor)[0]

    depth_map = prediction.cpu().numpy()
    depth_map = cv2.resize(depth_map, (img.shape[1], img.shape[0]))
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    return depth_map

# ---------- Compute Focus from Alpha ----------
def compute_focus_from_alpha(alpha_gray, depth):
    mask = alpha_gray > 127
    if np.sum(mask) == 0:
        return 0.5  # fallback
    return float(np.mean(depth[mask]))

# ---------- Apply Depth-Aware Bokeh with Foreground Protection ----------
def apply_depth_bokeh(img, depth_map, focus=0.4, max_blur=15, alpha_mask=None):
    output = np.zeros_like(img, dtype=np.float32)
    steps = 8
    focus_range = 0.05

    for i in range(steps):
        blur_strength = int((i / (steps - 1)) * max_blur)
        if blur_strength % 2 == 0:
            blur_strength += 1

        blurred = cv2.GaussianBlur(img, (blur_strength, blur_strength), 0)
        lower = focus + (i - steps // 2) / steps
        upper = focus + (i + 1 - steps // 2) / steps

        mask = (depth_map >= lower) & (depth_map < upper)
        mask = cv2.GaussianBlur(mask.astype(np.float32), (11, 11), 0)

        for c in range(img.shape[2]):
            output[..., c] += blurred[..., c].astype(np.float32) * mask

    # Foreground preservation using alpha matte
    if alpha_mask is not None:
        alpha_norm = alpha_mask.astype(np.float32) / 255.0
        for c in range(img.shape[2]):
            output[..., c] = alpha_norm * img[..., c] + (1 - alpha_norm) * output[..., c]

    return np.clip(output, 0, 255).astype(np.uint8)

# ---------- Process Video with Matching Bokeh ----------
def process_video_with_matte(input_video_path, alpha_video_path, output_rgb_path, output_alpha_path):
    cap_rgb = cv2.VideoCapture(input_video_path)
    cap_alpha = cv2.VideoCapture(alpha_video_path)

    frame_width = int(cap_rgb.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap_rgb.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap_rgb.get(cv2.CAP_PROP_FPS)

    out_rgb = cv2.VideoWriter(output_rgb_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
    out_alpha = cv2.VideoWriter(output_alpha_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height), isColor=False)

    model, transform = load_midas_model()
    smoothed_depth = None
    alpha_smooth = 0.8

    frame_idx = 0
    while True:
        ret_rgb, frame = cap_rgb.read()
        ret_alpha, alpha = cap_alpha.read()
        if not ret_rgb or not ret_alpha:
            break

        alpha_gray = cv2.cvtColor(alpha, cv2.COLOR_BGR2GRAY)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        current_depth = get_depth_map(frame_rgb, model, transform)

        if smoothed_depth is None:
            smoothed_depth = current_depth
        else:
            smoothed_depth = alpha_smooth * smoothed_depth + (1 - alpha_smooth) * current_depth

        depth = smoothed_depth
        focus = compute_focus_from_alpha(alpha_gray, depth)

        # ✅ Apply bokeh to RGB with alpha protection
        blurred_frame = apply_depth_bokeh(frame, depth, focus=focus, alpha_mask=alpha_gray)

        # ✅ Apply bokeh to matte with same protection
        alpha_color = cv2.cvtColor(alpha_gray, cv2.COLOR_GRAY2BGR)
        blurred_alpha = apply_depth_bokeh(alpha_color, depth, focus=focus, alpha_mask=alpha_gray)
        blurred_alpha_gray = cv2.cvtColor(blurred_alpha, cv2.COLOR_BGR2GRAY)

        out_rgb.write(blurred_frame)
        out_alpha.write(blurred_alpha_gray)

        frame_idx += 1
        print(f"Processed frame {frame_idx}", end='\r')

    cap_rgb.release()
    cap_alpha.release()
    out_rgb.release()
    out_alpha.release()
    print("\n✅ Finished processing video and matte with consistent bokeh.")

# Example usage:
process_video_with_matte(
    "../data/VideoMatte240K/train/fgr/0000.mp4",
    "../data/VideoMatte240K/train/pha/0000.mp4",
    "output_blurred_rgb.mp4",
    "output_blurred_alpha.mp4"
)


Processed frame 339
✅ Finished processing video and matte with consistent bokeh.


In [21]:
import os
import sys
import cv2
import torch
import numpy as np
from torchvision.transforms import Compose
from tqdm import tqdm

# ----- Path Setup -----
project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from midas.midas.dpt_depth import DPTDepthModel
from midas.midas.transforms import Resize, NormalizeImage, PrepareForNet

# ---------- Load MiDaS ----------
def load_midas_model():
    model = DPTDepthModel(
        "../midas/midas/weights/dpt_large_384.pt",
        pretrained=True,
        backbone="vitl16_384",
        non_negative=True
    )
    model.eval().cuda()
    transform = Compose([
        Resize(384, 384, keep_aspect_ratio=True),
        NormalizeImage(mean=[0.5]*3, std=[0.5]*3),
        PrepareForNet()
    ])
    return model, transform

# ---------- Get Depth Map ----------
def get_depth_map(img, model, transform):
    sample = transform({"image": img / 255.0})
    input_tensor = sample["image"]
    if isinstance(input_tensor, np.ndarray):
        input_tensor = torch.from_numpy(input_tensor)
    input_tensor = input_tensor.unsqueeze(0).cuda()

    with torch.no_grad():
        prediction = model(input_tensor)[0]

    depth_map = prediction.cpu().numpy()
    depth_map = cv2.resize(depth_map, (img.shape[1], img.shape[0]))
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    return depth_map

# ---------- Create Static Blur Map ----------
def create_static_blur_map(depth_map, focus_mask, focus_range=0.02, steps=8, max_blur=65):
    subject_depth = np.mean(depth_map[focus_mask > 127])
    depth_diff = np.abs(depth_map - subject_depth)
    blur_levels = np.clip((depth_diff - focus_range) / (1 - focus_range), 0, 1)

    blur_maps = []
    for i in range(steps):
        blur_strength = int((i / (steps - 1)) * max_blur)
        if blur_strength % 2 == 0:
            blur_strength += 1
        mask = np.clip(1 - np.abs(blur_levels - i / (steps - 1)) * steps, 0, 1)
        mask = cv2.GaussianBlur(mask.astype(np.float32), (15, 15), 0)
        blur_maps.append((blur_strength, mask))

    return blur_maps

# ---------- Apply Precomputed Blur Map ----------
def apply_precomputed_blur(image, blur_maps):
    is_gray = image.ndim == 2
    if is_gray:
        image = image[..., np.newaxis]

    output = np.zeros_like(image, dtype=np.float32)

    for blur_strength, mask in blur_maps:
        blurred = cv2.GaussianBlur(image, (blur_strength, blur_strength), 0)
        if blurred.ndim == 2:
            blurred = blurred[..., np.newaxis]
        for c in range(image.shape[2]):
            output[..., c] += blurred[..., c] * mask

    output = np.clip(output, 0, 255).astype(np.uint8)
    return output.squeeze() if is_gray else output

# ---------- Process Video ----------
def process_video_static_blur_mask(input_rgb, input_alpha, output_rgb, output_alpha):
    cap_rgb = cv2.VideoCapture(input_rgb)
    cap_alpha = cv2.VideoCapture(input_alpha)

    frame_width = int(cap_rgb.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap_rgb.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap_rgb.get(cv2.CAP_PROP_FPS)

    out_rgb = cv2.VideoWriter(output_rgb, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
    out_alpha = cv2.VideoWriter(output_alpha, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height), isColor=False)

    model, transform = load_midas_model()

    # --- Use FIRST frame to generate static blur maps ---
    ret_rgb, ref_frame = cap_rgb.read()
    ret_alpha, ref_alpha = cap_alpha.read()
    if not ret_rgb or not ret_alpha:
        print("Error reading reference frame")
        return

    ref_rgb_input = cv2.cvtColor(ref_frame, cv2.COLOR_BGR2RGB)
    ref_alpha_gray = cv2.cvtColor(ref_alpha, cv2.COLOR_BGR2GRAY)

    print("🔍 Generating depth map from reference frame...")
    depth_map = get_depth_map(ref_rgb_input, model, transform)

    print("🌀 Creating static blur maps...")
    blur_maps = create_static_blur_map(depth_map, ref_alpha_gray)

    # Rewind
    cap_rgb.set(cv2.CAP_PROP_POS_FRAMES, 0)
    cap_alpha.set(cv2.CAP_PROP_POS_FRAMES, 0)

    # --- Process all frames using same blur maps ---
    frame_idx = 0
    while True:
        ret_rgb, frame = cap_rgb.read()
        ret_alpha, alpha = cap_alpha.read()
        if not ret_rgb or not ret_alpha:
            break

        alpha_gray = cv2.cvtColor(alpha, cv2.COLOR_BGR2GRAY)

        blurred_rgb = apply_precomputed_blur(frame, blur_maps)
        blurred_alpha = apply_precomputed_blur(alpha_gray, blur_maps)

        out_rgb.write(blurred_rgb)
        out_alpha.write(blurred_alpha)

        frame_idx += 1
        print(f"Processed frame {frame_idx}", end='\r')

    cap_rgb.release()
    cap_alpha.release()
    out_rgb.release()
    out_alpha.release()
    print("\n✅ Done: Full-subject blur + matching soft matte, flicker-free.")

# ✅ RUN IT
process_video_static_blur_mask(
    "../data/VideoMatte240K/train/fgr/0000.mp4",     # Your RGB video
    "../data/VideoMatte240K/train/pha/0000.mp4",     # Your alpha matte video
    "blurred_rgb_subject_and_bg.mp4",                # Output RGB
    "blurred_alpha_soft_subject_mask.mp4"            # Output soft matte
)


🔍 Generating depth map from reference frame...
🌀 Creating static blur maps...
Processed frame 339
✅ Done: Full-subject blur + matching soft matte, flicker-free.
