## Lick MY Balls Quinn 

In [None]:
# overhead_multicolor_pixel_to_world.py
# Requirements: pip install mujoco glfw PyOpenGL pillow numpy

import numpy as np
import mujoco as mj
from mujoco.glfw import glfw
import OpenGL.GL as gl
from PIL import Image, ImageDraw

# --------- User config ---------
XML_PATH = r"C:\Users\ptfc0\Downloads\LabratoryPickNPlace-main\LabratoryPickNPlace-main\franka_panda_w_objs.xml"
CAM_NAME = "overhead_cam"
WINDOW_SIZE = (1200, 900)     # good coverage and resolution
USE_CV_CONVENTION = True      # X right, Y down, Z forward (image coords)

Z_TABLE_FALLBACK = 1.0

COLOR_CLASSES = {
    "teal_block": {"ref_rgb": (0, 255, 0), "tol": 55},
    "red_target": {"ref_rgb": (255, 0, 0), "tol": 60},
}

GRID = 48
MIN_PIXELS = 80
# --------------------------------


def intrinsics_from_fovy(fovy_deg, width, height):
    H, W = height, width
    fovy_rad = np.deg2rad(float(fovy_deg))
    fy = H / (2.0 * np.tan(fovy_rad / 2.0))
    fx = fy
    cx = (W - 1) / 2.0
    cy = (H - 1) / 2.0
    return fx, fy, cx, cy


def pixel_to_world_xy(u, v, W, H, fovy_deg, cam_pos, cam_xmat, z_plane=0.0, use_cv=True):
    fx, fy, cx, cy = intrinsics_from_fovy(fovy_deg, W, H)
    x_cv = (u - cx) / fx
    y_cv = (v - cy) / fy
    d_cam_cv = np.array([x_cv, y_cv, 1.0], dtype=np.float32)
    if use_cv:
        d_cam_mj = np.array([d_cam_cv[0], -d_cam_cv[1], -d_cam_cv[2]], dtype=np.float32)
    else:
        d_cam_mj = np.array([x_cv, y_cv, -1.0], dtype=np.float32)
    d_cam_mj /= np.linalg.norm(d_cam_mj) + 1e-12
    R = cam_xmat.reshape(3, 3)
    d_world = d_cam_mj @ R.T
    o = cam_pos
    if abs(d_world[2]) < 1e-12:
        return None
    t = (z_plane - o[2]) / d_world[2]
    if t < 0:
        return None
    p = o + t * d_world
    return p


def pixel_to_world_from_depth(u, v, depth_m, W, H, fovy_deg, cam_pos, cam_xmat, use_cv=True):
    fx, fy, cx, cy = intrinsics_from_fovy(fovy_deg, W, H)
    x = (u - cx) / fx
    y = (v - cy) / fy
    z = 1.0
    d_cam_cv = np.array([x, y, z], dtype=np.float32)
    d_cam_cv /= np.linalg.norm(d_cam_cv)
    if use_cv:
        d_cam_mj = np.array([d_cam_cv[0], -d_cam_cv[1], -d_cam_cv[2]], dtype=np.float32)
    else:
        d_cam_mj = np.array([x, y, -1.0], dtype=np.float32)
    R = cam_xmat.reshape(3, 3)
    p_world = cam_pos + (R.T @ d_cam_mj) * depth_m
    return p_world


def mask_from_ref_rgb(rgb, ref_rgb, tol):
    ref = np.array(ref_rgb, dtype=np.int16).reshape(1, 1, 3)
    img = rgb.astype(np.int16)
    diff = np.abs(img - ref)
    max_dev = np.max(diff, axis=2)
    mask = max_dev <= int(tol)
    return mask


