In [43]:
import torch, torchvision
from torchvision import transforms
import numpy as np
import cv2
import os
import shutil
import json
import re
import glob
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from multiprocessing import Process, Manager
import gc
from PIL import Image, ImageDraw, ImageFont


In [44]:
# Setting CUDA devices as visible
cuda_devices = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices

In [45]:
available_devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
print("Available CUDA devices:")
for i, device_name in enumerate(available_devices):
    print(f"  {i}: {device_name}")

Available CUDA devices:
  0: Tesla P100-PCIE-16GB


In [46]:
torch.cuda.empty_cache()
gc.collect()

1594

In [47]:
def get_soi(str1, start_char, end_char):
    str1 = str(str1)
    offst = len(start_char)
    ind1 = str1.find(start_char)
    ind2 = str1.find(end_char)
    s_str = str1[ind1+offst:ind2]
    return s_str

In [48]:
def convert_bbox_xywh(b):
    x1, y1, x2, y2 = b
    x = x1
    y = y1
    w = x2 - x1
    h = y2 - y1
    return [x, y, w, h]

In [49]:

def createDataDict (fn, outputs):
    img_shape = list(outputs["instances"].image_size)
    img_h = int(img_shape[0])
    img_w = int(img_shape[1])
    ann_list = []

    class_list = get_soi(outputs["instances"].pred_classes, "[", "]").split(",")
    
    if class_list[0] != "":

        class_list_new = []
        for each in class_list:
            if each.strip().isdigit():
                class_list_new.append(int(each.strip()))
            else:
                print(f"Invalid class ID: {each}")

        bbox_list = get_soi(outputs["instances"].pred_boxes, "[[", "]]").split("]")
        bbox_list_new = []
        for each in bbox_list:
            bbox = re.sub("['[,\n]", "", each).split(" ")
            bbox_new = []
            for item in bbox:
                if item != "":
                    bbox_new.append(float(item))
            bbox_new = convert_bbox_xywh(bbox_new)
            bbox_list_new.append(bbox_new)

        for i in range(0, len(class_list)):
            # og was "bbox_mode": "<BoxMode.XYWH_ABS: 1>"
            ann_list.append({"iscrowd": 0, "bbox": bbox_list_new[i], "category_id": class_list_new[i], "bbox_mode": 0})
    
    data_dict = {
        "file_name": fn,
        "height": img_h,
        "width": img_w, 
        "annotations": ann_list
    }
 
    return data_dict

In [50]:
def crop_image(file_path, bounding_box, padding):
    
    with Image.open(file_path) as img:
        
        x_min, y_min, width, height = bounding_box

        # Calculate padding in pixels
        pad_width = int(width * padding)
        pad_height = int(height * padding)

        # Adjust the bounding box with padding
        x_min = max(x_min - pad_width, 0)
        y_min = max(y_min - pad_height, 0)
        x1 = min(x_min + width + 2 * pad_width, img.width)
        y1 = min(y_min + height + 2 * pad_height, img.height)
        
        cropped_img = img.crop((x_min, y_min, x1, y1))
        
        return cropped_img

In [51]:
def paste_to_bg(image, background_color, bg_width, bg_height):
    
    # Create a new image with the specified background color and dimensions
    background = Image.new('RGB', (bg_width, bg_height), background_color)

    # Calculate the position to paste the image so it's centered
    x = (bg_width - image.width) // 2
    y = (bg_height - image.height) // 2

    # Paste the image onto the background
    background.paste(image, (x, y), image if image.mode == 'RGBA' else None)

    return background

In [52]:
def resize_ar_lock(img, target_size):

    original_width, original_height = img.size
    target_width, target_height = target_size

    # Calculate scaling factor
    scaling_factor = min(target_width / original_width, target_height / original_height)

    # Calculate new dimensions
    new_width = max(int(original_width * scaling_factor), 1)
    new_height = max(int(original_height * scaling_factor), 1)

    # Resize the image
    resized_img = img.resize((new_width, new_height))

    return resized_img

In [53]:
def gen_rand_str(length):
    characters = string.ascii_letters + string.digits
    random_string = ''.join(random.choice(characters) for i in range(length))
    return random_string

In [54]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [55]:
def process_image(img_fp, bbox_list, padding, bg_color, border):
    
    # Create an empty list to store processed images
    processed_images = []

    for j, bbox in enumerate(bbox_list):

        try:
            elem_img = crop_image(img_fp, bbox, padding)
            e_w = elem_img.size[0]
            e_h = elem_img.size[1]

            if e_w < e_h:
                elem_img = paste_to_bg(elem_img, bg_color, e_h + border, e_h + border)
            elif e_w > e_h:
                elem_img = paste_to_bg(elem_img, bg_color, e_w + border, e_w + border)
                
            # elem_img = transform(elem_img)
            processed_images.append(elem_img)

        except Exception as e:
            print(img_fp)
            print(e)

    # Return the list of processed images
    return processed_images


In [56]:
def draw_bounding_boxes(image_path, bbox_list, label_list, output_path):
    # Open the image
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)

    # Load a font
    font = ImageFont.load_default()

    # Draw bounding boxes and labels
    for bbox, label in zip(bbox_list, label_list):
        x, y, w, h = bbox
        draw.rectangle([x, y, x+w, y+h], outline="red")
        draw.text((x, y), label, fill="red", font=font)

    # Save the new image
    image.save(output_path)

