In [1]:
import os
id = 1
os.environ['CUDA_VISIBLE_DEVICES'] = f'{id}'

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from PIL import Image, ImageEnhance
import requests
import torch
import numpy as np
import re
from glob import glob
import os
from time import time
import pandas as pd
import matplotlib.pyplot as plt
# from my_metrics.metrics import ConfusionMatrix as MyConfusionMatrix
from garuda.od import ConfusionMatrix




from torch.utils.data import Dataset
from transformers import PaliGemmaProcessor
import supervision as sv
import json

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load the model

In [2]:
model_id = "google/paligemma-3b-pt-448"
# model_id = 'google/paligemma-3b-ft-rsvqa-hr-448'

quantize_model = False

# quantize model weights
if quantize_model:
    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, 
                                                            quantization_config = quantization_config
                                                            ).eval()
else:
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to(DEVICE)

processor = AutoProcessor.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>.+?)( ;|$)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

In [4]:
def plot_results(user_message, gr_imgs_path, new_predicted_results, new_target_results, region, plot=False):
    if not plot:
        return
    gr_img_path = sorted(glob(gr_imgs_path))
    n = len(gr_img_path)
    fig, ax = plt.subplots(nrows = n, ncols = 1 ,figsize=(120, 120))
    ax = ax.flatten()
    for i in range(n):
        img = Image.open(gr_img_path[i]).convert('RGB') # image
        w, h = img.size
        ax[i].imshow(img) 
        for bbox in new_target_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4 = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'green')
        for bbox in new_predicted_results[i]:
            classvalue, x1, y1, x2, y2, x3, y3, x4, y4, conf = bbox*w 
            ax[i].plot([x1, x2, x3, x4, x1], [y1, y2, y3, y4, y1], color = 'red')
        ax[i].set_axis_off()

    fig.suptitle(f'{user_message}')
    region = region.replace('/','_')
    plt.savefig(f'paligemma_output_refer_{region}.png')
    plt.close() # will not display the plot

In [5]:
def add_class_confidence(predicted_results):
    # paligemma model output is in the format of [ymin, xmin, ymax, xmax]
    # convert them into [class, xmin, ymin, xmax, ymax, xmax, ymax, xmin, ymin, confidence]
    new_predicted_results = []
    for res in predicted_results:
        if len(res):
            res = np.hstack([np.zeros((len(res),1)), res[:,[1,0,1,2,3,2,3,0]], np.ones((len(res),1))], dtype=np.float32) # add class label 0 at index 0 and confidence score 1 at last index
            new_predicted_results.append(res)
        else:
            res = np.zeros((1, 10))
            res[:, 0] = 1 # class label 1 for no detection
            res[:, -1] = 1
            new_predicted_results.append(res.astype(np.float32))
    return new_predicted_results

In [6]:
def modify_class(target_results):
    new_target_results = []
    for res in target_results:
        res[:,0] = 0 # convert class labels to 0
        xmax = np.max(res[:, [1,3,5,7]], axis = 1)
        xmin = np.min(res[:, [1,3,5,7]], axis = 1)
        ymax = np.max(res[:, [2,4,6,8]], axis = 1)
        ymin = np.min(res[:, [2,4,6,8]], axis = 1)
        res = np.vstack([res[:,0], xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
        res = res.T
        res = res.astype(np.float32)
        new_target_results.append(res)
    return new_target_results

In [7]:
def calculate_confusion_matrix(new_predicted_results, new_target_results):
    cm_predicted_results = []
    for res in new_predicted_results:
        res[:,1:9] = res[:,1:9]*500
        cm_predicted_results.append(res)


    cm_target_results = []
    for res in new_target_results:
        res[:,0] = 0 # convert class labels to 0
        res[:,1:9] = res[:,1:9]*500
        res = res.astype(np.float32)
        cm_target_results.append(res)

    classes, conf_threshold, iou_threshold = ['brick_kilns'], 0.25, 0.1
    cm = ConfusionMatrix.from_obb_tensors(cm_predicted_results, cm_target_results, classes, conf_threshold, iou_threshold)
    # cm = MyConfusionMatrix.from_tensors(cm_predicted_results, cm_target_results, classes, conf_threshold, iou_threshold)
    df = pd.DataFrame(cm.matrix, columns = ['predicted kilns','predicted_bg'], index=['true kilns','true_bg'])
    print(f'conf_threshold = {conf_threshold}, iou_threshold = {iou_threshold}')
    # print(cm.summary)
    # print(df.to_markdown())
    return cm, df

In [8]:
def calculate_precision_recall(confusion_matrix, new_predicted_results, new_target_results):
    tp = confusion_matrix.loc['true kilns']['predicted kilns']
    predicted_positives = 0
    for res in new_predicted_results:
        predicted_positives += np.where(res[:,0] == 1, 0, 1).sum() # class = 1 means background class
    
    ground_truth = 0
    for res in new_target_results:
        ground_truth += np.where(res[:,0] == 0, 1, 0).sum() # class = 0 means brick kiln class

    precision = tp / (predicted_positives + 1e-9)
    recall = tp/ground_truth
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-9)

    return precision, recall, f1_score

In [18]:
def patch_boxes(image_height , image_width, patch_size, overlap):
    slice_bboxes = []
    offsets = []
    y_max = y_min = 0

    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + patch_size
        while x_max < image_width:
            x_max = x_min + patch_size
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - patch_size)
                ymin = max(0, ymax - patch_size)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
                offsets.append([xmin, ymin])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
                offsets.append([x_min, y_min])
            x_min = x_max - overlap
        y_min = y_max - overlap
    return slice_bboxes, offsets

