# Imports

In [None]:
from pathlib import Path
from typing import Any, Sequence

from IPython.display import display # type: ignore
import ipywidgets as widgets # type: ignore

import torch
import supervision as sv
from PIL import Image

import tt

# Setup

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

In [None]:
data_path = Path.home() / "src/data"
mobs_path = data_path / "mobs1/640"

IMAGE_FILES = sorted(list(mobs_path.iterdir()))

CLASSES = [
    'chicken',
    'cow',
    'creeper',
    'enderman',
    'pig',
    'player',
    'sheep',
    'skeleton',
    'spider',
    'villager',
    'zombie'
    ]
CLASSES.sort()
CLASSES_MINECRAFT = [f"minecraft {x}" for x in CLASSES]


In [None]:
viewer = tt.ImageDirViewer(mobs_path)
viewer.show_widget()

# Owl

In [None]:
from transformers import Owlv2Processor, Owlv2ForObjectDetection

class Owl:
    # model_id = "google/owlv2-base-patch16-ensemble"
    model_id = "google/owlv2-large-patch14-ensemble"
    
    def __init__(self, classes: list[str], threshold: float = 0.1):
        self.processor = Owlv2Processor.from_pretrained(self.model_id)
        self.model = Owlv2ForObjectDetection.from_pretrained(self.model_id).to(device)
        self.classes = classes
        self.threshold = threshold

    def detect(self, image: Image.Image) -> dict[Any, Any]:
        texts = [self.classes]
        inputs = self.processor(text=texts, images=image, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        target_sizes = torch.Tensor([image.size[::-1]])
        # Convert outputs (bounding boxes and class logits) to Pascal VOC Format (xmin, ymin, xmax, ymax)
        results = self.processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=self.threshold)
        i = 0  # Retrieve predictions for the first image for the corresponding text queries
        text = texts[i]
        boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
        for box, score, label in zip(boxes, scores, labels):
            box = [round(i, 2) for i in box.tolist()]
            print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")

        return results[i]

    def infer(self, image_file: str | Path) -> Image.Image:
        image_file = Path(image_file)
        image = Image.open(image_file).convert("RGB")
        result = self.detect(image)

        detections = sv.Detections.from_transformers(result)
        # Create annotators
        box_annotator = sv.BoxAnnotator()
        label_annotator = sv.LabelAnnotator()

        # display(result)
        # print("--------------------")
        # display(detections)

        # Create labels
        names = [self.classes[x] for x in detections.class_id]
        labels = [f"{name}: {conf:.2f}" for name, conf in zip(names, detections.confidence)]

        # # Annotate
        annotated_image = box_annotator.annotate(scene=image.copy(), detections=detections)
        annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
        return annotated_image

In [None]:
# Reduce classes
owl_remove = ["creeper", "enderman", "player", "skeleton", "villager", "zombie"]
owl_classes = [x for x in CLASSES if x not in owl_remove]

owl_classes_minecraft = [f"minecraft {x}" for x in owl_classes]

owl1 = Owl(owl_classes_minecraft, threshold=0.4)
fname = mobs_path / '06092b21-2024-10-20_22.22.09.png'
image = owl1.infer(fname)
display(image)

In [None]:
viewer = tt.InferViewer(owl1.infer, mobs_path)
viewer.show_widget()

# Yolo

In [None]:
from ultralytics import YOLOWorld

fname = mobs_path / '06092b21-2024-10-20_22.22.09.png'
# Initialize model
model = YOLOWorld("yolov8x-worldv2.pt")  # or yolov8l-worldv2.pt

# Set classes
model.set_classes(["spider", "pig", "minecraft chicken", "cow", "creeper", "zombie", "skeleton"])

# Predict
image = Image.open(fname).convert("RGB")
results = model.predict(image, conf=.04)
annotated_bgr = results[0].plot()
annotated_rgb = annotated_bgr[..., ::-1]
display(Image.fromarray(annotated_rgb))


