In [1]:
import cv2
from time import time
import re
import math
import numpy as np
import pandas as pd
import random
from collections import defaultdict
import html
import os
gpu_id = 1
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
bounding_box_size = 100

import torch
import matplotlib.pyplot as plt
import torchvision.transforms as T
from garuda.od import ConfusionMatrix
from typing import List
from garuda.core import obb_iou
from dataclasses import dataclass
# from my_metrics.metrics import ConfusionMatrix as MyConfusionMatrix

from geochat.model.builder import load_pretrained_model
from geochat.mm_utils import  get_model_name_from_path
from geochat.conversation import conv_templates, Chat

from PIL import Image, ImageEnhance
from glob import glob

torch.cuda.empty_cache()

[2025-05-27 01:34:57,143] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
def escape_markdown(text):
    # List of Markdown special characters that need to be escaped
    md_chars = ['<', '>']

    # Escape each special character
    for char in md_chars:
        text = text.replace(char, '\\' + char)

    return text

def reverse_escape(text):
    md_chars = ['\\<', '\\>']

    for char in md_chars:
        text = text.replace(char, char[1:])

    return text

def bbox_and_angle_to_polygon(x1, y1, x2, y2, a):
    # Calculate center coordinates
    x_ctr = (x1 + x2) / 2
    y_ctr = (y1 + y2) / 2
    
    # Calculate width and height
    w = abs(x2 - x1)
    h = abs(y2 - y1)
    
    # Calculate the angle in radians
    angle_rad = math.radians(a)
    
    # Calculate coordinates of the four corners of the rotated bounding box
    cos_a = math.cos(angle_rad)
    sin_a = math.sin(angle_rad)
    
    x1_rot = cos_a * (-w / 2) - sin_a * (-h / 2) + x_ctr
    y1_rot = sin_a * (-w / 2) + cos_a * (-h / 2) + y_ctr
    
    x2_rot = cos_a * (w / 2) - sin_a * (-h / 2) + x_ctr
    y2_rot = sin_a * (w / 2) + cos_a * (-h / 2) + y_ctr
    
    x3_rot = cos_a * (w / 2) - sin_a * (h / 2) + x_ctr
    y3_rot = sin_a * (w / 2) + cos_a * (h / 2) + y_ctr
    
    x4_rot = cos_a * (-w / 2) - sin_a * (h / 2) + x_ctr
    y4_rot = sin_a * (-w / 2) + cos_a * (h / 2) + y_ctr
    
    # Return the polygon coordinates
    polygon_coords = np.array((x1_rot, y1_rot, x2_rot, y2_rot, x3_rot, y3_rot, x4_rot, y4_rot))
    
    return polygon_coords

def rotate_bbox(top_right, bottom_left, angle_degrees):
    # Convert angle to radians
    angle_radians = np.radians(angle_degrees)

    # Calculate the center of the rectangle
    center = ((top_right[0] + bottom_left[0]) / 2, (top_right[1] + bottom_left[1]) / 2)

    # Calculate the width and height of the rectangle
    width = top_right[0] - bottom_left[0]
    height = top_right[1] - bottom_left[1]

    # Create a rotation matrix
    rotation_matrix = cv2.getRotationMatrix2D(center, angle_degrees, 1)

    # Create an array of the rectangle corners
    rectangle_points = np.array([[bottom_left[0], bottom_left[1]],
                                 [top_right[0], bottom_left[1]],
                                 [top_right[0], top_right[1]],
                                 [bottom_left[0], top_right[1]]], dtype=np.float32)

    # Rotate the rectangle points
    rotated_rectangle = cv2.transform(np.array([rectangle_points]), rotation_matrix)[0]

    return rotated_rectangle
def extract_substrings(string):
    # first check if there is no-finished bracket
    index = string.rfind('}')
    if index != -1:
        string = string[:index + 1]

    pattern = r'<p>(.*?)\}(?!<)'
    matches = re.findall(pattern, string)
    substrings = [match for match in matches]

    return substrings


def is_overlapping(rect1, rect2):
    x1, y1, x2, y2 = rect1
    x3, y3, x4, y4 = rect2
    return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)


def computeIoU(bbox1, bbox2):
    x1, y1, x2, y2 = bbox1
    x3, y3, x4, y4 = bbox2
    intersection_x1 = max(x1, x3)
    intersection_y1 = max(y1, y3)
    intersection_x2 = min(x2, x4)
    intersection_y2 = min(y2, y4)
    intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
    bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
    bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
    union_area = bbox1_area + bbox2_area - intersection_area
    iou = intersection_area / union_area
    return iou


def save_tmp_img(visual_img, img_name):
    # file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
    # file_path = "/tmp/gradio" + file_name
    file_path = "/tmp/gradio" + img_name
    visual_img.save(file_path)
    return file_path


def mask2bbox(mask):
    if mask is None:
        return ''
    mask = mask.resize([100, 100], resample=Image.NEAREST)
    mask = np.array(mask)[:, :, 0]

    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)

    if rows.sum():
        # Get the top, bottom, left, and right boundaries
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]
        bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
    else:
        bbox = ''

    return bbox


def escape_markdown(text):
    # List of Markdown special characters that need to be escaped
    md_chars = ['<', '>']

    # Escape each special character
    for char in md_chars:
        text = text.replace(char, '\\' + char)

    return text


def reverse_escape(text):
    md_chars = ['\\<', '\\>']

    for char in md_chars:
        text = text.replace(char, char[1:])

    return text