def centroids_from_mask_grid(mask, grid=48, min_pixels=80):
    H, W = mask.shape
    ys, xs = np.nonzero(mask)
    if ys.size == 0:
        return []
    gx = (xs * grid) // W
    gy = (ys * grid) // H
    gx = gx.clip(0, grid - 1)
    gy = gy.clip(0, grid - 1)
    bin_keys = (gy * grid + gx).astype(np.int32)
    order = np.argsort(bin_keys)
    bin_keys_sorted = bin_keys[order]
    unique_bins, first_idx, counts = np.unique(bin_keys_sorted, return_index=True, return_counts=True)
    keep_mask = counts >= int(min_pixels)
    if not np.any(keep_mask):
        return [(float(xs.mean()), float(ys.mean()))]
    kept_bins = unique_bins[keep_mask]
    kept_first = first_idx[keep_mask]
    kept_counts = counts[keep_mask]
    bin_to_slice = {}
    for b, f, c in zip(kept_bins, kept_first, kept_counts):
        bin_to_slice[int(b)] = (int(f), int(f + c))
    cells = [(int(b) // grid, int(b) % grid) for b in kept_bins]
    cell_set = set(cells)
    visited = set()
    components = []
    for cell in cells:
        if cell in visited:
            continue
        comp = []
        queue = [cell]
        visited.add(cell)
        while queue:
            cy, cx = queue.pop()
            comp.append((cy, cx))
            for ny, nx in ((cy - 1, cx), (cy + 1, cx), (cy, cx - 1), (cy, cx + 1)):
                if 0 <= ny < grid and 0 <= nx < grid and (ny, nx) in cell_set and (ny, nx) not in visited:
                    visited.add((ny, nx))
                    queue.append((ny, nx))
        components.append(comp)
    centroids = []
    for comp in components:
        comp_indices = []
        for (cy, cx) in comp:
            b = cy * grid + cx
            s = bin_to_slice.get(b, None)
            if s is not None:
                lo, hi = s
                comp_indices.append(order[lo:hi])
        if not comp_indices:
            continue
        comp_idx = np.concatenate(comp_indices, axis=0)
        u = float(xs[comp_idx].mean())
        v = float(ys[comp_idx].mean())
        centroids.append((u, v))
    return centroids


def detect_colors_centroids(rgb, color_classes, grid=48, min_pixels=80):
    results = {}
    for name, spec in color_classes.items():
        ref = spec["ref_rgb"]
        tol = int(spec.get("tol", 60))
        mask = mask_from_ref_rgb(rgb, ref, tol)
        cents = centroids_from_mask_grid(mask, grid=grid, min_pixels=min_pixels)
        results[name] = cents
    return results


def estimate_table_z_from_known_geom(model, data):
    candidates = ["box_geom", "box_geom2", "box_geom3"]
    for gname in candidates:
        gid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_GEOM, gname)
        if gid >= 0:
            mj.mj_forward(model, data)
            z_center = float(data.geom_xpos[gid, 2])
            z_half = float(model.geom_size[gid, 2])
            return z_center - z_half
    return None


def render_and_capture(model, data, cam_name, window_size):
    W, H = window_size
    if not glfw.init():
        raise RuntimeError("Failed to initialize GLFW")
    glfw.window_hint(glfw.VISIBLE, glfw.TRUE)
    window = glfw.create_window(W, H, "Overhead Capture", None, None)
    glfw.make_context_current(window)
    glfw.swap_interval(1)

    model.vis.map.znear = 0.055
    model.vis.map.zfar  = 5.0

    cam = mj.MjvCamera()
    opt = mj.MjvOption()
    mj.mjv_defaultCamera(cam)
    mj.mjv_defaultOption(opt)
    scene = mj.MjvScene(model, maxgeom=10000)
    context = mj.MjrContext(model, mj.mjtFontScale.mjFONTSCALE_150.value)

    cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, cam_name)
    if cam_id < 0:
        raise ValueError(f"Camera '{cam_name}' not found")

    cam.type = mj.mjtCamera.mjCAMERA_FIXED
    cam.fixedcamid = cam_id
    mj.mj_forward(model, data)

    fb_w, fb_h = glfw.get_framebuffer_size(window)
    viewport = mj.MjrRect(0, 0, fb_w, fb_h)
    mj.mjv_updateScene(model, data, opt, None, cam, mj.mjtCatBit.mjCAT_ALL.value, scene)
    mj.mjr_render(viewport, scene, context)

    gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)

    # RGB capture
    rgba_bytes = gl.glReadPixels(0, 0, fb_w, fb_h, gl.GL_RGBA, gl.GL_UNSIGNED_BYTE)
    rgba = np.frombuffer(rgba_bytes, dtype=np.uint8).reshape(fb_h, fb_w, 4)
    rgb = np.flip(rgba[:, :, :3], axis=0).copy()

    # Depth capture
    depth_bytes = gl.glReadPixels(0, 0, fb_w, fb_h, gl.GL_DEPTH_COMPONENT, gl.GL_FLOAT)
    depth = np.frombuffer(depth_bytes, dtype=np.float32).reshape(fb_h, fb_w)
    depth = np.flip(depth, axis=0)

    znear, zfar = model.vis.map.znear, model.vis.map.zfar
    linear_depth = 2.0 * znear * zfar / (zfar + znear - (2.0 * depth - 1.0) * (zfar - znear))

    fovy_deg = float(model.cam_fovy[cam_id])
    cam_pos = data.cam_xpos[cam_id].copy()
    cam_xmat = data.cam_xmat[cam_id].copy()

    glfw.destroy_window(window)
    glfw.terminate()
    return rgb, linear_depth, fovy_deg, cam_pos, cam_xmat


