In [5]:
import torch
import os
import numpy as np
from pycocotools.coco import COCO
from PIL import Image

**Reference:** [tutorial](https://www.kaggle.com/code/armanasgharpoor1993/coco-dataset-tutorial-image-segmentation#Step-16:-Generating-Image-and-Mask-Datasets)

### Tworzenie generatora do wsadowego generowania wstępnie przetworzonych obrazów i masek.

In [6]:
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, images_path, masks_path, annotation_dir, image_size=224, batch_size=1):
        """
        CocoDataset class for generating batches of images, masks, bounding boxes, and labels.

        Args:
            images_path (str): Path to the directory containing the original images.
            masks_path (str): Path to the directory containing the masks.
            annotation_dir (object): COCO-like annotations object containing bounding boxes and labels.
            image_size (int): Target size for resizing images and masks (default: 224).
            batch_size (int): Number of samples in each batch (default: 1).
        """
        self.images_path = images_path
        self.masks_path = masks_path
        self.coco = annotation_dir
        self.image_size = image_size
        self.batch_size = batch_size
        self.image_filenames = self.get_matching_filenames()
        self.cat_ids = self.coco.getCatIds()

    def get_matching_filenames(self):
        """
        Get the list of matching filenames between images and masks.

        Returns:
            list: List of matching filenames.
        """
        image_files = set([os.path.splitext(filename)[0] for filename in os.listdir(self.images_path)])
        mask_files = set([os.path.splitext(filename)[0] for filename in os.listdir(self.masks_path)])
        matching_files = list(image_files.intersection(mask_files))
        return matching_files

    def __len__(self):
        """
        Get the number of batches in the dataset.

        Returns:
            int: Total number of batches.
        """
        return int(np.ceil(len(self.image_filenames) / self.batch_size))

    def __getitem__(self, idx):
        """
        Get a batch of preprocessed samples of images, masks, bounding boxes, and labels.

        Args:
            idx (int): Batch index.

        Returns:
            dict: Dictionary containing batch of images, masks, bounding boxes, and labels.
        """
        batch_start = idx * self.batch_size
        batch_end = min((idx + 1) * self.batch_size, len(self.image_filenames))
        batch_filenames = self.image_filenames[batch_start:batch_end]

        images = []
        masks = []
        bounding_boxes = []
        labels = []

        for filename in batch_filenames:
            # Load image and mask
            image_path = os.path.join(self.images_path, filename + '.jpg')
            mask_path = os.path.join(self.masks_path, filename + '.jpg')

            image = Image.open(image_path).convert("RGB")
            mask = Image.open(mask_path)

            # Save original dimensions for scaling bounding boxes
            original_width, original_height = image.size

            # Resize image and mask to target size
            image = image.resize((self.image_size, self.image_size))
            mask = mask.resize((self.image_size, self.image_size))

            # Convert image and mask to numpy arrays
            image = np.array(image) / 255.0  # Normalize image
            mask = np.array(mask)  # Binary or multi-class mask

            # Retrieve annotations for the image
            img_id = self.coco.getImgIds()[self.image_filenames.index(filename)]

            ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.cat_ids, iscrowd=None)
            anns = self.coco.loadAnns(ann_ids)
            img_bboxes = []
            img_labels = []

            # Process annotations (scale bounding boxes to new image size)
            for ann in anns:
                bbox = ann['bbox']
                category_id = ann['category_id']

                # Scale bounding box to resized image dimensions
                x, y, width, height = bbox
                x = x * (self.image_size / original_width)
                y = y * (self.image_size / original_height)
                width = width * (self.image_size / original_width)
                height = height * (self.image_size / original_height)

                img_bboxes.append([x, y, x + width, y + height])
                img_labels.append(category_id)

            # Append data for this image to the batch
            images.append(image.transpose(2, 0, 1))  # Convert to channels-first format
            masks.append(mask)
            bounding_boxes.append(img_bboxes)
            labels.append(img_labels)

        # Convert all batch data to tensors
        images = torch.tensor(images, dtype=torch.float32)
        masks = torch.tensor(masks, dtype=torch.float32).unsqueeze(1)  # Add channel dimension for masks
        bounding_boxes = [torch.tensor(bboxes, dtype=torch.float32) for bboxes in bounding_boxes]
        labels = [torch.tensor(lbls, dtype=torch.int64) for lbls in labels]

        return {
            "images": images,  # Shape: (batch_size, 3, 224, 224)
            "masks": masks,  # Shape: (batch_size, 1, 224, 224)
            "bounding_boxes": bounding_boxes,  # List of tensors (1 per image)
            "labels": labels  # List of tensors (1 per image)
        }


