In [9]:
from PIL import Image
import os

def apply_mask_to_image(image_path, mask_path):
    """
    根据mask裁剪图片，将非mask区域设置为透明。

    :param image_path: 原始图片路径
    :param mask_path: mask图片路径（黑白二值图像）
    :param output_path: 输出图片路径
    """
    # 打开原始图片和mask图片
    image = Image.open(image_path).convert("RGBA")
    mask = Image.open(mask_path).convert("L")  # 转换为灰度图

    # 确保mask和图片尺寸一致
    if image.size != mask.size:
        raise ValueError("图片和mask的尺寸不一致，请确保它们具有相同的宽度和高度。")

    # 将mask应用到图片上
    data = image.getdata()
    mask_data = mask.getdata()

    new_data = []
    for pixel, mask_value in zip(data, mask_data):
        # 如果mask值为0（黑色），设置为透明
        if mask_value == 0:
            new_data.append((0, 0, 0, 0))  # 完全透明
        else:
            new_data.append(pixel)  # 保留原像素

    # 更新图片数据
    image.putdata(new_data)

    # 寻找mask内物体的边界
    bbox = mask.getbbox()  # 获取包含mask中所有非零像素的最小矩形区域

    if bbox is None:
        print("未在mask中找到任何物体。")
        return

    # 裁剪图片
    cropped_image = image.crop(bbox)

    return cropped_image

dex_path = {'mustard0' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000006",
            'mustard_easy_00_02' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000006",
            'bleach_hard_00_03_chaitanya' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000194",
            'bleach0' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000194",
            'cracker_box_yalehand0' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000055",
            'cracker_box_reorient' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000055",
            'sugar_box_yalehand0' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000080",
            'sugar_box1' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000080",
            'tomato_soup_can_yalehand0' : "/baai-cwm-1/baai_cwm_ml/public_data/scenes/new_dex_ycb/bop/data/000131",}


for key in dex_path: 
    raw_img = os.path.join(dex_path[key], "rgb/000000.jpg")
    mask_file = raw_img.replace('public_data/scenes/new_dex_ycb/bop/data','cwm/yuhao.duan/gz/FoundationPose/init_dir').replace('rgb','mask').replace('.jpg','.png')
    ref_output_dir = os.path.join("/baai-cwm-1/baai_cwm_ml/cwm/yuhao.duan/gz/FoundationPose/init_dir", "crop")
    image_output_dir = os.path.join("/baai-cwm-1/baai_cwm_ml/cwm/yuhao.duan/gz/FoundationPose/init_dir", "rgb")
    mask_output_dir = os.path.join("/baai-cwm-1/baai_cwm_ml/cwm/yuhao.duan/gz/FoundationPose/init_dir", "mask")
    os.makedirs(ref_output_dir, exist_ok=True)
    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(mask_output_dir, exist_ok=True)
    ref_output_image = os.path.join(ref_output_dir, f"{key}.png")
    image_output_image = os.path.join(image_output_dir, f"{key}.png")
    mask_output_image = os.path.join(mask_output_dir, f"{key}.png")

    image = Image.open(raw_img).convert("RGBA")
    mask = Image.open(mask_file).convert("L")  # 转换为灰度图

    # 裁剪物体
    cropped_image = apply_mask_to_image(raw_img, mask_file)

    cropped_image.save(ref_output_image)  # 保存裁剪后的图片
    image.save(image_output_image)  # 保存原始图片
    mask.save(mask_output_image)  # 保存mask图片