In [7]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import warnings
import os, json, cv2, random
from PIL import Image, ImageOps
from os import listdir
from os.path import isfile, join
from multiprocessing import Pool
from tqdm import notebook

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

warnings.filterwarnings("ignore")

In [8]:

mypath="train"
onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]

In [9]:
class PetfinderPrep(object):
    def __init__(self):
        self.predictor = None
        self.cfg = None
        self.path = "train"
        self.save_path = "prep"
        self.load_model()
        self.class_labels = [15, 16, 17]
        pass
    
    def to_PIL(self, img_cv2):
        img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
        return Image.fromarray(img_cv2)
    
    def load_model(self):
        cfg = get_cfg()
        cfg.MODEL.DEVICE='cpu'
        # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
        cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
        # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        self.predictor = DefaultPredictor(cfg)
        self.cfg = cfg
        
    def predict(self, img):
        return self.predictor(img)
    
    def is_mask(self, masks, j, i, relevant_masks):
        for x in range(len(masks)):
            if x in relevant_masks:
                if masks[x][j][i] == True:
                    return True
        return False
    
    def relevant_masks(self, pred_classes, class_labels=None):
        if class_labels is None:
            class_labels = self.class_labels
        relevant_masks = []
        for i, itm in enumerate(pred_classes):
            if itm in class_labels:
                relevant_masks.append(i)
        return relevant_masks
    
    def delete_background(self, img):
        outputs = self.predict(img)
        if len(outputs["instances"].pred_classes) < 1:
            warnings.warn("NO CLASSES FOUND")
        rm = self.relevant_masks(outputs["instances"].pred_classes, self.class_labels)
        mask = outputs["instances"].pred_masks
        i = len(mask[0][0])
        j = len(mask[0])
        test = img.copy()
        for j1 in range(j):
            for i1 in range(i):
                if not self.is_mask(mask, j1, i1, rm):
                      test[j1,i1] = 255
        return test
    
    def resize_img(self, img, size=(1280, 1280), fill=(255, 255, 255)):
        img_white = self.delete_background(img)
        img_rs = self.to_PIL(img_white)
        img_rs = self.resize_with_padding(img_rs, size, fill)
        return img_rs


    def resize_with_padding(self, img, expected_size, fill):
        img.thumbnail((expected_size[0], expected_size[1]))
        delta_width = expected_size[0] - img.size[0]
        delta_height = expected_size[1] - img.size[1]
        pad_width = delta_width // 2
        pad_height = delta_height // 2
        padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
        return ImageOps.expand(img, padding, fill=fill)
    
    def viz(self, img, outputs):
        v = Visualizer(img[:, :, ::-1], MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]), scale=1.2)
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        return self.to_PIL(out.get_image()[:, :, ::-1])
    
    def wrapper(self, image):
        img = cv2.imread(image)
        out = self.predict(img)
        viz = self.viz(img, out)
        return self.to_PIL(img), self.resize_img(img), viz, out
    
    def prep_img(self, image, save=True):
        img = cv2.imread(f"{self.path}/{image}")
        img_rs = self.resize_img(img)
        if save:
            img_rs.save(f"{self.save_path}/{image}")
        return img_rs

In [10]:
def process_images(images):
    pp = PetfinderPrep()
    for image in images:
        pp.prep_img(image)
        
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

In [11]:
ch = list(chunks(onlyfiles,500))

In [None]:
with Pool() as p:
    p.map(process_images, ch)

The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generat

In [74]:
pp = PetfinderPrep()

The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


In [None]:
pp.prep_img("6285b27404f2b2bf3df0ecaecc5c8b89.jpg")
org, resz, viz, out = pp.wrapper(f"train/6285b27404f2b2bf3df0ecaecc5c8b89.jpg")
resz.show()

In [51]:
for file in notebook.tqdm(onlyfiles):
    image = f"train/{file}"
    img = cv2.imread(image)
    pp.resize_img(img).save(f"prep/{file}")

  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [45]:
Image.open(f"prep/{file}").size

(1500, 1500)

In [47]:
wl=[]
hl=[]
for file in notebook.tqdm(onlyfiles):
    w,h = Image.open(f"train/{file}").size
    wl.append(w)
    hl.append(h)

  0%|          | 0/9912 [00:00<?, ?it/s]

In [48]:
max(wl)

1280

In [49]:
max(hl)

1280