def get_predicted_target_box_patch(gr_imgs_path, gr_labels, type, prompt, patch_size, overlap):
    target_results = []
    predicted_results = []

    start_time = time()
    for img_path in sorted(glob(gr_imgs_path)):
        gr_img = Image.open(img_path).convert('RGB')
        patches, offsets = patch_boxes(gr_img.size[0], gr_img.size[1], patch_size, overlap)
        tmp = []
        for i, patch_offset in enumerate(zip(patches, offsets)):
            patch, offset = patch_offset
            patch = gr_img.crop(patch)
            model_inputs = processor(text=prompt, images=patch, return_tensors="pt") # preprocess the prompt and the image
            model_inputs.to(device)
            input_len = model_inputs["input_ids"].shape[-1]

            with torch.inference_mode():
                generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) # pass the preprocess input for generation
                # generation = generation[0][input_len:]
                generation = generation[0]
                decoded = processor.decode(generation, skip_special_tokens=True)
                print(f'decoded {decoded}') # print the generated output
                boxes, labels = parse_bbox_and_labels(decoded)
                # print(f'labels {labels}, boxes {boxes}, offset {offset}')
                for box in boxes: # predicted boxes are in ymin, xmin, ymax, xmax format.
                    ymin, xmin, ymax, xmax = box
                    ymin = (ymin*patch_size + offset[1])/gr_img.size[0]
                    xmin = (xmin*patch_size + offset[0])/gr_img.size[0]
                    ymax = (ymax*patch_size + offset[1])/gr_img.size[0]
                    xmax = (xmax*patch_size + offset[0])/gr_img.size[0]
                    tmp.append([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])

        if prompt == '<image> detect chimney \n':
            idx = 1
        elif prompt == '<image> detect brick kilns with chimney':
            idx = 2
        elif prompt == '<image> detect factories':
            idx = 3

        img_name = os.path.basename(img_path)
        np.savetxt(f'{type}/prompt{idx}/{img_name[:-4]}.txt',np.array(tmp))
        predicted_results.append(np.array(tmp))
        
        if '.tif' in img_name:
            target_path = os.path.join(gr_labels, img_name.replace('.tif', '.txt'))
        else:
            target_path = os.path.join(gr_labels, img_name.replace('.png', '.txt'))
        target_results.append(np.loadtxt(target_path, ndmin=2))
        
    total_execution_time = time() - start_time

    return predicted_results, target_results, total_execution_time


In [None]:
prompt1 = "<image>detect chimney \n"
prompt2 = '<image>detect brick kilns with chimney'
prompt3 = '<image>detect factories'

prompts = [
    prompt1,
    # prompt2,
    # prompt3
    ]

In [20]:
base_path = '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data'

regions = [
            # f'{base_path}/lucknow_airshed_most_15/images', 
         #   '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/swinir_data/lucknow_airshed_most_15/images',
           '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom18',
            # '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/uttar_pradesh_most_15/images/',
        #     f'{base_path}/uttar_pradesh_most_15/swinir_images',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_high_resolution_zoom18',
         #   '/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/west_bengal_most_15/images/',
        #    f'{base_path}/west_bengal_most_15/swinir_images',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom17',
        #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_high_resolution_zoom18'
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/uttar_pradesh_kilns_zoom19',
        # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/west_bengal_kilns_zoom19'
        ]

