In [None]:
#You may download weights from here https://www.dropbox.com/scl/fi/tcdyhar7u5ru53lct3nr5/cls_model.pth?rlkey=ctyuiet6ralm5rqvzmkt46z4p&st=vkjyd0wz&dl=0

In [20]:
from torchvision import datasets, models, transforms
import torch
import cv2
import numpy as np
from PIL import Image
import os

# Define your modified ResNet model
model = models.resnet50(pretrained=False).to('cpu')
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 18)  # Assuming you have 18 classes now
model = model.to("cpu")
# Load the state dictionary partially
state_dict = torch.load('cls_model.pth',map_location='cpu')
model.load_state_dict(state_dict)  # Set strict=False to ignore missing keys

# Set the model to evaluation mode
model.eval()

OUTPUT_DIR = "classification_outputs"

class_names = ['autocomplete',
 'button',
 'checkboxlist',
 'combobox',
 'currencyedit',
 'datagrid',
 'datetimeedit',
 'dropdownbutton',
 'fileuploader',
 'maskedit',
 'memoedit',
 'negative',
 'numericedit',
 'pager',
 'passwordedit',
 'percentedit',
 'radiobuttonlist',
 'textedit']

  state_dict = torch.load('cls_model.pth',map_location='cpu')


In [21]:
def pad_image(img, target_size):
    
    # Calculate padding dimensions
    width, height = img.size
    target_width, target_height = target_size
    pad_width = max(target_width - width, 0)
    pad_height = max(target_height - height, 0)

    # Calculate padding
    left_pad = pad_width // 2
    right_pad = pad_width - left_pad
    top_pad = pad_height // 2
    bottom_pad = pad_height - top_pad

    # Apply padding
    padded_img = Image.new(img.mode, target_size, (255, 255, 255))  # Create a white canvas
    padded_img.paste(img, (left_pad, top_pad))

    return padded_img

In [22]:
def transform_roi(img):

        #wide and tall
        if img.shape[0] >= 244 and img.shape[1] >= 244: 
            upper_left = img[0:80, 0:120]
            lower_left = img[(img.shape[0]-80):img.shape[0],     0:120]
            upper_right = img[0:80, (img.shape[1]-120):img.shape[1]]
            lower_right = img[img.shape[0]-80:img.shape[0], img.shape[1]-120:img.shape[1]]
            center = img[img.shape[0]//2-40:img.shape[0]//2+40, img.shape[1]//2-120:img.shape[1]//2+120]
            upper_combined = np.concatenate((upper_left, upper_right), axis=1)
            lower_combined = np.concatenate((lower_left, lower_right), axis=1)
            combined_parts = np.concatenate((upper_combined, center, lower_combined), axis=0)
            img = Image.fromarray(combined_parts)

        elif (80 <= img.shape[0] < 244) and img.shape[1] >= 244: 
            upper_left = img[0:80, 0:120]
            lower_left = img[(img.shape[0]-80):img.shape[0], 0:120]
            upper_right = img[0:80, (img.shape[1]-120):img.shape[1]]
            lower_right = img[img.shape[0]-80:img.shape[0], img.shape[1]-120:img.shape[1]]
            # center = img[img.shape[0]//2-40:img.shape[0]//2+40, img.shape[1]//2-120:img.shape[1]//2+120]
            upper_combined = np.concatenate((upper_left, upper_right), axis=1)
            lower_combined = np.concatenate((lower_left, lower_right), axis=1)
            combined_parts = np.concatenate((upper_combined, lower_combined), axis=0)
            img = Image.fromarray(combined_parts)

        elif img.shape[0] <= 80 and img.shape[1] >= 488:
            left = img[0:80, 0:244]
            right = img[0:80, img.shape[1]-244:img.shape[1]]
            center = img[0:80, img.shape[1]//2-122:img.shape[1]//2+122]
            combined_parts = np.concatenate((left, center, right), axis=0)
            img = Image.fromarray(combined_parts)
        elif img.shape[0] <= 80 and img.shape[1] >= 300:
            left = img[0:80, 0:244]
            right = img[0:80, img.shape[1]-244:img.shape[1]]
            # center = img[0:img.shape[0], img.shape[1]//2-122:img.shape[1]//2+122]
            combined_parts = np.concatenate((left, right), axis=0)
            img = Image.fromarray(combined_parts)
       
        else:
            img = Image.fromarray(img)

        return img
        

In [24]:
img =  cv2.imread(r"test_img.png")

height, width, _ = img.shape

img = img[0:height//1, 0:width]

gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)

alpha = 1.0  # Contrast control (1.0-3.0)
beta = 0     # Brightness control (0-100)
enhanced_image = cv2.convertScaleAbs(gray, alpha=alpha, beta=beta)

os.makedirs(OUTPUT_DIR, exist_ok=True)

thresh_inv = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)[1]
blur = cv2.GaussianBlur(thresh_inv,(3,3),0)
kernel = np.ones((3,3), np.uint8)
dilated = cv2.dilate(blur, kernel, iterations=1)
thresh = cv2.threshold(dilated, 50, 255, cv2.THRESH_BINARY)[1]
contours = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]

cv2.imwrite(os.path.join(OUTPUT_DIR,"thresh_inv.png"),thresh_inv)
cv2.imwrite(os.path.join(OUTPUT_DIR,"dilated.png"),dilated)
cv2.imwrite(os.path.join(OUTPUT_DIR,"thresh.png"),thresh)

mask = np.ones(img.shape[:2], dtype="uint8") * 255
for i,c in enumerate(contours):
    # get the bounding rect
    x, y, w, h = cv2.boundingRect(c)
    if 100<w*h<1000000:
        cv2.rectangle(mask, (x, y), (x+w, y+h), (0, 0, 255), -1)

        roi_box = img[y:y+h, x:x+w]
        transformed_roi = transform_roi(roi_box)
        
        target_size = (244, 244)
        transformed_roi = pad_image(transformed_roi, target_size)

        transform = transforms.Compose([transforms.ToTensor()])

        input_tensor = transform(transformed_roi).unsqueeze(0)  # Add batch dimension

        # Forward pass
        with torch.no_grad():
            output = model(input_tensor)

        # Get predicted class probabilities
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # Get predicted class index
        predicted_class = torch.argmax(probabilities).item()

        cv2.putText(img, class_names[predicted_class], (x,y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color = (255, 0, 0))
        cv2.imwrite(os.path.join(OUTPUT_DIR,str(i)+class_names[predicted_class]+".png"),np.array(transformed_roi))
        

res_final = cv2.bitwise_and(img, img, mask=cv2.bitwise_not(mask))
cv2.imwrite(os.path.join(OUTPUT_DIR,"final.png"),res_final)

True