In [1]:
import os
id = 3
os.environ['CUDA_VISIBLE_DEVICES'] = str(id)

In [2]:
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
from glob import glob
import re

import torch 
from torch.utils.data import Dataset, DataLoader

from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration

import supervision as sv
from supervision.metrics import MeanAveragePrecision, MetricTarget

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class TifPatchDataset(Dataset):
    def __init__(self, labels_file_path: str, image_directory_path: str, prompt: str, start : int, end : int, type : str):
        self.prompt = prompt
        self.start = start
        self.end = end
        self.type = type
        self.non_bg_img = []
        self.non_bg_img_name = []
        self.labels_file_path = labels_file_path
        self.image_directory_path = image_directory_path
        self.entries = self._load_entries()
        
                
    def _convert_to_paligemma_format(self, pil_image, absolute_bbox):

        image_width, image_height = pil_image.size
        tmp = []

        if len(absolute_bbox) == 0:
            tmp.append(f'<loc0000><loc0000><loc1024><loc1024> background')

        for bbox in absolute_bbox:
            xmin = min(bbox[[1,3,5,7]])
            ymin = min(bbox[[2,4,6,8]])
            xmax = max(bbox[[1,3,5,7]])
            ymax = max(bbox[[2,4,6,8]])
            
            xmin = (xmin)/image_width * 1024
            ymin = (ymin)/image_width * 1024
            xmax = (xmax)/image_width * 1024
            ymax = (ymax)/image_width * 1024

            ymin, xmin, ymax, xmax = ['0'*(4-len(x))+str(x) if len(x) < 4 else x for x in [str(int(ymin)), str(int(xmin)), str(int(ymax)), str(int(xmax))] ]
            tmp.append(f'<loc{ymin}><loc{xmin}><loc{ymax}><loc{xmax}> {self.prompt}')
        
        tmp = ' ; '.join(tmp)
        # create a dictionary with PIL image in one key and its bbox in another key suffix, make sure to bring the bbox range in (0, 1024)
        paligemma_labels = {'image': pil_image, 'suffix': tmp, 'prefix': f'detect {self.prompt} ; background' }
        
        return paligemma_labels

    def _load_entries(self): # load the bbox of all patches
        entries = []
        for idx, image_path in enumerate(sorted(glob(self.image_directory_path+'/*'))[self.start:self.end]):
            img_name = os.path.basename(image_path)
            img = Image.open(image_path)
            w, h = img.size
            label_path = os.path.join(self.labels_file_path,img_name[:-4]+'.txt')
            if os.path.exists(label_path):
                bbox = np.loadtxt(label_path, ndmin = 2)
                bbox = bbox*w
                self.non_bg_img.append(idx)
                self.non_bg_img_name.append(image_path)
            else:
                bbox = []
            entries.append(self._convert_to_paligemma_format(img, bbox))

        return entries

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx: int):
        if idx < 0 or idx >= len(self.entries):
            raise IndexError("Index out of range")

        entry = self.entries[idx]
        image = self.entries[idx]['image']
        return image, entry

In [5]:
def parse_bbox_and_labels(detokenized_output: str):
  matches = re.finditer(
      '<loc(?P<y0>\d\d\d\d)><loc(?P<x0>\d\d\d\d)><loc(?P<y1>\d\d\d\d)><loc(?P<x1>\d\d\d\d)>'
      ' (?P<label>(\w+\s?)+\w)',
      detokenized_output,
  )
  labels, boxes = [], []
  fmt = lambda x: float(x) / 1024.0
  for m in matches:
    d = m.groupdict()
    # print(d)
    boxes.append([fmt(d['y0']), fmt(d['x0']), fmt(d['y1']), fmt(d['x1'])])
    labels.append(d['label'])
  return np.array(boxes), np.array(labels)