colors = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (210, 210, 0),
    (255, 0, 255),
    (0, 255, 255),
    (114, 128, 250),
    (0, 165, 255),
    (0, 128, 0),
    (144, 238, 144),
    (238, 238, 175),
    (255, 191, 0),
    (0, 128, 0),
    (226, 43, 138),
    (255, 0, 255),
    (0, 215, 255),
]

color_map = {
    f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
    color_id, color in enumerate(colors)
}

used_colors = colors


def visualize_all_bbox_together(image, generation):
    
    if image is None:
        return None, ''

    generation = html.unescape(generation)
    image_width, image_height = image.size
    image = image.resize([500, int(500 / image_width * image_height)]) # if image width and hight are same then image is of size 500x500
    image_width, image_height = image.size

    string_list = extract_substrings(generation)
    # print(f'String list {string_list}')
    if string_list:  # it is grounding or detection
        mode = 'all'
        entities = defaultdict(list)
        i = 0
        j = 0
        for string in string_list:
            try:
                obj, string = string.split('</p>')
            except ValueError:
                print('wrong string: ', string)
                continue
            if "}{" in string:
                string=string.replace("}{","}<delim>{")
            bbox_list = string.split('<delim>')
            # print(f'bbox_list {bbox_list}')
            flag = False
            for bbox_string in bbox_list:
                integers = re.findall(r'-?\d+', bbox_string)
                if len(integers)==4:
                    angle=0
                else:
                    angle=integers[4]
                integers=integers[:-1]
                
                if len(integers) == 4:
                    x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
                    left = x0 / bounding_box_size * image_width
                    bottom = y0 / bounding_box_size * image_height
                    right = x1 / bounding_box_size * image_width
                    top = y1 / bounding_box_size * image_height

                    entities[obj].append([left, bottom, right, top,angle])

                    j += 1
                    flag = True
            if flag:
                i += 1
    else:
        integers = re.findall(r'-?\d+', generation)
        # if len(integers)==4:
        angle=0
        # else:
            # angle=integers[4]
        integers=integers[:-1]
        if len(integers) == 4:  # it is refer
            mode = 'single'

            entities = list()
            x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
            left = x0 / bounding_box_size * image_width
            bottom = y0 / bounding_box_size * image_height
            right = x1 / bounding_box_size * image_width
            top = y1 / bounding_box_size * image_height
            entities.append([left, bottom, right, top,angle])
        else:
            # don't detect any valid bbox to visualize
            return None, '', None

    if len(entities) == 0:
        return None, '', None

    if isinstance(image, Image.Image):
        image_h = image.height
        image_w = image.width
        image = np.array(image)

    else:
        raise ValueError(f"invalid image format, {type(image)} for {image}")

    indices = list(range(len(entities)))

    new_image = image.copy()

    previous_bboxes = []
    # size of text
    text_size = 0.4
    # thickness of text
    text_line = 1  # int(max(1 * min(image_h, image_w) / 512, 1))
    box_line = 2
    (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
    base_height = int(text_height * 0.675)
    text_offset_original = text_height - base_height
    text_spaces = 2


    used_colors = colors  # random.sample(colors, k=num_bboxes)

    color_id = -1
    # print(f'entities {entities}')
    for entity_idx, entity_name in enumerate(entities):
        if mode == 'single' or mode == 'identify':
            bbox_coords = []
            bboxes = entity_name
            bboxes = [bboxes]
        else:
            bbox_coords = defaultdict(list)
            bboxes = entities[entity_name]
        color_id += 1
        for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm,angle) in enumerate(bboxes):
            skip_flag = False
            orig_x1, orig_y1, orig_x2, orig_y2,angle = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm), int(angle)

            color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
            top_right=(orig_x1,orig_y1)
            bottom_left=(orig_x2,orig_y2)
            angle=angle
            rotated_bbox = rotate_bbox(top_right, bottom_left, angle)
            # print(f'rotated_bbox {rotated_bbox}')
            if mode == 'single' or mode == 'identify':
                bbox_coords.append(rotated_bbox/500)
            else:
                bbox_coords[entity_name].append(rotated_bbox/500)
            new_image=cv2.polylines(new_image, [rotated_bbox.astype(np.int32)], isClosed=True, thickness=2, color=color)


            if mode == 'all':
                l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1

                x1 = orig_x1 - l_o
                y1 = orig_y1 - l_o

                if y1 < text_height + text_offset_original + 2 * text_spaces:
                    y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
                    x1 = orig_x1 + r_o

                # add text background
                (text_width, text_height), _ = cv2.getTextSize(f"  {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
                                                               text_line)
                text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
                            text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
                
                if not skip_flag:
                    alpha = 0.5
                    for i in range(text_bg_y1, text_bg_y2):
                        for j in range(text_bg_x1, text_bg_x2):
                            if i < image_h and j < image_w:
                                if j < text_bg_x1 + 1.35 * c_width:
                                    # original color
                                    bg_color = color
                                else:
                                    # white
                                    bg_color = [255, 255, 255]
                                new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
                                    np.uint8)

                    cv2.putText(
                        new_image, f"  {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
                        cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
                    )

                    previous_bboxes.append(
                        {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})

    if mode == 'all':
        def color_iterator(colors):
            while True:
                for color in colors:
                    yield color

        color_gen = color_iterator(colors)

        # Add colors to phrases and remove <p></p>
        def colored_phrases(match):
            phrase = match.group(1)
            color = next(color_gen)
            return f'<span style="color:rgb{color}">{phrase}</span>'

        generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
        generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
    else:
        generation_colored = ''

    pil_image = Image.fromarray(new_image)
    # print(f'bbox_coords {bbox_coords}')
    return pil_image, generation_colored, bbox_coords

