In [5]:
%%writefile real_env.py

import os
import dataclasses
from dataclasses import dataclass
from typing import List, Dict
import numpy as np
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

WS_CROP_X = (0, 300)
WS_CROP_Y = (0, 300)

@dataclass
class EnvState:
    color_im: np.ndarray
    depth_im: np.ndarray
    objects: Dict[str, np.ndarray]  # 2D 좌표 (xmin, ymin, xmax, ymax)

class RealEnv():
    def __init__(self, bin_cam, task: str, all_objects: List[str], task_objects: List[str], output_name: str = None):
        self.bin_cam = bin_cam
        self.task = task
        self.all_objects = all_objects
        self.task_objects = task_objects
        if output_name is None:
            self.output_name = f"real_world/outputs/{self.task}/"
        else:
            self.output_name = f"real_world/outputs/{output_name}/"
        os.makedirs(self.output_name, exist_ok=True)

        self.robot_name = 'Bob'
        self.human_name = 'Alice'

        # Load OWL-ViT model
        self.model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
        self.processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")

        self.timestep = 0

    def get_obs(self, save=False) -> EnvState:
        # kinect.py 파일에서 color_im, depth_im 정보 받아옴.
        color_im, depth_im = self.bin_cam.get_camera_data()
        ws_color_im = color_im[WS_CROP_X[0]:WS_CROP_X[1], WS_CROP_Y[0]:WS_CROP_Y[1]]
        ws_depth_im = depth_im[WS_CROP_X[0]:WS_CROP_X[1], WS_CROP_Y[0]:WS_CROP_Y[1]]

        image = Image.fromarray(ws_color_im)
        text = self.all_objects

        # 객체 탐지 수행 (processor, 모델링)
        inputs = self.processor(text=[text], images=image, return_tensors="pt")
        outputs = self.model(**inputs)

        # 결과 처리
        target_sizes = torch.Tensor([image.size[::-1]])
        results = self.processor.post_process(outputs=outputs, target_sizes=target_sizes)
        pred_scores = results[0]["scores"].detach().numpy()
        pred_labels = results[0]["labels"].detach().numpy()
        pred_boxes = results[0]["boxes"].detach().numpy()

        objects = {}
        for label, box in zip(pred_labels, pred_boxes):
            if pred_scores[label] >= 0.5:  # 임계값 설정
                objects[text[label]] = box

        # 3D 좌표 계산
        object_coords = []
        for label, box in objects.items():
            xmin, ymin, xmax, ymax = map(int, box)
            center_x = (xmin + xmax) // 2
            center_y = (ymin + ymax) // 2

            depth = ws_depth_im[center_y, center_x]
            if depth == 0:
                continue

            point_x = (center_x - self.bin_cam.intr.ppx) * depth / self.bin_cam.intr.fx
            point_y = (center_y - self.bin_cam.intr.ppy) * depth / self.bin_cam.intr.fy
            point_z = depth

            object_coords.append({"label": label, "coords": (point_x, point_y, point_z)})

        self.timestep += 1
        if save:
            image.save(f"{self.output_name}/img_{self.timestep}.png")

        # 3D 좌표 출력
        for obj in object_coords:
            print(f"Label: {obj['label']}, 3D Coordinates: {obj['coords']}")

        obs = EnvState(
            color_im=color_im,
            depth_im=depth_im,
            objects=objects,
        )
        return obs

    def plot_preds(self, color_im, objects, save=False, show=True):
        fig, ax = plt.subplots(figsize=(12, 12 * color_im.shape[0] / color_im.shape[1]))
        ax.imshow(color_im)
        colors = sns.color_palette('muted', len(objects))
        for label, c in zip(objects, colors):
            (xmin, ymin, xmax, ymax) = map(int, objects[label])
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
            if label in self.task_objects:
                ax.text(xmin-30, ymax+15, label, fontsize=22, bbox=dict(facecolor='white', alpha=0.8))
            else:
                ax.text(xmin, ymin-10, label, fontsize=22, bbox=dict(facecolor='white', alpha=0.8))
        plt.axis('off')
        fig.tight_layout()

        if show:
            plt.show()
        if save:
            fig.savefig(f"{self.output_name}/pred_{self.timestep}.png")


Overwriting real_env.py


