In [1]:
from pathlib import Path
import sys
from torchvision import transforms

sys.path.append(str(Path.cwd().parent))
sys.path.append(str(Path.cwd().parent / "label_anything"))
sys.path.append(str(Path.cwd().parent / "label_anything" / "data"))

import os
from label_anything.data.test import LabelAnythingTestDataset
from label_anything.data.utils import BatchKeys
from torchvision.transforms import ToTensor
from PIL import Image
import json
import torch
from pycocotools import mask as mask_utils

  from .autonotebook import tqdm as notebook_tqdm


In [152]:
class BrainMriTestDataset(LabelAnythingTestDataset):
    num_classes = 2

    def __init__(self, annotations, img_dir, transform=None):
        super().__init__()
        with open(annotations, "r") as f:
            annotations = json.load(f)
        self.annotations = annotations
        self.img_dir = img_dir  # data/raw/lgg-mri-segmentation/kaggle_3m/
        self.transform = transform

    def __len__(self):
        return len(self.annotations["images"])

    def _get_image(self, image_info):
        image_path = os.path.join(self.img_dir, image_info["url"])
        img = Image.open(image_path)
        if self.transform:
            img = self.transform(img)  # 3 x h x w
        return img, (img.shape[1], img.shape[2])

    def _get_gt(self, annotation_info):
        mask = mask_utils.decode(annotation_info["segmentation"])  #
        bbox = (annotation_info["bbox"],)  # [x, y, w, h]
        if self.transform:
            mask = self.transform(mask)
            bbox = torch.tensor(bbox)
        return {"mask": mask, "bbox": bbox}

    def __getitem__(self, idx):
        image_info = self.annotations["images"][idx]
        annotation_info = self.annotations["annotations"][idx]
        image, size = self._get_image(image_info)
        gt = self._get_gt(annotation_info)
        return {
            BatchKeys.IMAGES: image,
            BatchKeys.DIMS: size,
        }, gt

In [None]:
from torch.nn.functional import one_hot


def extract_prompts(self):
    images = [
        self._get_image(self.train_channels_folder, filename)
        for filename in self.prompt_images
    ]
    sizes = torch.stack([torch.tensor(x.shape[1:]) for x in images])
    images = [self._transform(image) for image in images]
    images = torch.stack(images)
    masks = [
        self._get_gt(self.train_gt_folder, filename) for filename in self.prompt_images
    ]
    masks = torch.stack(masks)
    masks = one_hot(masks.long(), 3).permute(0, 3, 1, 2).float()
    contains_crop = (masks == 1).sum(dim=(1, 2)) > 0
    contains_weed = (masks == 2).sum(dim=(1, 2)) > 0
    flag_masks = torch.stack([contains_crop, contains_weed]).T

    prompt_dict = {
        BatchKeys.IMAGES: images,
        BatchKeys.PROMPT_MASKS: masks,
        BatchKeys.FLAG_MASKS: flag_masks,
        BatchKeys.DIMS: sizes,
    }
    return prompt_dict

In [4]:
annotations = "/home/emanuele/LabelAnything/data/annotations/brain_mri.json"
img_dir = "/home/emanuele/LabelAnything/data/raw/lgg-mri-segmentation/kaggle_3m/"

In [154]:
transform = transforms.Compose([transforms.ToTensor()])
dataset = BrainMriTestDataset(annotations, img_dir, transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

In [155]:
data_dict, gt = next(iter(dataloader))

print(
    [
        f"{k}: {v.size() if isinstance(v, torch.Tensor) else v}"
        for k, v in data_dict.items()
    ]
)
print([f"{k}: {v.size() if isinstance(v, torch.Tensor) else v}" for k, v in gt.items()])

['images: torch.Size([4, 3, 256, 256])', 'dims: [tensor([256, 256, 256, 256]), tensor([256, 256, 256, 256])]']
['mask: torch.Size([4, 1, 256, 256])', 'bbox: torch.Size([4, 1, 4])']


In [16]:
img = Image.open(os.path.join(img_dir, annotations_file["images"][0]["url"]))

In [17]:
seg = annotations_file["annotations"][0]["segmentation"]
bbox = annotations_file["annotations"][0]["bbox"]

In [5]:
with open(annotations, "r") as f:
    annotations_file = json.load(f)

In [19]:
from pycocotools import mask as mask_utils

i = mask_utils.decode(seg)
# convert bbox to tensor
bbox = torch.tensor(bbox)

In [10]:
def image_to_category(annotations):
    image_to_category = {}
    for annotation in annotations["annotations"]:
        image_id = annotation["image_id"]
        category_id = annotation["category_id"]
        image_to_category[image_id] = category_id
    return image_to_category

In [11]:
a = image_to_category(annotations_file)

In [None]:
a