In [3]:
@dataclass
class MeanAveragePrecision:
    """
    Mean Average Precision for object detection tasks.

    Attributes:
        map50_95 (float): Mean Average Precision (mAP) calculated over IoU thresholds
            ranging from `0.50` to `0.95` with a step size of `0.05`.
        map50 (float): Mean Average Precision (mAP) calculated specifically at
            an IoU threshold of `0.50`.
        map75 (float): Mean Average Precision (mAP) calculated specifically at
            an IoU threshold of `0.75`.
        per_class_ap50_95 (np.ndarray): Average Precision (AP) values calculated over
            IoU thresholds ranging from `0.50` to `0.95` with a step size of `0.05`,
            provided for each individual class.
    """

    map50_95: float
    map50: float
    map75: float
    per_class_ap50_95: np.ndarray

    @classmethod
    def from_tensors(
        cls,
        predictions: List[np.ndarray],
        targets: List[np.ndarray],
    ):
        """
        Calculate Mean Average Precision based on predicted and ground-truth
            detections at different threshold.

        Args:
            predictions (List[np.ndarray]): Each element of the list describes
                a single image and has `shape = (M, 6)` where `M` is
                the number of detected objects. Each row is expected to be
                in `(x_min, y_min, x_max, y_max, class, conf)` format.
            targets (List[np.ndarray]): Each element of the list describes a single
                image and has `shape = (N, 5)` where `N` is the
                number of ground-truth objects. Each row is expected to be in
                `(x_min, y_min, x_max, y_max, class)` format.
        Returns:
            MeanAveragePrecision: New instance of MeanAveragePrecision.

        Example:
            ```python
            import supervision as sv
            import numpy as np

            targets = (
                [
                    np.array(
                        [
                            [0.0, 0.0, 3.0, 3.0, 1],
                            [2.0, 2.0, 5.0, 5.0, 1],
                            [6.0, 1.0, 8.0, 3.0, 2],
                        ]
                    ),
                    np.array([[1.0, 1.0, 2.0, 2.0, 2]]),
                ]
            )

            predictions = [
                np.array(
                    [
                        [0.0, 0.0, 3.0, 3.0, 1, 0.9],
                        [0.1, 0.1, 3.0, 3.0, 0, 0.9],
                        [6.0, 1.0, 8.0, 3.0, 1, 0.8],
                        [1.0, 6.0, 2.0, 7.0, 1, 0.8],
                    ]
                ),
                np.array([[1.0, 1.0, 2.0, 2.0, 2, 0.8]])
            ]

            mean_average_precision = sv.MeanAveragePrecision.from_tensors(
                predictions=predictions,
                targets=targets,
            )

            print(mean_average_precision.map50_95)
            # 0.6649
            ```
        """
        # validate_input_tensors(predictions, targets)
        iou_thresholds = np.linspace(0.5, 0.95, 10)
        stats = []

        # Gather matching stats for predictions and targets
        for true_objs, predicted_objs in zip(targets, predictions):
            if predicted_objs.shape[0] == 0:
                if true_objs.shape[0]:
                    stats.append(
                        (
                            np.zeros((0, iou_thresholds.size), dtype=bool),
                            *np.zeros((2, 0)),
                            true_objs[:, 0], # index 0 is class
                        )
                    )
                continue

            if true_objs.shape[0]:
                matches = cls._match_detection_batch(
                    predicted_objs, true_objs, iou_thresholds
                )
                stats.append(
                    (
                        matches,
                        predicted_objs[:, -1],
                        predicted_objs[:, 0],
                        true_objs[:, 0],
                    )
                )

        # Compute average precisions if any matches exist
        if stats:
            concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
            average_precisions = cls._average_precisions_per_class(*concatenated_stats)
            map50 = average_precisions[:, 0].mean()
            map75 = average_precisions[:, 5].mean()
            map50_95 = average_precisions.mean()
        else:
            map50, map75, map50_95 = 0, 0, 0
            average_precisions = []

        return cls(
            map50_95=map50_95,
            map50=map50,
            map75=map75,
            per_class_ap50_95=average_precisions,
        )

    
    @staticmethod
    def _match_detection_batch(
        predictions: np.ndarray, targets: np.ndarray, iou_thresholds: np.ndarray
    ) -> np.ndarray:
        """
        Match predictions with target labels based on IoU levels.

        Args:
            predictions (np.ndarray): Batch prediction. Describes a single image and
                has `shape = (M, 6)` where `M` is the number of detected objects.
                Each row is expected to be in
                `(x_min, y_min, x_max, y_max, class, conf)` format.
            targets (np.ndarray): Batch target labels. Describes a single image and
                has `shape = (N, 5)` where `N` is the number of ground-truth objects.
                Each row is expected to be in
                `(x_min, y_min, x_max, y_max, class)` format.
            iou_thresholds (np.ndarray): Array contains different IoU thresholds.

        Returns:
            np.ndarray: Matched prediction with target labels result.
        """
        num_predictions, num_iou_levels = predictions.shape[0], iou_thresholds.shape[0]
        correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
        iou = obb_iou(targets[:, 1:9].reshape(-1,4,2), predictions[:, 1:9].reshape(-1,4,2))
        correct_class = targets[:, 0:1] == predictions[:, 0]

        for i, iou_level in enumerate(iou_thresholds):
            matched_indices = np.where((iou >= iou_level) & correct_class)

            if matched_indices[0].shape[0]:
                combined_indices = np.stack(matched_indices, axis=1)
                iou_values = iou[matched_indices][:, None]
                matches = np.hstack([combined_indices, iou_values])

                if matched_indices[0].shape[0] > 1:
                    matches = matches[matches[:, 2].argsort()[::-1]]
                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]

                correct[matches[:, 1].astype(int), i] = True

        return correct
    
    @staticmethod
    def compute_average_precision(recall: np.ndarray, precision: np.ndarray) -> float:
        """
        Compute the average precision using 101-point interpolation (COCO), given
            the recall and precision curves.

        Args:
            recall (np.ndarray): The recall curve.
            precision (np.ndarray): The precision curve.

        Returns:
            float: Average precision.
        """
        extended_recall = np.concatenate(([0.0], recall, [1.0]))
        extended_precision = np.concatenate(([1.0], precision, [0.0]))
        max_accumulated_precision = np.flip(
            np.maximum.accumulate(np.flip(extended_precision))
        )
        interpolated_recall_levels = np.linspace(0, 1, 101)
        interpolated_precision = np.interp(
            interpolated_recall_levels, extended_recall, max_accumulated_precision
        )
        average_precision = np.trapz(interpolated_precision, interpolated_recall_levels)
        return average_precision

    @staticmethod
    def _average_precisions_per_class(
        matches: np.ndarray,
        prediction_confidence: np.ndarray,
        prediction_class_ids: np.ndarray,
        true_class_ids: np.ndarray,
        eps: float = 1e-16,
    ) -> np.ndarray:
        """
        Compute the average precision, given the recall and precision curves.
        Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.

        Args:
            matches (np.ndarray): True positives.
            prediction_confidence (np.ndarray): Objectness value from 0-1.
            prediction_class_ids (np.ndarray): Predicted object classes.
            true_class_ids (np.ndarray): True object classes.
            eps (float): Small value to prevent division by zero.

        Returns:
            np.ndarray: Average precision for different IoU levels.
        """
        sorted_indices = np.argsort(-prediction_confidence)
        matches = matches[sorted_indices]
        prediction_class_ids = prediction_class_ids[sorted_indices]

        unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
        num_classes = unique_classes.shape[0]

        average_precisions = np.zeros((num_classes, matches.shape[1]))

        for class_idx, class_id in enumerate(unique_classes):
            is_class = prediction_class_ids == class_id
            total_true = class_counts[class_idx]
            total_prediction = is_class.sum()

            if total_prediction == 0 or total_true == 0:
                continue

            false_positives = (1 - matches[is_class]).cumsum(0)
            true_positives = matches[is_class].cumsum(0)
            recall = true_positives / (total_true + eps)
            precision = true_positives / (true_positives + false_positives)

            for iou_level_idx in range(matches.shape[1]):
                average_precisions[class_idx, iou_level_idx] = (
                    MeanAveragePrecision.compute_average_precision(
                        recall[:, iou_level_idx], precision[:, iou_level_idx]
                    )
                )

        return average_precisions