In [None]:
class Yolo:
    model_id = "yolov8x-worldv2.pt"
    # model_id = "yolov8l-worldv2.pt"
    
    def __init__(self, classes: list[str], conf=0.25):
        self.model = YOLOWorld(self.model_id)
        self.model.set_classes(classes)
        self.conf = conf

    def detect(self, image: Image.Image) -> dict[Any, Any]:
        result = self.model.predict(image, conf=self.conf)
        return result

    def infer(self, image_file: str | Path) -> Image.Image:
        image_file = Path(image_file)
        image = Image.open(image_file).convert("RGB")
        results = self.detect(image)

        # # Create annotators
        annotated_bgr = results[0].plot()
        annotated_rgb = annotated_bgr[..., ::-1]
        return Image.fromarray(annotated_rgb)

In [None]:
yolo1 = Yolo(CLASSES_MINECRAFT, conf=.03)
viewer = tt.InferViewer(yolo1.infer, mobs_path)
viewer.show_widget()

# Dino

In [None]:
from transformers import AutoProcessor, GroundingDinoProcessor, AutoModelForZeroShotObjectDetection # type: ignore

In [None]:
class Dino:
    model_id = "IDEA-Research/grounding-dino-tiny"
    # model_id = "IDEA-Research/grounding-dino-base"
    
    def __init__(self, classes: Sequence[str], threshold: float, text_threshold: float):
        self.processor: GroundingDinoProcessor = AutoProcessor.from_pretrained(self.model_id)
        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(self.model_id).to(device)

        self.classes = classes
        self.threshold = threshold
        self.text_threshold = text_threshold

    def detect(self, image: Image.Image) -> dict[Any, Any]:
        inputs = self.processor(images=image, text=self.classes, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model(**inputs)

        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            threshold=self.threshold,
            text_threshold=self.text_threshold,
            target_sizes=[image.size[::-1]]
        )

        result: dict[Any, Any] = results[0]
        display(result)
        for box, score, labels in zip(result["boxes"], result["scores"], result["text_labels"]):
            box = [round(x, 2) for x in box.tolist()]
            print(f"Detected {labels} with confidence {round(score.item(), 3)} at location {box}")

        return result

    def fix_grounding_dino_result(self, result: dict) -> tuple[dict, dict]:
        """Work around bug in grounding dino transformer. result["labels"] are supposed to be ints."""
        # Create label mapping
        unique_labels = list(set(result["text_labels"]))
        label2id = {label: idx for idx, label in enumerate(unique_labels)}
        id2label = {idx: label for label, idx in label2id.items()}
        
        # Fix result
        result_fixed = result.copy()
        result_fixed["labels"] = torch.tensor(
            [label2id[label] for label in result["text_labels"]], 
            device=result["boxes"].device
        )
        
        return result_fixed, id2label

    def infer(self, image_file: str | Path) -> Image.Image:
        image_file = Path(image_file)
        image = Image.open(image_file).convert("RGB")
        result = self.detect(image)

        result_fixed, id2label = self.fix_grounding_dino_result(result)
        detections = sv.Detections.from_transformers(result_fixed, id2label=id2label)
        # Create annotators
        box_annotator = sv.BoxAnnotator()
        label_annotator = sv.LabelAnnotator()

        # Create labels
        labels = [
            f"{label}: {score:.2f}"
            for label, score in zip(result["text_labels"], result["scores"])
        ]

        # Annotate
        annotated_image = box_annotator.annotate(scene=image.copy(), detections=detections)
        annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
        return annotated_image

In [None]:
dino1 = Dino(CLASSES, 0.3, 0.5)
fname = mobs_path / '06092b21-2024-10-20_22.22.09.png'
image = dino1.infer(fname)
display(image)

In [None]:
viewer = tt.InferViewer(dino1.infer, mobs_path)
viewer.show_widget()