In [None]:
# !pip install segmentation_models_pytorch


Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [None]:
from pycocotools.coco import COCO
import numpy as np
import random
import os
from PIL import Image
import torch
from torchvision import transforms
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt

In [None]:
import random
from pycocotools.coco import COCO
from collections import defaultdict

from google.colab import drive
drive.mount('/content/drive')

# Load your annotation file
ann_path = '/content/drive/MyDrive/images_segm/train-300/labels.json'
coco = COCO(ann_path)


# # Path to annotation file
# train_ann_path = '/content/drive/MyDrive/images_segm/train-300/labels.json'
train_img_dir = '/content/drive/MyDrive/images_segm/train-300/data/'


# Target classes
target_classes = ['cake', 'cat', 'dog', 'person']
cat_ids = coco.getCatIds(catNms=target_classes)

# Get all image IDs that contain any of the 4 classes
image_ids = set()
for cid in cat_ids:
    image_ids.update(coco.getImgIds(catIds=[cid]))
image_ids = list(image_ids)

# Shuffle for randomness
random.seed(42)
random.shuffle(image_ids)

# Organize image IDs by category (to ensure class coverage)
class_to_images = defaultdict(set)
for img_id in image_ids:
    anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
    class_ids = set([ann['category_id'] for ann in anns])
    for cid in class_ids:
        if cid in cat_ids:
            class_to_images[cid].add(img_id)

# Make sure each class has enough samples
for cid in cat_ids:
    print(f"{coco.loadCats([cid])[0]['name']}: {len(class_to_images[cid])} images")

# Split into train, val, test (70/15/15) ensuring all classes present
train_ids, val_ids, test_ids = [], [], []

used_ids = set()

def take_balanced_subset(target_list, size, used_ids):
    """Get a balanced subset containing all 4 classes"""
    count = 0
    selected = []
    for img_id in image_ids:
        if img_id in used_ids:
            continue
        anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        labels = {ann["category_id"] for ann in anns}
        if labels & set(cat_ids):
            selected.append(img_id)
            used_ids.add(img_id)
            count += 1
        if count >= size:
            break
    target_list.extend(selected)

# Sizes
N = len(image_ids)
take_balanced_subset(train_ids, int(N * 0.7), used_ids)
take_balanced_subset(val_ids, int(N * 0.15), used_ids)
take_balanced_subset(test_ids, int(N * 0.15), used_ids)

print(f"Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

# Optional: Verify class distribution in test set
def check_class_coverage(split_ids, name):
    present = set()
    for img_id in split_ids:
        anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        present.update([a['category_id'] for a in anns if a['category_id'] in cat_ids])
    class_names = [coco.loadCats([cid])[0]['name'] for cid in present]
    print(f"{name} contains classes: {sorted(class_names)}")

check_class_coverage(train_ids, "Train")
check_class_coverage(val_ids, "Val")
check_class_coverage(test_ids, "Test")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
loading annotations into memory...
Done (t=0.06s)
creating index...
index created!
cake: 1 images
cat: 4 images
dog: 10 images
person: 272 images
Train: 193, Val: 41, Test: 41
Train contains classes: ['cake', 'cat', 'dog', 'person']
Val contains classes: ['dog', 'person']
Test contains classes: ['cat', 'dog', 'person']


In [None]:
def load_image_and_mask(image_id):
    info = coco.loadImgs([image_id])[0]
    path = os.path.join(train_img_dir, info['file_name'])
    image = Image.open(path).convert('RGB')
    image = np.array(image)

    mask = np.zeros((info["height"], info["width"]), dtype=np.uint8)
    ann_ids = coco.getAnnIds(imgIds=image_id, catIds=cat_ids)
    anns = coco.loadAnns(ann_ids)
    for ann in anns:
        class_index = cat_id_to_index[ann['category_id']]
        ann_mask = coco.annToMask(ann)
        mask = np.maximum(mask, ann_mask * class_index)

    image = transforms.ToTensor()(image)
    return image, torch.tensor(mask, dtype=torch.long)