In [4]:
def get_predicted_target_bbox(chat, CONV_VISION, gr_imgs_path, gr_labels, type, user_message):    
    image_path_with_detected_objects = []

    target_results = []
    predicted_results = []

    total_execution_time = 0
    for img_path in sorted(glob(gr_imgs_path)):
        gr_img = Image.open(img_path).convert('RGB')
        img_list = []
        chat_state = CONV_VISION.copy()
        llm_message = chat.upload_img(gr_img, chat_state, img_list = img_list)
        # print(llm_message)
        chat.ask(user_message, chat_state) # ask the question grounding, refer, expression, scene classification etc

        if len(img_list) > 0:
            if not isinstance(img_list[0], torch.Tensor):
                chat.encode_img(img_list)
        time_start = time()     
        streamer = chat.stream_answer(conv=chat_state,
                                        img_list=img_list,
                                        temperature=0.5,
                                        max_new_tokens=500,
                                        max_length=2000)
        time_end = time()
        total_execution_time += time_end - time_start
        # print(f'Streamer {streamer}')    
        output = ''
        for new_output in streamer:
            # print(new_output)
            output=output+new_output
        # print(output)

        # output = escape_markdown(output)
        chat_state.messages[-1][1] = '</s>'

        # output = reverse_escape(output)
        # print(output)
        visual_img, generation_color, bbox = visualize_all_bbox_together(gr_img, output) # None, dictionary, list
        # print(visual_img)
        img_name = os.path.basename(img_path)
        if visual_img is not None:
            file_path = save_tmp_img(visual_img, img_name)
            image_path_with_detected_objects.append(file_path)
        # print(generation_color)

        if bbox is None:
            # predicted_results.append(np.zeros((1, 8)))
            predicted_results.append(np.array([])) # no bbox detected
        elif isinstance(bbox, dict): # grounding return bboxes in dict type
            tmp = []
            for key, values in bbox.items():
                for value in values:
                    tmp.append(np.array(value).reshape(-1))
            predicted_results.append(np.array(tmp))
        else: # [refer] return bboxes in list type
            tmp = []
            for value in bbox:
                    tmp.append(np.array(value).reshape(-1))
                    
            # save the predictions in txt file - folder name (image_type_prompt), file name (image_name)
            # np.savetxt(f'{type}/{user_message}/{img_name[:-4]}.txt',np.array(tmp))
            predicted_results.append(np.array(tmp))
        if '.tif' in img_name:
            target_path = os.path.join(gr_labels, img_name.replace('.tif', '.txt'))
        else:
            target_path = os.path.join(gr_labels, img_name.replace('.png', '.txt'))
        target_results.append(np.loadtxt(target_path, ndmin=2))

    return predicted_results, target_results, image_path_with_detected_objects, total_execution_time