locations = [
            #  'lucknow_airshed_most_15', 
             'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15', 
            #  'lucknow_airshed_most_15',
            #  'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15',
            #  'uttar_pradesh_most_15', 
            #  'uttar_pradesh_most_15', 
            #  'west_bengal_most_15',
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15', 
            #  'west_bengal_most_15',
             ]
        
types = [
   #  'lucknow_airshed_most_15_planet',
   #  'lucknow_airshed_most_15_swinir',
    'lucknow_airshed_most_15_zoom17',
   #  'lucknow_airshed_most_15_zoom18',
   #  'uttar_pradesh_most_15_planet',
   #  'uttar_pradesh_most_15_swinir',
   #  'uttar_pradesh_most_15_zoom17',
   #  'uttar_pradesh_most_15_zoom18',
   #  'west_bengal_most_15_planet',
   #  'west_bengal_most_15_swinir',
   #  'west_bengal_most_15_zoom17',
   #  'west_bengal_most_15_zoom18',
   #  'lucknow_kilns_zoom19',
   #  'uttar_pradesh_kilns_zoom19',
   #  'west_bengal_kilns_zoom19'
]

In [21]:
patch_size = 320
overlap = 20 # 640x640
# overlap = 75 # 2560x2560

for prompt in prompts:
    for region, location, type in zip(regions, locations, types):
        print(f'user_message: {prompt}, type {type} region: {region}, ')
        gr_imgs_path = region+'/*'
        gr_labels = f'/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/data/vlm_data/{location}/labels'
        predicted_results, target_results, total_execution_time = get_predicted_target_box_patch(gr_imgs_path, gr_labels, type, prompt, patch_size, overlap)
        print(f'Total execution time(s): {total_execution_time}')
        new_predicted_results = add_class_confidence(predicted_results)
        print(f'new predicted results {new_predicted_results[0]}')
        new_target_results = modify_class(target_results)
        print(f'new target results {new_target_results[0]}')
        if len(new_predicted_results) == 0:
            print('No detections found')
        else:
            plot_results(prompt, gr_imgs_path, new_predicted_results, new_target_results, region, True)
            cm, df = calculate_confusion_matrix(new_predicted_results, new_target_results)
            # print(f'Precision: {cm.precision}, Recall: {cm.recall}, F1 Score: {cm.f1_score}')
            precision, recall, f1_score = calculate_precision_recall(df, new_predicted_results, new_target_results)
            print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1_score}')
            print(df)
            
        print('---------------------------------------------------------------')

user_message: <image> detect chimney 
, type lucknow_airshed_most_15_zoom17 region: /home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_high_resolution_zoom17, 
decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 

 chimney
decoded  detect chimney 

 chimney
decoded  detect chimney 

<loc0000><loc0000><loc1023><loc1023> chimney
decoded  detect chimney 

<loc0000><loc0000><loc1023><loc1023> chimney
decoded  detect chimney 

 chimney
decoded  detect chimney 

<loc0000><loc0000><loc1023><loc1023> chimney
decoded  detect chimney 

<loc0000><loc0000><loc1023><loc1023> chimney
decoded  detect chimney 


decoded  detect chimney 

10
decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 

 chimney
decoded  detect chimney 

 chimney
decoded  detect chimney 


decoded  detect chimney 


decoded  detect chimney 

<loc0000><loc0000><loc1023><loc1023> chimney
decoded

Contrast Changes

In [13]:
# import cv2
# from glob import glob

# fig, ax = plt.subplots(1,2, squeeze=False)
# image_paths  = glob(regions[0]+'/*')
# for image_path in image_paths[:5]:
#     image = cv2.imread(image_path)
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#     ax[0, 0].imshow(image)
#     # Adjusts the brightness by adding 10 to each pixel value 
#     brightness = 0 
#     # # Adjusts the contrast by scaling the pixel values by 2.3 
#     contrast = 1.5
#     image2 = cv2.addWeighted(image, contrast, np.zeros(image.shape, image.dtype), 0, brightness) 
#     # Create the sharpening kernel 
#     # kernel = np.array([[0, -1, 0], [-1, 5., -1], [0, -1, 0]]) 
    