In [6]:
def perform_inference(model, processor, CLASSES, test_dataloader):
    images = []
    targets = []
    predictions = []

    with torch.inference_mode():
        for imgs, test_inputs, suffixes in tqdm(test_dataloader):
            # print(test_inputs['input_ids'])
            prefix_length = test_inputs["input_ids"].shape[-1]

            generation = model.generate(**test_inputs, max_new_tokens=256, do_sample=False)
            generation = generation[:, prefix_length:]
            generated_texts = processor.batch_decode(generation, skip_special_tokens=True)
            w, h = imgs[0].size
            for generated_text in generated_texts:
                prediction = sv.Detections.from_lmm(
                    lmm='paligemma',
                    result=generated_text,
                    resolution_wh=(w, h),
                    classes=CLASSES)

                prediction.class_id = np.array([CLASSES.index(class_name) for class_name in prediction['class_name']])
                prediction.confidence = np.ones(len(prediction))
                predictions.append(prediction)

            for suffix in suffixes:
                target = sv.Detections.from_lmm(
                    lmm='paligemma',
                    result=suffix,
                    resolution_wh=(w, h),
                    classes=CLASSES)

                target.class_id = np.array([CLASSES.index(class_name) for class_name in target['class_name']])
                targets.append(target)
            images += list(imgs)
    
    return predictions, targets, images

In [7]:
def test_domain(model_path, model, processor, CLASSES, test_dataset, df, TORCH_DTYPE):

    def collate_test_fn(batch):
        images, labels = zip(*batch)

        prefixes = ["<image>" + label["prefix"] for label in labels]
        suffixes = [label["suffix"] for label in labels]
        inputs = processor(
            text=prefixes,
            images=images,
            return_tensors="pt",
            padding="longest"
        ).to(TORCH_DTYPE).to(DEVICE)

        return images, inputs, suffixes

    test_dataloader = DataLoader(test_dataset, batch_size=8, collate_fn= collate_test_fn, shuffle=False)

    predictions, targets, images = perform_inference(model, processor, CLASSES, test_dataloader)
    
    # @title Calculate mAP
    map_metric = MeanAveragePrecision(metric_target=MetricTarget.BOXES)
    map_result = map_metric.update(predictions, targets).compute()
    # print(map_result)
    # map_result.plot()

    # @title Calculate Confusion Matrix
    # default confidence threshold = 0.3
    for iou in [0.1, 0.3, 0.5, 0.7]:
        confusion_matrix = sv.ConfusionMatrix.from_detections(
            predictions=predictions,
            targets=targets,
            classes=CLASSES,
            conf_threshold = 0.25, #  paligemma does not give confidence score so no use of conf_threshold  
            iou_threshold=iou
        )
        # _ = confusion_matrix.plot()

        tp = confusion_matrix.matrix[0][0]
        pp = confusion_matrix.matrix[:,0].sum()
        ap = confusion_matrix.matrix[0, :].sum()
        precision = tp/(pp + 1e-9)
        recall = tp/(ap + 1e-9)
        f1_score = (2*precision*recall)/(precision+recall+1e-9)
        print(f'Model path {model_path} IoU {iou}')
        print(f'Precision {precision}, Recall {recall}, F1 score {f1_score}, Predicted positive {pp}, True positive {tp}')
        tmp = pd.DataFrame({'model_path': model_path, 'iou': iou, 'precision': precision, 'recall': recall, 'f1_score': f1_score, 'predicted_kilns': pp, 'true_positive': tp}, index = [0])
        df = pd.concat([df, tmp])
    print('\n')
    return df


In [8]:
test_dataset = TifPatchDataset(
    labels_file_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/data/processed_data/wb_small_airshed/labels", 
    image_directory_path=f"/home/shataxi.dubey/shataxi_work/vlm_on_planet/data/processed_data/wb_small_airshed/images", # planet images

    # labels_file_path = '/home/shataxi.dubey/shataxi_work/vlm_on_planet/test/labels',
    # image_directory_path='/home/shataxi.dubey/shataxi_work/vlm_on_planet/test/images',
    prompt = 'brick kilns with chimney',
    start = 0,
    end = 649,
    type = 'test'
) 

In [9]:
len(test_dataset)

459

In [10]:
sum = 0
for i in range(len(test_dataset)):
    box, label = parse_bbox_and_labels(test_dataset.entries[i]['suffix'])
    for classname in label:
        if 'background' in classname:
            sum += 1
print(f'Total number of kilns in the test set {sum}')

Total number of kilns in the test set 382


In [11]:
CLASSES = test_dataset[0][1]['prefix'].replace("detect ", "").split(" ; ")
CLASSES

['brick kilns with chimney', 'background']

In [12]:
MODEL_ID ="google/paligemma2-3b-pt-448"
processor = PaliGemmaProcessor.from_pretrained(MODEL_ID)