In [5]:
def add_class_confidence(predicted_results):
    new_predicted_results = []
    for res in predicted_results:
        if len(res):
            res = np.hstack([np.zeros((len(res),1)), res, np.ones((len(res),1))], dtype=np.float32) # add class label 0 at index 0 and confidence score 1 at last index
            new_predicted_results.append(res)
        else:
            res = np.zeros((1, 10))
            res[:, 0] = 1
            res[:, -1] = 1
            new_predicted_results.append(res.astype(np.float32))
    return new_predicted_results

In [6]:
def modify_class(target_results):
    new_target_results = []
    for res in target_results:
        res[:,0] = 0 # convert class labels to 0
        res = res.astype(np.float32)
        new_target_results.append(res)
    return new_target_results

In [7]:
def calculate_confusion_matrix(new_predicted_results, new_target_results):
    cm_predicted_results = []
    for res in new_predicted_results:
        res[:,1:9] = res[:,1:9]*500
        cm_predicted_results.append(res)


    cm_target_results = []
    for res in new_target_results:
        res[:,0] = 0 # convert class labels to 0
        res[:,1:9] = res[:,1:9]*500
        res = res.astype(np.float32)
        cm_target_results.append(res)

    classes, conf_threshold, iou_threshold = ['brick_kilns'], 0.25, 0.5
    cm = ConfusionMatrix.from_obb_tensors(cm_predicted_results, cm_target_results, classes, conf_threshold, iou_threshold)
    # cm = MyConfusionMatrix.from_tensors(cm_predicted_results, cm_target_results, classes, conf_threshold, iou_threshold)
    df = pd.DataFrame(cm.matrix, columns = ['predicted kilns','predicted_bg'], index=['true kilns','true_bg'])
    print(f'conf_threshold = {conf_threshold}, iou_threshold = {iou_threshold}')
    # print(cm.summary)
    # print(df.to_markdown())
    return cm, df

In [8]:
def calculate_precision_recall(confusion_matrix, new_predicted_results, new_target_results):
    tp = confusion_matrix.loc['true kilns']['predicted kilns']
    predicted_positives = 0
    for res in new_predicted_results:
        predicted_positives += np.where(res[:,0] == 1, 0, 1).sum() # class = 1 means background class
    
    ground_truth = 0
    for res in new_target_results:
        ground_truth += np.where(res[:,0] == 0, 1, 0).sum() # class = 0 means brick kiln class

    precision = tp / predicted_positives
    recall = tp/ground_truth
    f1_score = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1_score

In [17]:
def plot_results(user_message, gr_imgs_path, new_predicted_results, new_target_results, iou_threshold, image_path_with_detected_objects, region, plot=False):
    if not plot:
        return
    # n = len(image_path_with_detected_objects)
    gr_img_path = sorted(glob(gr_imgs_path))
    n = len(gr_img_path)
    cols = 5
    rows = math.ceil(n / cols)
    fig, ax = plt.subplots(nrows = rows, ncols = cols ,figsize=(cols*8, rows*8))
    ax = ax.flatten()
    for i in range(n):
        # img = Image.open(image_path_with_detected_objects[i]).convert('RGB') # predicted image
        img = Image.open(gr_img_path[i]).convert('RGB') # planet image
        w, h = img.size
        ax[i].imshow(img) 
        for bbox in new_target_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4 = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'green', linewidth = 8)
        for bbox in new_predicted_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4, conf = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'red', linewidth = 8)
        ax[i].set_axis_off()

    # Hide unused subplots
    for j in range(n, len(ax)):
        ax[j].axis('off')
    # location = region.split('/data/')[1]
    # location = location.replace('/','_')
    fig.suptitle(f'{user_message}')
    # plt.savefig(f'geochat_output_refer_{location}_planet.png')
    region = region.replace('/','_')
    plt.savefig(f'geochat_output_refer_{region}.png')
    # plt.close() # will not display the plot

In [12]:
model_path = 'MBZUAI/geochat-7B'
model_name = get_model_name_from_path(model_path)
print(model_name)
device = 'cuda:{}'.format(0)
# set device_map = None to use single GPU, otherwise 'auto' to load model in all GPUs (auto will do sharding it will load layers, weights in different GPUs for better memory efficiency).
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map = None, device = device)

geochat-7B
Loading GeoChat......




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
model = model.eval()
chat = Chat(model.to(device), image_processor,tokenizer, device=device)
CONV_VISION = conv_templates['llava_v1'].copy()
# chat_state = CONV_VISION.copy()
# chat_state

In [14]:
conv_templates