In [2]:
# 다현언니 코드 기반 잘돌아감!


%%writefile real_env.py

#real_env.py 코드 파일입니다.
import os
import dataclasses
from dataclasses import dataclass
from typing import List, Dict
import numpy as np
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

WS_CROP_X = (0, 300)
WS_CROP_Y = (0, 300)

# 환경 상태를 저장할 데이터 클래스 정의
@dataclass
class EnvState:
    color_im: np.ndarray
    depth_im: np.ndarray
    objects: Dict[str, np.ndarray] # 2d 좌표. (xmin, ymin, xmax, ymax)로 구성되어 객체를 감싸는 사각형.
    # ex) { "apple": np.array([50, 30, 200, 150]), "banana": np.array([120, 60, 250, 180]) }

class RealEnv():
    def __init__(self, bin_cam, task: str, all_objects: List[str], task_objects: List[str], output_name: str = None):
        self.bin_cam = bin_cam
        self.task = task
        self.all_objects = all_objects
        self.task_objects = task_objects
        if output_name is None: #output_name 디폴트는 None. 따로 설정 안 해주면, self.task 로 디렉토리 이름 설정
            self.output_name = f"real_world/outputs/{self.task}/"
        else:
            self.output_name = f"real_world/outputs/{output_name}/"
        os.makedirs(self.output_name, exist_ok=True) # 출력 디렉토리 생성

        self.robot_name = 'Bob' #로봇 이름: Bob
        self.human_name = 'Alice' #사람 이름: Alice

        # Load OWL-ViT model
        self.model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") # OWL-ViT 모델을 로드
        self.processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") # 이미지와 텍스트 데이터를 처리하기 위한 프로세서를 로드

        self.timestep = 0

    def get_obs(self, save=False) -> EnvState:
        # kinect.py 파일에서 color_im, depth_im 정보 받아옴. 
        color_im, depth_im = self.bin_cam.get_camera_data()
        ws_color_im = color_im[WS_CROP_X[0]:WS_CROP_X[1], WS_CROP_Y[0]:WS_CROP_Y[1]]
        ws_depth_im = depth_im[WS_CROP_X[0]:WS_CROP_X[1], WS_CROP_Y[0]:WS_CROP_Y[1]]

        image = Image.fromarray(ws_color_im)
        text = self.all_objects

        # Get max probability bounding boxes for each object label
        # 객체 탐지 수행 (processor, 모델링)
        inputs = self.processor(text=[text], images=image, return_tensors="pt")
        outputs = self.model(**inputs)

        # 결과 처리
        target_sizes = torch.Tensor([image.size[::-1]])
        results = self.processor.post_process(outputs=outputs, target_sizes=target_sizes)
        pred_scores = results[0]["scores"].detach().numpy()
        pred_labels = results[0]["labels"].detach().numpy()
        pred_boxes = results[0]["boxes"].detach().numpy()

        # 객체 크기 비교 및 분류 - Origin from dahyeon
        objects = {}
        for label in np.unique(pred_labels):
            max_score_idx = np.argmax(pred_scores[np.where(pred_labels == label)])
            max_box = pred_boxes[np.where(pred_labels == label)][max_score_idx]
            objects[text[label]] = max_box

        self.timestep += 1
        if save:
            image.save(f"{self.output_name}/img_{self.timestep}.png")
        obs = EnvState(
            color_im=color_im,
            depth_im=depth_im,
            objects=objects,
        )
        return obs

    def plot_preds(self, color_im, objects, save=False, show=True):
        fig, ax = plt.subplots(figsize=(12, 12 * color_im.shape[0] / color_im.shape[1]))
        ax.imshow(color_im)
        colors = sns.color_palette('muted', len(objects))
        for label, c in zip(objects, colors):
            (xmin, ymin, xmax, ymax) = objects[label]
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
            if label in self.task_objects:
                ax.text(xmin-30, ymax+15, label, fontsize=22, bbox=dict(facecolor='white', alpha=0.8))
            else:
                ax.text(xmin, ymin-10, label, fontsize=22, bbox=dict(facecolor='white', alpha=0.8))
        plt.axis('off')
        fig.tight_layout()

        if show:
            plt.show()
        if save:
            fig.savefig(f"{self.output_name}/pred_{self.timestep}.png")
           

Overwriting real_env.py