#     # Sharpen the image 
#     # image2 = cv2.filter2D(image, -1, kernel) 
#     ax[0][1].imshow(image2)

In [14]:
# from PIL import ImageEnhance
# fig, ax = plt.subplots(1,2, squeeze=False)
# image_paths  = glob(regions[0]+'/*')
# for image_path in image_paths[:1]:
#     image = Image.open(image_path).convert('RGB')
#     # im3 = ImageEnhance.Brightness(image) 
#     # # showing resultant image 
#     # im3 = im3.enhance(1.5)
#     # im3 = ImageEnhance.Sharpness(image) 
#     # # showing resultant image 
#     # im3 = im3.enhance(5.5)
#     im3 = ImageEnhance.Contrast(image) 
#     # showing resultant image 
#     im3 = im3.enhance(2)
#     ax[0 ,0].imshow(image)
#     ax[0, 1].imshow(im3)
# plt.axis('off')

In [8]:
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch

model_id = "google/paligemma2-3b-pt-448"

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)

# Leaving the prompt blank for pre-trained models
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.


a teal car parked in front of a yellow wall


Testing

In [41]:
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
url = "https://steemitimages.com/DQmVrNBAUyUifreYwn2bJ5DTyb4xt4PuVwiJP1J164HLnC6/SAM_5060.JPG"
# url = 'https://media.istockphoto.com/id/1250380933/photo/top-down-aerial-view-of-chicago-downtown-urban-grid-with-park.jpg?s=2048x2048&w=is&k=20&c=mi6DZWHjZVWbyRaPS5z-ZkH910ka6JMZcPbult9tEAw='
image = Image.open(requests.get(url, stream=True).raw)
image.save('/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/ztemp/images/zebra.jpg')
print(image.size)
prompt = "<image>detect zebra"
# prompt = "<image>answer en how many trees are there in the image?"
# prompt = "<image>describe the image?"
TORCH_DTYPE = model.dtype
model_inputs = processor(text=prompt, images=image, return_tensors = 'pt').to(TORCH_DTYPE).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=5, do_sample=False)
    generation = generation[:, input_len:]
    decoded = processor.batch_decode(generation, skip_special_tokens=True)
    print(decoded)

(5472, 3648)
['<loc0652><loc0063><loc0858><loc0398> zebra']


In [None]:
import numpy as np
fig,ax = plt.subplots()
img = Image.open('/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/ztemp/images/zebra.jpg')
w,h = img.size
ax.imshow(img)
ymin, xmin , ymax, xmax = np.array([652, 63, 858,398])/1000
ymin = ymin * h
xmin = xmin * w
ymax = ymax * h
xmax = xmax * w
ax.plot([xmin, xmax, xmax, xmin, xmin], [ymin, ymin, ymax, ymax, ymin], color='red')
ax.set_axis_off()


In [39]:
from glob import glob
kiln_img = Image.open(glob('/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images/*')[1])
prompt = "<image>detect brick kiln with chimney"
TORCH_DTYPE = model.dtype
model_inputs = processor(text=prompt, images=kiln_img, return_tensors = 'pt').to(TORCH_DTYPE).to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=10, do_sample=False)
    generation = generation[:, input_len:]
    decoded = processor.batch_decode(generation, skip_special_tokens=True)
    print(decoded)

['<loc0000><loc0450><loc1023><loc0643> brick kiln with chimney']


In [None]:
import numpy as np
fig,ax = plt.subplots()
w,h = kiln_img.size
ax.imshow(kiln_img)
ymin, xmin , ymax, xmax = np.array([0, 450, 1023, 636])/1000
ymin = ymin * h
xmin = xmin * w
ymax = ymax * h
xmax = xmax * w
ax.plot([xmin, xmax, xmax, xmin, xmin], [ymin, ymin, ymax, ymax, ymin], color='red')


In [13]:
print(generation[0])

tensor([257152, 257152, 257152,  ..., 257023,   8195,      1], device='cuda:1')


#### Zero shot performance

In [17]:
class JSONLDataset(Dataset):
    def __init__(self, jsonl_file_path: str, image_directory_path: str):
        self.jsonl_file_path = jsonl_file_path
        self.image_directory_path = image_directory_path
        self.entries = self._load_entries()

    def _load_entries(self):
        entries = []
        with open(self.jsonl_file_path, 'r') as file:
            for line in file:
                data = json.loads(line)
                entries.append(data)
        return entries

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx: int):
        if idx < 0 or idx >= len(self.entries):
            raise IndexError("Index out of range")

        entry = self.entries[idx]
        image_path = os.path.join(self.image_directory_path, entry['image'])
        image = Image.open(image_path)
        return image, entry
    

