In [1]:
import os.path as path
import glob
from PIL import Image
import matplotlib.pyplot as plt
import tqdm
import torch
import torchvision
import torchvision.models.detection as det
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [10]:
class PetDataset(Dataset):
    def __init__(self, root_dir, xforms, yforms):
        self.ann_dir = path.join(root_dir, "annotations", "trimaps")
        self.image_dir = path.join(root_dir, "images")
        self.image_files = glob.glob(path.join(self.image_dir, "*"))
        self.last_mrcnn_idx = 0
        self.breed_assoc = {x: self.last_mrcnn_idx + idx for 
                            idx, x in enumerate(sorted(list(set(
                                [ path.basename('_'.join(fname.split("_")[:-1])) for 
                                 fname in self.image_files]))))}
        self.xforms = xforms
        self.yforms = yforms
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        imf = self.image_files[idx]
        bname = path.basename(path.splitext(imf)[0])
        ann = path.join(self.ann_dir, bname) + '.png'

        # image
        image = self.xforms(Image.open(imf)).to(device)

        # mask
        mask = self.yforms(Image.open(ann)).to(device)

        # label
        category = path.basename('_'.join(imf.split("_")[:-1]))
        labels = torch.tensor([self.breed_assoc[category]]).to(torch.int64).to(device)
        
        # box
        get_edge_pixels = lambda x: ((x* 300 ).floor() - 1) == 2.0
        edge_pixels = get_edge_pixels(mask.squeeze())
        indices = torch.nonzero(edge_pixels)
        
        if indices.numel() == 0:
            left_x = 0
            bottom_y = 0
            right_x = 224
            top_y = 224

        else:
            left_x = indices[:,0].min()
            right_x = indices[:,0].max()
            top_y = indices[:,1].max()
            bottom_y = indices[:,1].min()
            
        boxes = torch.tensor([left_x,bottom_y,right_x,top_y]).unsqueeze(0).to(device)
        
        return image, {"boxes": boxes, "labels": labels, "masks": mask}

In [11]:
transformx = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
transformy = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
ds = PetDataset(".", transformx, transformy)
to_pil = transforms.ToPILImage()

In [12]:
sample_im, label = ds[0]
get_edge_pixels = lambda x: ((x* 300 ).floor() - 1) == 2.0
mask = label["masks"].squeeze()
edge_pixels = get_edge_pixels(mask)
indices = torch.nonzero(edge_pixels)
if indices.numel() == 0:
    left_x = 0
    bottom_y = 0
    right_x = 224
    top_y = 224
else:
    left_x = indices[:,0].min()
    right_x = indices[:,0].max()
    top_y = indices[:,1].max()
    bottom_y = indices[:,1].min()
indices[:,0]

tensor([  0,   0,   0,  ..., 223, 223, 223], device='cuda:0')

In [13]:
ds.image_files[220]

'./images/Egyptian_Mau_165.jpg'

The input to the model is expected to be a list of tensors, each of shape [C, H, W],
one for each image, and should be in 0-1 range. Different images can have different sizes.

The behavior of the model changes depending on if it is in training or evaluation mode.

During training, the model expects both the input tensors and targets (list of dictionary), containn= H.

labels (Int64Tensor[N]): the class label for each ground-tuth box

masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance

In [14]:
model = det.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT).to(device)
opt = torch.optim.Adam(model.parameters())
num_epochs = 10
for epoch in range(num_epochs):
    for idx, (x, y) in tqdm.tqdm(enumerate(ds), total=len(ds)):
        preds = model([x],[y])
        losses = sum(loss for loss in preds.values())
        opt.zero_grad()
        losses.backward()
        opt.step()

  4%|█▌                                   | 321/7393 [00:40<14:46,  7.98it/s]


RuntimeError: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]

In [17]:
ds.image_files[321]

'./images/Abyssinian_34.jpg'