def draw_annotations(rgb, detections, color_classes, radius=8):
    img = Image.fromarray(rgb)
    draw = ImageDraw.Draw(img)
    for cname, cents in detections.items():
        ref = color_classes[cname]["ref_rgb"]
        outline = tuple(int(c) for c in ref)
        for i, (u, v) in enumerate(cents):
            x0, y0 = u - radius, v - radius
            x1, y1 = u + radius, v + radius
            draw.ellipse((x0, y0, x1, y1), outline=outline, width=3)
            draw.text((u + radius + 2, v - radius - 2), f"{cname}[{i}]", fill=outline)
    return np.array(img)


def main():
    print("Loading model...")
    model = mj.MjModel.from_xml_path(XML_PATH)
    data = mj.MjData(model)

    z_table = estimate_table_z_from_known_geom(model, data)
    if z_table is None:
        z_table = Z_TABLE_FALLBACK
        print(f"Using fallback table height z={z_table:.3f} m")
    else:
        print(f"Estimated table height from geom: z={z_table:.3f} m")

    print("Rendering overhead image and depth map...")
    rgb, depth_map, fovy_deg, cam_pos, cam_xmat = render_and_capture(model, data, CAM_NAME, WINDOW_SIZE)
    Image.fromarray(rgb).save("overhead_rgb.png")
    np.save("overhead_depth.npy", depth_map)
    print("Saved: overhead_rgb.png and overhead_depth.npy")

    print("Detecting colors...")
    detections = detect_colors_centroids(rgb, COLOR_CLASSES, grid=GRID, min_pixels=MIN_PIXELS)

    H, W, _ = rgb.shape
    results_by_color = {}

    for cname, cents in detections.items():
        pts_world = []
        for (u, v) in cents:
            u_i, v_i = int(round(u)), int(round(v))
            if 0 <= v_i < depth_map.shape[0] and 0 <= u_i < depth_map.shape[1]:
                depth_m = float(depth_map[v_i, u_i])
                if np.isfinite(depth_m) and depth_m > 0:
                    P = pixel_to_world_from_depth(
                        u, v, depth_m,
                        W, H, fovy_deg,
                        cam_pos=cam_pos, cam_xmat=cam_xmat,
                        use_cv=USE_CV_CONVENTION
                    )
                    pts_world.append(P)
        results_by_color[cname] = pts_world

    for cname, pts in results_by_color.items():
        print(f"\n{cname}: found {len(pts)} object(s)")
        for i, P in enumerate(pts):
            print(f"  {cname}[{i}]  X={P[0]:.4f}, Y={P[1]:.4f}, Z={P[2]:.4f}")

    try:
        bid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "obj_box_06")
        mj.mj_forward(model, data)
        gt = data.xpos[bid].copy()
        gt_xyz = gt[:3]
        teal_pts = results_by_color.get("teal_block", [])
        if teal_pts:
            dists = [np.linalg.norm(pt - gt_xyz) for pt in teal_pts]
            j = int(np.argmin(dists))
            print(f"\nSanity check vs obj_box_06: nearest teal_block[{j}] "
                  f"3D error = {dists[j]*1000:.1f} mm")
    except Exception:
        pass

    annotated = draw_annotations(rgb, detections, COLOR_CLASSES, radius=8)
    Image.fromarray(annotated).save("overhead_rgb_annotated.png")
    print("Saved: overhead_rgb_annotated.png")

    print("\nDone. 3D world coordinates and annotated image saved.")


if __name__ == "__main__":
    main()