# GMS imagery
test_dataset = JSONLDataset(
    jsonl_file_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/west_bengal_small_639_sq_km/kiln_images/test/paligemma2_annotations.jsonl",
    image_directory_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/west_bengal_small_639_sq_km/kiln_images/test/images",
)

# test_dataset = JSONLDataset(
#     jsonl_file_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/paligemma2_annotations.jsonl",
#     image_directory_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images",
# )

# Planet Imagery
# test_dataset = JSONLDataset(
#     jsonl_file_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_train_test_split/test/paligemma2_annotations.jsonl",
#     image_directory_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/lucknow_train_test_split/test/images",
# )

# test_dataset = JSONLDataset(
#     jsonl_file_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/ztemp/p.jsonl",
#     image_directory_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/ztemp/images",
# )

CLASSES = test_dataset[0][1]['prefix'].replace("detect ", "").split(" ; ")
CLASSES

['brick kilns with chimney']

In [18]:
test_dataset.entries

[{'image': '9789942.9117_2522839.3509.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0010><loc0161><loc0150><loc0401> brick kilns with chimney'},
 {'image': '9792083.1485_2521616.3584.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0632><loc0537><loc0918><loc0854> brick kilns with chimney'},
 {'image': '9786579.6824_2527119.8245.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0303><loc0644><loc0552><loc0938> brick kilns with chimney'},
 {'image': '9792694.6447_2524368.0915.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0197><loc0015><loc0526><loc0441> brick kilns with chimney'},
 {'image': '9785662.4381_2527425.5726.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0306><loc0754><loc0472><loc1020> brick kilns with chimney'},
 {'image': '9788719.9192_2547604.9481.png',
  'prefix': 'detect brick kilns with chimney',
  'suffix': '<loc0689><loc0260><loc0835><loc0582> brick kilns wi

In [4]:
MODEL_ID ="google/paligemma2-3b-pt-448"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto")

TORCH_DTYPE = model.dtype

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [19]:
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

def collate_test_fn(batch):
    images, labels = zip(*batch)

    prefixes = ["<image>" + label["prefix"] for label in labels]
    suffixes = [label["suffix"] for label in labels]
    inputs = processor(
        text=prefixes,
        images=images,
        return_tensors="pt",
        padding="longest"
    ).to(TORCH_DTYPE).to(DEVICE)

    return images, inputs, suffixes

test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn= collate_test_fn, shuffle=False)

images = []
targets = []
predictions = []

with torch.inference_mode():
    for imgs, test_inputs, suffixes in tqdm(test_dataloader):
        # print(test_inputs['input_ids'])
        prefix_length = test_inputs["input_ids"].shape[-1]

        generation = model.generate(**test_inputs, max_new_tokens=256, do_sample=False)
        generation = generation[:, prefix_length:]
        generated_texts = processor.batch_decode(generation, skip_special_tokens=True)
        w, h = imgs[0].size
        for generated_text in generated_texts:
            print(generated_text)
            prediction = sv.Detections.from_vlm(
                vlm='paligemma',
                result=generated_text,
                resolution_wh=(w, h),
                )
            
            prediction.class_id = np.array([CLASSES.index(class_name) for class_name in prediction['class_name']])
            prediction.confidence = np.ones(len(prediction))
            predictions.append(prediction)

        for suffix in suffixes:
            target = sv.Detections.from_vlm(
                vlm='paligemma',
                result=suffix,
                resolution_wh=(w, h),
                )

            target.class_id = np.array([CLASSES.index(class_name) for class_name in target['class_name']])
            targets.append(target)
        images += list(imgs)

  4%|▎         | 1/27 [00:00<00:14,  1.86it/s]

<loc0000><loc0149><loc0189><loc0416> brick kilns with chimney


  7%|▋         | 2/27 [00:01<00:13,  1.91it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 11%|█         | 3/27 [00:01<00:12,  1.92it/s]

<loc0000><loc0615><loc0442><loc1023> brick kilns with chimney


 15%|█▍        | 4/27 [00:02<00:11,  1.93it/s]

<loc0200><loc0128><loc0347><loc0235> brick kilns with chimney


 19%|█▊        | 5/27 [00:02<00:11,  1.93it/s]

<loc0000><loc0652><loc0057><loc0681> brick kilns with chimney


 22%|██▏       | 6/27 [00:03<00:10,  1.93it/s]

<loc0311><loc0035><loc1023><loc0867> brick kilns with chimney


 26%|██▌       | 7/27 [00:03<00:10,  1.94it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 30%|██▉       | 8/27 [00:04<00:09,  1.94it/s]

<loc0026><loc0193><loc0109><loc0291> brick kilns with chimney


 33%|███▎      | 9/27 [00:04<00:09,  1.94it/s]

<loc0000><loc0000><loc1023><loc1019> brick kilns with chimney


 37%|███▋      | 10/27 [00:05<00:08,  1.94it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 41%|████      | 11/27 [00:05<00:08,  1.94it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 44%|████▍     | 12/27 [00:06<00:07,  1.94it/s]

<loc0000><loc0000><loc0203><loc0247> brick kilns with chimney


 48%|████▊     | 13/27 [00:06<00:07,  1.94it/s]

<loc0340><loc0472><loc0416><loc0546> brick kilns with chimney


 52%|█████▏    | 14/27 [00:07<00:06,  1.95it/s]

<loc0112><loc0307><loc0330><loc0344> brick kilns with chimney


 56%|█████▌    | 15/27 [00:07<00:06,  1.95it/s]

<loc0000><loc0017><loc0099><loc0224> brick kilns with chimney


 59%|█████▉    | 16/27 [00:08<00:05,  1.94it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 63%|██████▎   | 17/27 [00:08<00:05,  1.94it/s]

<loc0393><loc0378><loc0497><loc0437> brick kilns with chimney


 67%|██████▋   | 18/27 [00:09<00:04,  1.94it/s]

<loc0000><loc0000><loc1023><loc1014> brick kilns with chimney


 70%|███████   | 19/27 [00:09<00:04,  1.95it/s]

<loc0000><loc0000><loc0340><loc0224> brick kilns with chimney


 74%|███████▍  | 20/27 [00:10<00:03,  1.95it/s]

<loc0403><loc0643><loc0474><loc0735> brick kilns with chimney


 78%|███████▊  | 21/27 [00:10<00:03,  1.95it/s]

<loc0000><loc0000><loc1023><loc1020> brick kilns with chimney


 81%|████████▏ | 22/27 [00:11<00:02,  1.95it/s]

<loc0000><loc0636><loc0079><loc0673> brick kilns with chimney


 85%|████████▌ | 23/27 [00:11<00:02,  1.95it/s]

<loc0433><loc0889><loc0502><loc0947> brick kilns with chimney


 89%|████████▉ | 24/27 [00:12<00:01,  1.95it/s]

<loc0000><loc0894><loc0071><loc0928> brick kilns with chimney


 93%|█████████▎| 25/27 [00:12<00:01,  1.95it/s]

<loc0391><loc0110><loc0587><loc0407> brick kilns with chimney


 96%|█████████▋| 26/27 [00:13<00:00,  1.95it/s]

<loc0293><loc0806><loc0521><loc0961> brick kilns with chimney


100%|██████████| 27/27 [00:13<00:00,  1.94it/s]

<loc0423><loc0746><loc0800><loc0965> brick kilns with chimney





In [25]:
import sys
sys.path.append('/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma')
from supervision.detection.utils import box_iou_batch
from paligemma2_utils import calculate_map
iou = box_iou_batch(boxes_true = targets[2].xyxy, boxes_detection = predictions[2].xyxy)
print(iou)
map_result, map50, map50_95 = calculate_map(predictions[3:4], targets[3:4])
map50

[[0.19215144]]


0.0

In [26]:
from paligemma2_utils import visualize_predictions
visualize_predictions(images, predictions, targets, start = 0, end = 27, rows = 5, cols = 6)

In [19]:
import numpy as np
from glob import glob
# arr = np.random.choice(glob('/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images/*'), 10)
from PIL import Image
np.random.seed(42)
for i in np.random.choice(glob('/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images/*'), 1):
    print(i)

/home/shataxi.dubey/shataxi_work/vlm_on_planet/gms/lucknow_small_600_sq_km/kiln_images/test/images/9027712.8657_3076549.1837.png