In [None]:
model_paths = [
                # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_1_0/checkpoint-100',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_2_0/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_2_1/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_2_2/checkpoint-100',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_2_5/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_3_1/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_4_0/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_5_2/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_5_10/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_7_0/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_10_5/checkpoint-100',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_10_20/checkpoint-100',
            #    '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_16_16/checkpoint-100',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_20_10/checkpoint-200',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_20_40/checkpoint-200',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_30_15/checkpoint-200',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_30_60/checkpoint-300',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_40_20/checkpoint-200',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_40_80/checkpoint-400',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_60_30/checkpoint-300',
               '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_50/checkpoint-500',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_200/checkpoint-1000',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_200_100/checkpoint-1000',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_200_400/checkpoint-1900',
               # '/home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_297_150/checkpoint-1400'
               ]

In [14]:
df = pd.DataFrame(columns = ['model_path', 'iou' ,'precision', 'recall', 'f1_score', 'predicted_kilns', 'true_positive'])
for model_path in model_paths:
    model = PaliGemmaForConditionalGeneration.from_pretrained(model_path, device_map = 'auto')
    TORCH_DTYPE = model.dtype
    TORCH_DTYPE
    df = test_domain(model_path, model, processor, CLASSES, test_dataset, df, TORCH_DTYPE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
100%|██████████| 58/58 [16:21<00:00, 16.93s/it]


Model path /home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_50_again/checkpoint-500 IoU 0.1
Precision 0.11111111111104056, Recall 0.6653992395411962, F1 score 0.1904243740744748, Predicted positive 1575.0, True positive 175.0
Model path /home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_50_again/checkpoint-500 IoU 0.3
Precision 0.10730158730151917, Recall 0.6425855513283553, F1 score 0.18389553838351344, Predicted positive 1575.0, True positive 169.0
Model path /home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_50_again/checkpoint-500 IoU 0.5
Precision 0.07492063492058736, Recall 0.4486692015192066, F1 score 0.12840043501034207, Predicted positive 1575.0, True positive 118.0




Model path /home/shataxi.dubey/shataxi_work/vlm_on_planet/PaliGemma/paligemma2_object_detection_100_50_again/checkpoint-500 IoU 0.7
Precision 0.026666666666649734, Recall 0.1596958174898871, F1 score 0.04570184959149842, Predicted positive 1575.0, True positive 42.0




In [15]:
df

Unnamed: 0,model_path,iou,precision,recall,f1_score,predicted_kilns,true_positive
0,/home/shataxi.dubey/shataxi_work/vlm_on_planet...,0.1,0.111111,0.665399,0.190424,1575.0,175.0
0,/home/shataxi.dubey/shataxi_work/vlm_on_planet...,0.3,0.107302,0.642586,0.183896,1575.0,169.0
0,/home/shataxi.dubey/shataxi_work/vlm_on_planet...,0.5,0.074921,0.448669,0.1284,1575.0,118.0
0,/home/shataxi.dubey/shataxi_work/vlm_on_planet...,0.7,0.026667,0.159696,0.045702,1575.0,42.0


In [16]:
df.to_csv('p.csv')

In [17]:
# annotated_images = []

# for i in range(20,45):
#     image = images[i]
#     detections = predictions[i]
#     target = targets[i]
#     annotated_image = image.copy()
#     annotated_image = sv.BoxAnnotator(thickness=8, color=sv.Color(r=255, g=0, b=0)).annotate(annotated_image, detections)
#     annotated_image = sv.BoxAnnotator(thickness=4, color=sv.Color(r=0, g=255, b=0)).annotate(annotated_image, target)
#     # annotated_image = sv.LabelAnnotator(text_scale=2, text_thickness=4, smart_position=True).annotate(annotated_image, detections)
#     annotated_images.append(annotated_image)

# sv.plot_images_grid(annotated_images, (5,5))

In [18]:
# df = pd.DataFrame(columns = ['model_path', 'iou' ,'precision', 'recall', 'f1_score', 'predicted_kilns', 'true_positive'])
# tmp = pd.DataFrame({'model_path':'shdfj', 'iou':0.1, 'precision':5, 'recall':3, 'f1_score':2, 'predicted_kilns':12, 'true_positive':132}, index = [0])
# print(tmp)
# print(df)
# df = pd.concat([df, tmp])
# df