In [1]:
import os
import torch
import numpy as np
from PIL import Image

In [22]:
# 定义一个Dataset
class PennFudanDataset(torch.utils.data.Dataset):
    def __init__(self, root, transformers):
        self.root = root
        self.transformer = transformers
        self.images = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
        self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))

    def __getitem__(self, index):
        # load images and masks
        img_path = os.path.join(self.root, "PNGImages", self.images[index])
        mask_path = os.path.join(self.root, "PedMasks", self.masks[index])
        # open image and mask
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        mask = np.array(mask)
        mask_index = np.unique(mask)
        mask_index = mask_index[1:]
        # 分割图像标记，有mask的地方为true，并且按照标记数量将图像分层
        masks = mask == mask_index[:, None, None]
        # number of mask
        mask_number = len(mask_index)
        boxs = []
        for i in range(mask_number):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxs.append([xmin,ymin, xmax, ymax])
        # convert everything into tensor
        boxs = torch.as_tensor(boxs, dtype=torch.float32)
        labels = torch.as_tensor((mask_number,), dtype=torch.int64)   
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([index])
        area = (boxs[:, 3] - boxs[:, 1]) * (boxs[:, 2] - boxs[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((mask_number,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxs
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [23]:
# test dataset
pennFudanDataset = PennFudanDataset("C:/Users/CHENCHEN/Project/ServerBackup/data/PennFudanPed", "transformer")
pennFudanDataset[2]

(1, 445, 479)


In [32]:
a = np.arange(0,100,10)
b = np.where(a < 50)
c = torch.ones(20)
print(c)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])