In [57]:
in_dir = "/mnt/nis_lab_research/data/coco_files/test/test1"
res_out_dir = "./res_out_dir"
vis_out_dir = "./res_out_dir/vis"
cat_path = "/mnt/nis_lab_research/data/elem_cat/cat_neg_27.json"
bg_color = "white"
padding = 0.05
border = 0
remove_neg = True
neg_class_name = "Random"

In [58]:
if os.path.exists(res_out_dir):
    shutil.rmtree(res_out_dir)
os.makedirs(vis_out_dir)

In [59]:
setup_logger()
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
# cfg.MODEL.WEIGHTS = os.path.join("/home/dtron2_user/ls_dtron2_full/model/output", "model_final.pth")
cfg.MODEL.WEIGHTS = os.path.join("/mnt/nis_lab_research/data/pth", "far_shah_b1-b5_b8_train_EOI.pth")
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50

In [60]:
classifier_pth = "/mnt/nis_lab_research/data/pth/far_shah_b1-b5_b8_train_neg_ep25.pth"

In [61]:
img_path_list = [os.path.join(in_dir, "images", img_path) for img_path in os.listdir(os.path.join(in_dir, "images"))]

In [62]:
obj_det_pred = DefaultPredictor(cfg)
classifier = torch.load(classifier_pth)
classifier.eval()

[32m[04/11 23:28:05 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from /mnt/nis_lab_research/data/pth/far_shah_b1-b5_b8_train_EOI.pth ...


DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [63]:
with open(cat_path, 'r') as f:
    cats = json.load(f)

In [64]:
neg_class_id = ""

for key, value in cats.items():
    if value == neg_class_name:
        neg_class_id = key


In [65]:
master_dict = []

for i, img_path in enumerate(img_path_list):
    # Creating master dictionary of detected elements
    img = cv2.imread(img_path)
    outputs = obj_det_pred(img)
    
    print(i, img_path)
    
    data_dict = createDataDict(img_path, outputs)
    bbox_list = [ann["bbox"] for ann in data_dict["annotations"]]
    elem_img_list = process_image(data_dict["file_name"], bbox_list, padding, bg_color, border)
    
    pred_ids = []
    pred_classes = []
    remove_list = []
    
    for j, img in enumerate(elem_img_list):
        
        img_t = transform(img.convert('RGB')).unsqueeze(0).to('cuda')
        
        with torch.no_grad():
            output = classifier(img_t)
        _, predicted = torch.max(output, 1)
        pred_class_id = str(predicted.item() + 1)
        pred_class_name = cats[pred_class_id]
        
        if remove_neg and pred_class_id == neg_class_id:
            remove_list.append(j)

        pred_ids.append(pred_class_id)
        pred_classes.append(pred_classes)
        # data_dict["annotations"][j]["category_id"] = int(pred_class_id)
        data_dict["annotations"][j]["category_id"] = pred_class_name
        
    print("predicted number:", len(data_dict["annotations"]))
    
    if remove_neg and remove_list:
         for ind in sorted(remove_list, reverse=True):
            data_dict["annotations"].pop(ind)
            bbox_list.pop(ind)
            pred_ids.pop(ind)
            pred_classes.pop(ind)
            
    print("cleaned number:", len(data_dict["annotations"]))
        
    vis_outpath = os.path.join(vis_out_dir, os.path.basename(img_path))
    draw_bounding_boxes(img_path, bbox_list, pred_ids, vis_outpath)
    master_dict.append(data_dict)
    
res_outpath = os.path.join(res_out_dir, "results.json")
print("writing out results to", res_outpath)
with open(res_outpath, 'w+') as f:
    json.dump(master_dict, f, indent=4)

0 /mnt/nis_lab_research/data/coco_files/test/test1/images/0mpdgmPb3OKP7GwG-investorideas_ss.png
predicted number: 42
cleaned number: 40
1 /mnt/nis_lab_research/data/coco_files/test/test1/images/hLv3ijm31G1URmH6-navyfederal_ss.png
predicted number: 31
cleaned number: 30
2 /mnt/nis_lab_research/data/coco_files/test/test1/images/L9oDbA9sni99lgdr-brightspace_ss.png
predicted number: 25
cleaned number: 24
3 /mnt/nis_lab_research/data/coco_files/test/test1/images/tof4mTWQkTh1zRSe-oasiscannabis_ss.png
predicted number: 12
cleaned number: 12
4 /mnt/nis_lab_research/data/coco_files/test/test1/images/QNdJuEakVFnJTMgq-uptodown_ss.png
predicted number: 24
cleaned number: 23
5 /mnt/nis_lab_research/data/coco_files/test/test1/images/LgTCcgKP5TdxPSX4-gettyimages_ss.png
predicted number: 32
cleaned number: 29
6 /mnt/nis_lab_research/data/coco_files/test/test1/images/b12AVGFM3lA4v2C0-shortcutsolutions_ss.png
predicted number: 21
cleaned number: 21
7 /mnt/nis_lab_research/data/coco_files/test/test1/imag