In [21]:
# %load oxford_pets_train_script.py
import glob
import os.path as path
import pickle
import random
import torch
import torchvision
import torchvision.models.detection as det
import torchvision.transforms as transforms
import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor


device = "cuda" if torch.cuda.is_available() else "cpu"


class PetDataset(Dataset):
    def __init__(self, root_dir, xforms, yforms, augs):
        self.ann_dir = path.join(root_dir, "annotations", "trimaps")
        self.image_dir = path.join(root_dir, "images")
        self.image_files = glob.glob(path.join(self.image_dir, "*"))
        self.image_files = [
            x for x in self.image_files if path.splitext(x)[1] == ".jpg"
        ]
        self.image_files = [
            x for x in self.image_files if Image.open(x).format == "JPEG"
        ]
        self.image_files = [x for x in self.image_files if Image.open(x).mode == "RGB"]
        self.last_mrcnn_idx = 91
        
        self.breed_assoc = {
            x: self.last_mrcnn_idx + idx
            for idx, x in enumerate(
                sorted(
                    list(
                        
                    )
                )
            )
        }
        self.num_classes = max([v for k, v in self.breed_assoc.items()]) + 1
        self.xforms = xforms
        self.yforms = yforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        imf = self.image_files[idx]
        bname = path.basename(path.splitext(imf)[0])
        ann = path.join(self.ann_dir, bname) + ".png"

        image = self.xforms(Image.open(imf).to(device))

        mask = self.yforms(Image.open(ann).to(device))
        xformed_chans = self.augs(torch.cat(image, mask))
        image = xformed_chans[0:3]
        mask = xformed_chans[3:]
        
        unnormed_mask = (mask * 300).floor()
        unnormed_boundary = (unnormed_mask == 3.0).to(torch.float)
        unnormed_interior = (unnormed_mask == 1.0).to(torch.float)
        unnormed_exterior = (unnormed_mask == 2.0).to(torch.float)
        mask = 0.5 * unnormed_boundary + unnormed_interior

        category = path.basename("_".join(imf.split("_")[:-1]))
        labels = torch.tensor([self.breed_assoc[category]]).to(torch.int64).to(device)

        indices = torch.nonzero(mask.squeeze())

        if indices.numel() == 0:
            left_x = 0
            bottom_y = 0
            right_x = 224
            top_y = 224

        else:
            left_x = indices[:, 0].min()
            right_x = indices[:, 0].max()
            top_y = indices[:, 1].max()
            bottom_y = indices[:, 1].min()

        boxes = torch.tensor([left_x, bottom_y, right_x, top_y]).unsqueeze(0).to(device)

        return image, {"boxes": boxes, "labels": labels, "masks": mask}


transformx = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
transformy = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
aug = transforms.Compose([])
#aug = transforms.Compose([transforms.AutoAugment()])
    
ds = PetDataset(".", transformx, transformy, aug)
train_len = int(0.8 * len(ds))
test_len = len(ds) - train_len


train_dataset, test_dataset = random_split(
    ds, [train_len, test_len], generator=torch.Generator().manual_seed(42)
)

def unnormalize(tensor):
    for t, m, s in zip(tensor,[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):
        t.mul_(s).add_(m)
    return tensor

def custom_collate(batch):
    imgs = [b[0] for b in batch]
    targets = [b[1] for b in batch]

    return torch.stack(imgs, dim=0), targets


bs = 1
dl = DataLoader(train_dataset, shuffle=True, collate_fn=custom_collate, batch_size=bs)
tl = DataLoader(test_dataset, collate_fn=custom_collate, batch_size=bs)

to_pil = transforms.ToPILImage()


In [None]:
model = torch.load("checkpoint_8").to(device)

tdata, tlabel = test_dataset[105]
model.eval()
with torch.no_grad():
    preds = model([tdata])

tim = to_pil(unnormalize(tdata))

plt.imshow(tim)
pbox = preds[0]['boxes'][0]
lbox = tlabel['boxes'][0]
for box in preds[0]['boxes']:
    plt.gca().add_patch(plt.Rectangle((int(lbox[0]), int(lbox[1])), int(lbox[3] - lbox[1]), int(lbox[2] - lbox[0] ), linewidth=1, edgecolor='g', facecolor='none'))
    plt.gca().add_patch(plt.Rectangle((int(pbox[0]), int(pbox[1])), int(pbox[3] - pbox[1]), int(pbox[2] - pbox[0] ), linewidth=1, edgecolor='r', facecolor='none'))
    break
plt.show()
for mask in preds[0]["masks"]:
    plt.imshow(to_pil((mask>0.5).to(torch.uint8)))
    plt.show()
    break
lmask = tlabel['masks'][0]
plt.imshow(to_pil(lmask[:100]))
plt.show()


In [None]:
lmask = tlabel['masks']
indices = torch.nonzero(lmask.squeeze())
left_x = indices[:, 0].min()
right_x = indices[:, 0].max()
top_y = indices[:, 1].max()
bottom_y = indices[:, 1].min()
boxes = torch.tensor([left_x, bottom_y, right_x, top_y]).unsqueeze(0).to(device)
boxes

In [None]:
lmask.shape