In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

# import supervision as sv
from datasets import Dataset, Image
from collections import defaultdict
import PIL

# from PIL import Image
from tqdm import tqdm
from pycocotools.coco import COCO
import numpy as np
import cv2
import albumentations as A

In [None]:
transform = A.Compose(
    [
        A.SmallestMaxSize(max_size=386, always_apply=True),
        A.CenterCrop(height=386, width=386, always_apply=True),
    ],
    bbox_params=A.BboxParams(
        format="pascal_voc", label_fields=["class_labels"], clip=True, min_area=1
    ),
)

dataset_path = Path("/data/trash_demo/TACO/data/")
coco = COCO(dataset_path.joinpath("annotations.json").as_posix())


image_ids = coco.getImgIds()
categories = [coco.cats[cat_id]["name"] for cat_id in coco.getCatIds()]

dataset_dict = defaultdict(list)
prefix = "segment " + " ; ".join(categories)


for image_id in tqdm(image_ids):
    image_path = dataset_path.joinpath(coco.loadImgs(image_id)[0]["file_name"])
    annotations = coco.loadAnns(coco.getAnnIds(image_id))
    xywh_bboxes = [ann["bbox"] for ann in annotations]
    xyxy_bboxes = [[x, y, x + w, y + h] for x, y, w, h in xywh_bboxes]
    classes = [categories[ann["category_id"]] for ann in annotations]

    image = cv2.imread(image_path.as_posix())
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    masks = [coco.annToMask(ann) for ann in annotations]

    transformed = transform(
        image=image, masks=masks, bboxes=xyxy_bboxes, class_labels=classes
    )

    image = PIL.Image.fromarray(transformed["image"])
    masks = np.array(transformed["masks"], dtype=bool)
    xyxy_bboxes = np.array(transformed["bboxes"], dtype=int)
    classes = transformed["class_labels"]

    xyxy_bboxes = np.array([[x1, y1, x2, y2] for x1, y1, x2, y2 in xyxy_bboxes if x2 - x1 > 0 and y2 - y1 > 0])

    if len(masks) == 0 or len(xyxy_bboxes) == 0 or len(classes) == 0:
        continue

    assert len(masks.shape) == 3
    assert (
        len(xyxy_bboxes.shape) == 2 and xyxy_bboxes.shape[1] == 4
    ), f"{xyxy_bboxes.shape}, {len(masks)}, {len(xyxy_bboxes)}"
    

    dataset_dict["image"].append(image)
    dataset_dict["prompt"].append(prefix)
    dataset_dict["xyxy_bboxes"].append(xyxy_bboxes)
    dataset_dict["masks"].append(masks)
    dataset_dict["classes"].append(classes)


dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("image", Image())

dataset.info.dataset_name = "taco_trash"
dataset.info.description = f"class_names: {' ; '.join(categories)}"

dataset.save_to_disk("taco_trash")

In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys

sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

from datasets import Dataset
import PIL
import numpy as np
from training_toolkit.src.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

In [None]:
dataset = Dataset.load_from_disk("taco_trash")
dataset = dataset.with_format("torch")

segmentation_tokenizer = SegmentationTokenizer()

In [None]:
example = dataset[0]
PIL.Image.fromarray(example["masks"][0].numpy())

In [None]:
example = dataset[0]
suffix = segmentation_tokenizer.encode(
    example["image"], example["xyxy_bboxes"], example["masks"], example["classes"]
)

suffix

In [None]:
decoded = segmentation_tokenizer.decode(suffix, 386, 386)

PIL.Image.fromarray((decoded[0]["mask"] > 0.5).astype(np.uint8) * 255)

In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys

sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

from training_toolkit import paligemma_image_preset, image_segmentation_preset, build_trainer

In [None]:
paligemma_image_preset.training_args["per_device_train_batch_size"] = 12
paligemma_image_preset.training_args["per_device_eval_batch_size"] = 12
paligemma_image_preset.training_args["num_train_epochs"] = 8

trainer = build_trainer(
    **paligemma_image_preset.as_kwargs(),
    **image_segmentation_preset.with_path("taco_trash").as_kwargs()
)

In [None]:
trainer.train()

In [None]:
from dotenv import load_dotenv
from pathlib import Path
import sys


sys.path.append(Path("..").resolve().as_posix())
_ = load_dotenv()

from peft import AutoPeftModelForCausalLM
from transformers import AutoProcessor
import PIL
import numpy as np
import cv2
import supervision as sv

from training_toolkit.src.common.tokenization_utils.segmentation import (
    SegmentationTokenizer,
)

In [None]:
CHECKPOINT_PATH = "paligemma_2024-08-01_16-13-08/checkpoint-200"

model = AutoPeftModelForCausalLM.from_pretrained(CHECKPOINT_PATH)
processor = AutoProcessor.from_pretrained(CHECKPOINT_PATH)
segmentation_tokenizer = SegmentationTokenizer()

In [None]:
image = Image.open("assets/trash1.jpg")
# image

In [None]:
prefix = "segment trash"

PROMPT = prefix

inputs = processor(images=image, text=PROMPT)

generated_ids = model.generate(**inputs, max_new_tokens=256, do_sample=True)

# Next we turn each predicted token ID back into a string using the decode method
# We chop of the prompt, which consists of image tokens and our text prompt
image_token_index = model.config.image_token_index
num_image_tokens = len(generated_ids[generated_ids == image_token_index])
num_text_tokens = len(processor.tokenizer.encode(PROMPT))
num_prompt_tokens = num_image_tokens + num_text_tokens + 2
generated_text = processor.batch_decode(
    generated_ids[:, num_prompt_tokens:],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)[0]

w, h = image.size

generated_segmentation = segmentation_tokenizer.decode(generated_text, w, h)

In [None]:
PIL.Image.fromarray((generated_segmentation[0]["mask"] > 0.5).astype(np.uint8) * 255)

In [None]:
xyxy = []
mask = []
class_id = []
class_name = []

for r in generated_segmentation:
    xyxy.append(r["xyxy"])
    _, m = cv2.threshold(r["mask"], 0.5, 1.0, cv2.THRESH_BINARY)
    mask.append(m)
    # class_id.append(ds.classes.index(r["name"].strip()))
    # class_id.append(classes.index(r['name'].strip()))
    class_id.append(0)
    class_name.append(r["name"].strip())

detections = sv.Detections(
    xyxy=np.array(xyxy).astype(int),
    mask=np.array(mask).astype(bool),
    class_id=np.array(class_id).astype(int),
)

detections["class_name"] = class_name

In [None]:
image = sv.BoxAnnotator().annotate(image, detections)

image = sv.MaskAnnotator().annotate(image, detections)
image = sv.LabelAnnotator(text_scale=2, text_thickness=4, text_position=sv.Position.CENTER_OF_MASS, text_color=sv.Color.BLACK).annotate(image, detections)

# sv.plot_images_grid([image], (2, 2))
image