{'default': Conversation(system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=('Human', 'Assistant'), messages=(('Human', 'What are the key differences between renewable and non-renewable energy sources?'), ('Assistant', 'Renewable energy sources are those that can be replenished naturally in a relatively short amount of time, such as solar, wind, hydro, geothermal, and biomass. Non-renewable energy sources, on the other hand, are finite and will eventually be depleted, such as coal, oil, and natural gas. Here are some key differences between renewable and non-renewable energy sources:\n1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable energy sources are finite and will eventually run out.\n2. Environmental impact: Renewable energy sources have a much lower environmental impact than non-renewable sources, which can lead to air 

In [15]:
# user_message = 'Give the bounding box coordinates of all the brick kilns present in the image separated by newline'
# user_message = 'Draw bounding box around the brick kiln with chimney in the image'

# Visual question answering - given the image and a question, it generates the answer.

# Scene classification - given the image, it generates the scene category.
# user_message = 'Classify the image in one word. The classes are Church, School, Bareland, Beach, Forest'

# Region-level caption - given the bounding box on the image, it generates brief description about the object
# user_message = '[identify] What is this object?'

# Grounded description - describe the object and give the bounding box.
# user_message = 'describe the image in detail'

# Refering expressions - refer to the object by providing some attributes : large, top, close etc, will produce the bounding box
user_message1 = '[refer] Where is the brick kiln with chimney in the image? Give its bounding box'
user_message2 = '[refer] Where are the fields, factories, roads and brick kilns in the image? Give its oriented bounding box'
user_message3= '[refer] Where are the fields, factories and roads in the image? Give its oriented bounding box'
user_message4 = '[refer] Where are the fields and factories in the image? Give its oriented bounding box'
user_message5 = '[refer] Where are the factories in the image? Give its oriented bounding box'
user_message6 = '[refer] Where are the fields in the image? Give its oriented bounding box'
user_message7 ='[refer] Where are the chimneys in the image? Give its oriented bounding box'
user_message8 = '[refer] Where are the factories and roads in the image? Give its oriented bounding box'
user_message9 = '[refer] Where are the fields and roads in the image? Give its oriented bounding box'
user_message10 = '[refer] Where are the roads in the image? Give its oriented bounding box'
user_message11 = '[refer] Where is the chimneys and brick kilns in the image? Give its oriented bounding box'


user_messages = [
    user_message1, 
    # user_message2, 
    # user_message3, 
    # user_message4,
    # user_message5,
    # user_message6,
    # user_message7,
    # user_message8,
    # user_message9,
    # user_message10,
    # user_message11
    ]

Processing without patches

In [None]:
# base_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data'

regions = [
            # f'{base_path}/lucknow_airshed_most_15/images', 
        #    '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/swinir_data/lucknow_airshed_most_15/images',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom18',
            # '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/uttar_pradesh_most_15/images/',
        #     f'{base_path}/uttar_pradesh_most_15/swinir_images',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom18',
        #    '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/west_bengal_most_15/images/',
        #    f'{base_path}/west_bengal_most_15/swinir_images',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom18'
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_kilns_zoom19'
        '/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images'
        ]

locations = [
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15',
            #  'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15', 
            #  'uttar_pradesh_most_15', 
            #  'west_bengal_most_15',
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15',
            'lucknow_gms_kiln_images'
             ]

type = ['lucknow_gms_kiln_images']

for user_message in user_messages:
    for region, location in zip(regions, locations):
        print(f'region: {region}, user_message: {user_message}')
        gr_imgs_path = region+'/*'
        # gr_labels = f'/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/{location}/labels' # labels related to planet imagery
        gr_labels = f'/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/labels'
        
        predicted_results, target_results, image_path_with_detected_objects, total_execution_time = get_predicted_target_bbox(chat, CONV_VISION, gr_imgs_path, gr_labels, type, user_message)
        new_predicted_results = add_class_confidence(predicted_results)
        new_target_results = modify_class(target_results)
        plot_results(user_message, gr_imgs_path, new_predicted_results, new_target_results, 0.1, image_path_with_detected_objects, region, True)
        cm, df = calculate_confusion_matrix(new_predicted_results, new_target_results)
        # print(f'Precision: {cm.precision}, Recall: {cm.recall}, F1 Score: {cm.f1_score}')
        precision, recall, f1_score = calculate_precision_recall(df, new_predicted_results, new_target_results)
        print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1_score}')
        print(f'Total execution time(s): {total_execution_time}')
        display(df)
        
        print('---------------------------------------------------------------')

Divide the image in patch of size 320x320
because 320 divides 640, 2560 and 5120


Planet, Zoom 17, Zoom 18

Overlap of pixels

89/4.77, 89/1.1925, 89/0.59625 = (20, 75, 150)

In [None]:
def patch_boxes(image_height , image_width, patch_size, overlap):
    slice_bboxes = []
    offsets = []
    y_max = y_min = 0

    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + patch_size
        while x_max < image_width:
            x_max = x_min + patch_size
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - patch_size)
                ymin = max(0, ymax - patch_size)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
                offsets.append([xmin, ymin])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
                offsets.append([x_min, y_min])
            x_min = x_max - overlap
        y_min = y_max - overlap
    return slice_bboxes, offsets



def get_predicted_target_bbox_patch(chat, CONV_VISION, gr_imgs_path, patch_size, overlap, gr_labels, type, user_message):    
    image_path_with_detected_objects = []

    target_results = []
    predicted_results = []

    total_execution_time = 0
    for img_id, img_path in enumerate(glob(gr_imgs_path)):
        gr_img = Image.open(img_path).convert('RGB')
        # orig_img = Image.open(img_path).convert('RGB')
        # gr_img = ImageEnhance.Contrast(orig_img) 
        # gr_img = gr_img.enhance(2)
        patches, offsets = patch_boxes(gr_img.size[0], gr_img.size[1], patch_size, overlap)
        tmp = []
        num_patches = len(patches)
        # fig, ax = plt.subplots(nrows = num_patches, ncols = 1, figsize=(50,50)) # visualise plot to debug
        # ax = ax.flatten()
        for i, patch_offset in enumerate(zip(patches, offsets)):
            patch, offset = patch_offset
            patch = gr_img.crop(patch)
            # ax[i].imshow(patch) # visualise plot to debug
            img_list = []
            chat_state = CONV_VISION.copy()
            llm_message = chat.upload_img(patch, chat_state, img_list = img_list)
            # print(llm_message)
            chat.ask(user_message, chat_state) # ask the question grounding, refer, expression, scene classification etc

            if len(img_list) > 0:
                if not isinstance(img_list[0], torch.Tensor):
                    chat.encode_img(img_list)
            time_start = time()     
            streamer = chat.stream_answer(conv=chat_state,
                                            img_list=img_list,
                                            temperature=0.5,
                                            max_new_tokens=500,
                                            max_length=2000)
            time_end = time()
            total_execution_time += time_end - time_start
            output = ''
            for new_output in streamer:
                output=output+new_output


            chat_state.messages[-1][1] = '</s>'


            visual_img, generation_color, bbox = visualize_all_bbox_together(patch, output) # None, dictionary, list
            img_name = os.path.basename(img_path)
            if visual_img is not None:
                file_path = save_tmp_img(visual_img, img_name)
                image_path_with_detected_objects.append(file_path)

            if bbox is None:
                pass
            elif isinstance(bbox, dict): # grounding return bboxes in dict type
                for key, values in bbox.items():
                    for value in values:
                        # x = np.array(value).reshape(-1)
                        # ax[i].plot(x[[0,2,4,6,0]]*patch_size,x[[1,3,5,7,1]]*patch_size, color = 'red') #debug
                        value = (value*patch_size + offset)/gr_img.size[0]
                        value = np.array(value).reshape(-1)
                        tmp.append(value)
            else: # [refer] return bboxes in list type
                for value in bbox:
                        # x = np.array(value).reshape(-1)
                        # ax[i].plot(x[[0,2,4,6,0]]*patch_size,x[[1,3,5,7,1]]*patch_size, color = 'red') #debug
                        value = (value*patch_size + offset)/gr_img.size[0]
                        value = np.array(value).reshape(-1)
                        tmp.append(value)
                
        # save the predictions in txt file - folder name (image_type_prompt), file name (image_name)
        np.savetxt(f'{type}/{user_message}/{img_name[:-4]}.txt',np.array(tmp))
        predicted_results.append(np.array(tmp))

        if '.tif' in img_name:
            target_path = os.path.join(gr_labels, img_name.replace('.tif', '.txt'))
        else:
            target_path = os.path.join(gr_labels, img_name.replace('.png', '.txt'))
        target_results.append(np.loadtxt(target_path, ndmin=2))

        # fig, ax = plt.subplots()
        # ax.imshow(gr_img)
        # for bbox in target_results[img_id]:
        #     classvalue, x1, y1, x2, y2, x3, y3, x4, y4 = bbox*gr_img.size[0]
        #     ax.plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'green')
        # for bbox in predicted_results[img_id]:
        #     x1, y1, x2, y2, x3, y3, x4, y4 = bbox*gr_img.size[0]
        #     ax.plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'red')

    return predicted_results, target_results, image_path_with_detected_objects, total_execution_time

Patch wise processing of images

In [19]:
patch_size = 320
overlap = 20 # 640x640
# overlap = 75 # 2560x2560
base_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data'

regions = [
            # f'{base_path}/lucknow_airshed_most_15/images', 
         #   '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/swinir_data/lucknow_airshed_most_15/images',
         #   '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom18',
            '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/uttar_pradesh_most_15/images/',
         #    f'{base_path}/uttar_pradesh_most_15/swinir_images',
         #   '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom18',
           '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/west_bengal_most_15/images/',
         #   f'{base_path}/west_bengal_most_15/swinir_images',
         #   '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom18'
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_kilns_zoom19'
        ]

locations = [
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15',
             'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15', 
            #  'uttar_pradesh_most_15', 
             'west_bengal_most_15',
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15',
             ]
type = [
   #  'lucknow_airshed_most_15_planet',
   #  'lucknow_airshed_most_15_swinir',
   #  'lucknow_airshed_most_15_zoom17',
   #  'lucknow_airshed_most_15_zoom18',
    'uttar_pradesh_most_15_planet',
   #  'uttar_pradesh_most_15_swinir',
   #  'uttar_pradesh_most_15_zoom17',
   #  'uttar_pradesh_most_15_zoom18',
    'west_bengal_most_15_planet',
   #  'west_bengal_most_15_swinir',
   #  'west_bengal_most_15_zoom17',
   #  'west_bengal_most_15_zoom18',
   #  'lucknow_kilns_zoom19',
   #  'uttar_pradesh_kilns_zoom19',
   #  'west_bengal_kilns_zoom19'
]

for user_message in user_messages:
    for region, location, type in zip(regions, locations, type):
        print(f'region: {region}, user_message: {user_message}')
        gr_imgs_path = region+'/*'
        gr_labels = f'/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/{location}/labels'

        predicted_results, target_results, image_path_with_detected_objects, total_execution_time = get_predicted_target_bbox_patch(chat, CONV_VISION, gr_imgs_path, patch_size, overlap, gr_labels, type, user_message)
        new_predicted_results = add_class_confidence(predicted_results)
        new_target_results = modify_class(target_results)
        plot_results(user_message, gr_imgs_path, new_predicted_results, new_target_results, image_path_with_detected_objects, region, True)
        cm, df = calculate_confusion_matrix(new_predicted_results, new_target_results)
        # # print(f'Precision: {cm.precision}, Recall: {cm.recall}, F1 Score: {cm.f1_score}')
        precision, recall, f1_score = calculate_precision_recall(df, new_predicted_results, new_target_results)
        print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1_score}')
        print(f'Total execution time(s): {total_execution_time}')
        print(df)

        print('---------------------------------------------------------------')

region: /home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/uttar_pradesh_most_15/images/, user_message: [refer] Where are the factories in the image? Give its oriented bounding box
conf_threshold = 0.25, iou_threshold = 0.1
Precision: 0.1037037037037037, Recall: 0.05303030303030303, F1 Score: 0.07017543859649124
Total execution time(s): 96.04103755950928
            predicted kilns  predicted_bg
true kilns             14.0         250.0
true_bg               121.0           0.0
---------------------------------------------------------------
region: /home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/west_bengal_most_15/images/, user_message: [refer] Where are the factories in the image? Give its oriented bounding box
conf_threshold = 0.25, iou_threshold = 0.1
Precision: 0.13178294573643412, Recall: 0.08808290155440414, F1 Score: 0.10559006211180125
Total execution time(s): 96.33750653266907
            predicted kilns  predicted_bg
true kilns   

Customise the Plot results according to IoU threshold 

In [11]:
def plot_results_based_on_iou(user_message, gr_imgs_path, new_predicted_results, new_target_results, iou_threshold, image_path_with_detected_objects, region, plot=False):
    if not plot:
        return
    # n = len(image_path_with_detected_objects)
    gr_img_path = sorted(glob(gr_imgs_path))
    n = len(gr_img_path)
    fig, ax = plt.subplots(nrows = n, ncols = 1 ,figsize=(120, 120))
    ax = ax.flatten()
    for i in range(n):
        # img = Image.open(image_path_with_detected_objects[i]).convert('RGB') # predicted image
        img = Image.open(gr_img_path[i]).convert('RGB') # original image
        w, h = img.size
        ax[i].imshow(img) 
        for bbox in new_target_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4 = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'green')
        for bbox in new_predicted_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4, conf = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'red')
        ax[i].set_axis_off()

    fig.suptitle(f'IoU:{iou_threshold}_{user_message}')
    region = region.replace('/','_')
    plt.savefig(f'geochat_output_refer_{region}.png')
    plt.close() # will not display the plot