In [7]:
images_path = 'coco10/train2017_subset/images'
masks_path = 'coco10/train2017_subset/masks'
annotation_dir = "coco10/train2017_subset/coco10_train_annotations.json"
batch_size = 8

# Initialize the data generator
train_data_generator = CocoDataset(
    images_path=images_path,
    masks_path=masks_path,
    annotation_dir = COCO(annotation_dir),
    image_size=224,
    batch_size=batch_size,
)

# Fetch a batch of data
batch = train_data_generator[0]

# Inspect shapes of images and masks
print(f"Images shape: {batch['images'].shape}")  # (batch_size, 3, 224, 224)
print(f"Masks shape: {batch['masks'].shape}")    # (batch_size, 1, 224, 224)

loading annotations into memory...
Done (t=0.39s)
creating index...
index created!
Images shape: torch.Size([8, 3, 224, 224])
Masks shape: torch.Size([8, 1, 224, 224])


  images = torch.tensor(images, dtype=torch.float32)


In [8]:
def validate_image_shapes(generator):
    """
    Print the shapes of preprocessed images, masks, bounding boxes, and labels from the provided generator.

    Args:
        generator (CustomDataGenerator): Instance of the CustomDataGenerator class.
    """
    for i in range(len(generator)):
        # Get a batch of preprocessed samples from the generator
        batch = generator[i]  # batch is a dictionary
        
        images = batch['images']
        masks = batch['masks']
        bounding_boxes = batch['bounding_boxes']
        labels = batch['labels']

        # Print shapes for each item in the batch
        for j in range(len(images)):
            image = images[j]
            mask = masks[j]
            bbox = bounding_boxes[j]
            lbl = labels[j]

            print(f"Sample {i * generator.batch_size + j}:")
            print(f"  Shape of preprocessed image: {image.shape}")
            print(f"  Shape of preprocessed mask: {mask.shape}")
            print(f"  Number of bounding boxes: {len(bbox)}")
            print(f"  Bounding boxes: {bbox.tolist()}")
            print(f"  Labels: {lbl.tolist()}")
        
        # Break after a few batches for brevity (optional)
        if i >= 1:  # Check only the first two batches
            break

# Call the validation function
validate_image_shapes(train_data_generator)

Sample 0:
  Shape of preprocessed image: torch.Size([3, 224, 224])
  Shape of preprocessed mask: torch.Size([1, 224, 224])
  Number of bounding boxes: 2
  Bounding boxes: [[1.0149999856948853, 0.7886666655540466, 128.1595001220703, 197.81533813476562], [110.9990005493164, 0.0, 223.68850708007812, 199.73333740234375]]
  Labels: [24, 24]
Sample 1:
  Shape of preprocessed image: torch.Size([3, 224, 224])
  Shape of preprocessed mask: torch.Size([1, 224, 224])
  Number of bounding boxes: 1
  Bounding boxes: [[1.8976943492889404, 0.9240000247955322, 288.7558288574219, 147.07000732421875]]
  Labels: [23]
Sample 2:
  Shape of preprocessed image: torch.Size([3, 224, 224])
  Shape of preprocessed mask: torch.Size([1, 224, 224])
  Number of bounding boxes: 6
  Bounding boxes: [[44.432498931884766, 67.0373764038086, 92.8375015258789, 168.77639770507812], [88.99800109863281, 66.0144271850586, 160.3314971923828, 126.09574127197266], [140.6439971923828, 28.98885154724121, 179.89300537109375, 79.4019