In [66]:
region = 'lucknow_zoom17_prompt1'
# gr_imgs_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/swinir_data/lucknow_airshed_most_15/images/*'
gr_imgs_path = '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom17/*'
# gr_imgs_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/uttar_pradesh_most_15/swinir_images/*'
predicted_labels_path = '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_airshed_most_15_zoom17/prompt1/*'
ground_truth_labels_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/lucknow_airshed_most_15/labels/*'
user_message = '[refer] Where are the brick kilns with chimney in the image? Give its oriented bounding box'
iou_threshold = 0.1

labels = sorted(glob(predicted_labels_path))
new_predicted_results = []
for label in labels:
    boxes = np.loadtxt(label, ndmin = 2)
    new_predicted_results.append(boxes)

labels = sorted(glob(ground_truth_labels_path))
new_target_results = []
for label in labels:
    boxes = np.loadtxt(label, ndmin = 2)
    new_target_results.append(boxes)

new_predicted_results = add_class_confidence(new_predicted_results)
new_target_results = modify_class(new_target_results)

filtered_predictions = []
for targets, predictions in zip(new_target_results, new_predicted_results):
    num_predictions, num_iou_levels = predictions.shape[0], 1
    targets = targets*500
    predictions = predictions*500
    # print(f'total target {len(targets)}, total predictions {len(predictions)}')
    iou = obb_iou(targets[:, 1:9].reshape(-1,4,2), predictions[:, 1:9].reshape(-1,4,2)) # (n x m) n targets and m predictions
    matched_indices = np.where((iou >= iou_threshold)) # gives the indices of matched targets and predictions (2, num of matched predictions)
    if matched_indices[0].shape[0]:
        combined_indices = np.stack(matched_indices, axis=1) # gives the array of (matched target idx, matched predicted idx)
        iou_values = iou[matched_indices][:, None] # gives the iou values of matched targets and predictions
        matches = np.hstack([combined_indices, iou_values]) # horizontal stack of matched target idx, matched predicted idx and iou values

        if matched_indices[0].shape[0] > 1:
            matches = matches[matches[:, 2].argsort()[::-1]]
            matches = matches[np.unique(matches[:, 1], return_index=True)[1]] # find the unique targets with which prediction is matched
            matches = matches[np.unique(matches[:, 0], return_index=True)[1]] # find the unique prediction with which target is matched

        print(matches)
        print(len(matches))
        predictions = predictions[matches[:, 1].astype(int)]
        filtered_predictions.append(predictions/500)
    else:
        filtered_predictions.append(np.array([]))

print('are total predictions same as total images',len(filtered_predictions))
plot_results_based_on_iou(user_message, gr_imgs_path, filtered_predictions, new_target_results, iou_threshold, [], region, plot=True)

[[  0.         105.           0.23701341]
 [  1.          94.           0.14710886]
 [  2.          80.           0.15956379]
 [  4.          72.           0.23209519]
 [  5.          61.           0.21893226]
 [  6.          59.           0.1458968 ]
 [  8.          46.           0.16139825]
 [ 10.          29.           0.21680819]
 [ 11.           2.           0.37664996]]
9
[[ 0.         97.          0.12260881]
 [ 2.         90.          0.21227805]
 [ 3.         87.          0.16453006]
 [ 4.         86.          0.24101132]
 [ 5.         85.          0.27092224]
 [ 8.         32.          0.37253849]
 [ 9.         28.          0.15158679]]
7
[[ 0.         76.          0.20611791]
 [ 1.         46.          0.41013578]
 [ 2.         32.          0.70055344]
 [ 4.         22.          0.15135311]
 [ 6.         14.          0.34886974]
 [ 7.         19.          0.23136314]]
6
[[ 1.         78.          0.19502506]
 [ 2.         91.          0.10748007]
 [ 5.         68.          0