# Imports and path setups

In [None]:
from PIL import Image
import torch
import numpy as np
from scipy import ndimage
from skimage.filters import gaussian
from skimage.measure import label, regionprops, find_contours
import cv2
import matplotlib.pyplot as plt

In [None]:
import os

In [None]:
root = '../dataset/street_obstacle_sequences/'
raw_path = os.path.join(root, 'raw_data')
target_path = os.path.join(root, 'semantic_ood')
ood_score_path = os.path.join(root, 'ood_score')
ood_prediction_tracked_path = os.path.join(root, 'ood_prediction_tracked')

In [None]:
sorted(os.listdir(raw_path))

In [1]:
predicted_sequences = ['all'] # can be ['all'] to evaluate all sequences or a list with the sequences names

## Load SOS dataset

In [None]:
from Obstacle_Sequence_Challenge.datasets.street_obstacle_sequences import StreetObstacles

In [None]:
sos_dataset = StreetObstacles(root, sequences=predicted_sequences)

## Loading Grounding DINO and Segment Anything models

In [None]:
import sys

In [None]:
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "segment-anything"))

In [None]:
import copy

import numpy as np
import torch
from PIL import Image
from torchvision.ops import box_convert

# Grounding DINO
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate

# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# # Uncomment to download the weights of the models
# !mkdir weights
# %cd weights

# !wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
# !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
ROOT = '.'

GDINO_CONFIG_PATH = os.path.join(ROOT, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(GDINO_CONFIG_PATH, "; exist:", os.path.isfile(GDINO_CONFIG_PATH))

GDINO_WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
GDINO_WEIGHTS_PATH = os.path.join(ROOT, "weights", GDINO_WEIGHTS_NAME)
print(GDINO_WEIGHTS_PATH, "; exist:", os.path.isfile(GDINO_WEIGHTS_PATH))

SAM_WEIGHTS_NAME = "sam_vit_h_4b8939.pth"
SAM_WEIGHTS_PATH = os.path.join(ROOT, "weights", SAM_WEIGHTS_NAME)
print(SAM_WEIGHTS_PATH, "; exist:", os.path.isfile(SAM_WEIGHTS_PATH))

#
# Loading Grounding DINO Model
groundingdino_model = load_model(GDINO_CONFIG_PATH, GDINO_WEIGHTS_PATH)

#
# Loading SAM Model
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam = build_sam(checkpoint=SAM_WEIGHTS_PATH)
sam.to(device=DEVICE)
sam_predictor = SamPredictor(sam)

In [None]:
def filter_boxes(boxes, phrases):
    arr_phr = np.array(phrases)
    filtered_ids = set()
    
    road_ids = set(np.where(arr_phr == 'road')[0])
    obj_ids = set(np.where(arr_phr == 'object')[0])
    sml_ids = set(np.where(arr_phr == 'small')[0])
    sob_ids = set(np.where(arr_phr == 'small object')[0])
    
    filtered_ids.update(obj_ids)
    filtered_ids.update(sml_ids)
    filtered_ids.update(sob_ids)
    
    return list(road_ids) + list(filtered_ids)

def get_valid_boxes(image, boxes, logits, iou_threshold=0.9):
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    iou, _ = box_ops.box_iou(boxes_xyxy, boxes_xyxy)
    ioub = iou > iou_threshold
    problem_pairs = np.argwhere(np.triu(ioub, k=1))
    valid_boxes = set(np.arange(len(boxes_xyxy)).tolist())
    supressed_boxes = set()
    for boxpair in problem_pairs:
        box1, box2 = boxpair[0], boxpair[1]
        if logits[box1] >= logits[box2]:
            supressed_boxes.add(box2)
        else:
            supressed_boxes.add(box1)
    nonsupressed_boxes = list(valid_boxes-supressed_boxes)
    
    return nonsupressed_boxes

def get_object_boxes(boxes, phrases):
    obj_ids = np.where(phrases != 'road')[0]
    return boxes[obj_ids]

In [None]:
def get_proposal_masks(masks, scores, logits, phrases, w=(0.7, 0.3)):
    H, W = masks.shape[-2:]
    
    ood_score_mask = np.full((H, W), 0, dtype=np.float32)
    ood_mask = np.full((H, W), 255, dtype=np.uint8)
    road_mask = np.full((H, W), False, dtype=np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
    kernel_medium = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(20,20))
    kernel_road = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(50,50))
    
    for i, label in enumerate(phrases):
        mask = masks[i][0].cpu().numpy().astype(np.uint8)
                
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_medium)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel).astype(bool)
        
        if label == 'road':
            mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel_road)
            mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_road).astype(bool)
            
            ood_mask[mask] = 0
        else:
            ood_mask[mask] = 254# scores[i].item() 
            ood_score_mask[mask] = w[0]*logits[i].item() + w[1]*scores[i].item() # scores[i].item()
        
    return ood_mask, ood_score_mask

def get_proposal_regions(mask):
    # road: mask == 0, ood obj: == 254 for , background: == 255
    obj_label_img = label(mask == 254)
    obj_regions = regionprops(obj_label_img)
    
    road_label_img = label(mask == 0)
    road_regions = regionprops(road_label_img)
    
    return obj_label_img, obj_regions, road_label_img, road_regions

def get_ood_regions(obj_label_img, obj_regions, road_label_img, road_regions, inter_th=0.15):
    
    selected_obj_ids = []
    selected_obj_bbxs = []
    selected_obj_instance_lbls = []
    
    for i, obj in enumerate(obj_regions):
        
        obj_ch = np.zeros_like(obj_label_img, dtype=np.uint8)
        minr, minc, maxr, maxc = obj.bbox
        obj_ch[minr:maxr, minc:maxc] = obj.image_convex
        
        obj_px_area = np.count_nonzero(obj_ch)
        
        for j, road in enumerate(road_regions):
            
            if i in selected_obj_ids: break
            
            road_ch = np.zeros_like(obj_label_img, dtype=np.uint8)
            rd_minr, rd_minc, rd_maxr, rd_maxc = road.bbox
            road_ch[rd_minr:rd_maxr, rd_minc:rd_maxc] = road.image_convex
            
            road_px_area = np.count_nonzero(road_ch)
            min_px_area = min(road_px_area, obj_px_area)
            
            intersection = obj_ch & road_ch
            
            intersection_px_area = np.count_nonzero(intersection)
            
            if intersection_px_area / min_px_area >= inter_th:
                selected_obj_ids.append(i)
                selected_obj_bbxs.append([minc, minr, maxc, maxr]) #lt, rb
                selected_obj_instance_lbls.append(obj.label)
    
    return selected_obj_ids, selected_obj_bbxs, selected_obj_instance_lbls   
    

def get_filtered_proposal_mask(proposal_mask, obj_label_img, obj_instance_lbls, road_label_img):
    mask = np.full_like(proposal_mask, 255, dtype=np.uint8)
    
    # fill roads
    mask[road_label_img > 0] = 0
    
    # fill objects
    for lbl in obj_instance_lbls:
        mask[obj_label_img == lbl] = 254
        
    return mask
    
def get_ood_map(mask, score_mask):
    H, W = mask.shape[-2:]
    ood_mask = np.zeros((H, W), dtype=np.float32)
    ood_mask[mask == 254] = score_mask[mask == 254]
    
    return ood_mask

def get_track_map(masks, ids, mask_id):
    H, W = masks.shape[-2:]
    track_mask = np.full((H, W), 255, dtype=np.uint8)
    
    for i, label in enumerate(ids):
        mask = masks == mask_id[i] 
        # 254 # id == road: value = 0, ood object == not road: value 254
        if label == 0: continue 
        else:
            track_mask[mask] = label
        
    return track_mask

In [None]:
# Greed algorithm for tracking. Match boxes in the actual frame with boxes from previous frame with higher iou
def match_boxes(cur_box, prev_box, track_ids, iou_threshold=0.1):
    cur_box = torch.Tensor(cur_box)
    prev_box = torch.Tensor(prev_box)
    
    iou, _ = box_ops.box_iou(cur_box, prev_box)

    temp_track_ids = np.zeros(len(cur_box), dtype=np.uint8)

    for bi in range(len(cur_box)):
        if iou[bi].max() >= iou_threshold:
            matched_box = np.argmax(iou[bi])
            temp_track_ids[bi] = track_ids[matched_box]

    temp_track_ids[temp_track_ids == 0] = np.arange(len(temp_track_ids[temp_track_ids == 0])) + np.max(temp_track_ids) + 1

    return temp_track_ids

In [None]:
# First run the detection of objects and road
TEXT_PROMPT = "small object in the road, small object, road"
BOX_TRESHOLD = 0.15
TEXT_TRESHOLD = 0.15

## Process the dataset

- The GroundDINO detects objects and the road (following the `TEXT_PROMPT`).
- The detections are filtered, to remove know-objects (if they are in the `TEXT_PROMPT`) and to remove possible duplicates (boxes with high intersections, only the most confident is kept), and lastly reorder the boxes to position 'road' boxes first on the list.
- The boxes are passed, along the image, as input to the SAM model that provides a segmentation for each box.
- These segmentations are then labeled (and have a ood score attached if they are objects), using refined masks (through morphological operations)
- For each object segmentation, its convex hull is computed and only objects with some intersection with the road mask are mantained.
- The remainder objects are then allotted a track id. Tracking through frames is done using a greedy algorithm that matches boxes from one frame with boxes with the highest intersection in the next frame.
- Then both the track_id and ood_score maps are saved

In [None]:
for img_path in sos_dataset.images:
    print(f'... processing sequence {img_path.split("/")[-2]}, frame {img_path.split("/")[-1]}')
    # Load image
    image_source, image = load_image(img_path)
    
    # Get predictions
    boxes, logits, phrases = predict(
        model=groundingdino_model,
        image=image,
        caption=TEXT_PROMPT,
        box_threshold=BOX_TRESHOLD,
        text_threshold=TEXT_TRESHOLD,
        device=DEVICE
    )
    
    # Remove prepositions and articles from the text labels
    phrases = [x.replace('the ', '').replace('in ', '') for x in phrases]
    
    # Filter id bounboxes (only return road and small object boxes)
    filtered_boxes = filter_boxes(boxes, phrases)
    boxes = boxes[filtered_boxes]
    logits = logits[filtered_boxes]
    phrases= np.array(phrases)[filtered_boxes]    
        
    # Filter overlapping boundboxes
    valid_boxes = get_valid_boxes(image_source, boxes, logits)
    boxes = boxes[valid_boxes]
    logits = logits[valid_boxes]
    phrases= phrases[valid_boxes]
    
    # Get Segmentation
    sam_predictor.set_image(image_source)

    # box: normalized box xywh -> unnormalized xyxy
    H, W, _ = image_source.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(DEVICE)
    masks, scores, _ = sam_predictor.predict_torch(
                point_coords = None,
                point_labels = None,
                boxes = transformed_boxes,
                multimask_output = False,
            )
    # get proposal masks
    proposal_mask, score_mask = get_proposal_masks(masks, scores, logits, phrases)
    # get regions of interest
    obj_label_img, obj_regions, road_label_img, road_regions = get_proposal_regions(proposal_mask)
    # filter objects without intersection with the road mask
    selected_obj_ids, selected_obj_bbxs, selected_obj_instance_lbls  = get_ood_regions(obj_label_img, 
                                                                                       obj_regions, 
                                                                                       road_label_img, 
                                                                                       road_regions)
    # get the filtered version of the labeled map
    filtered_proposal_mask = get_filtered_proposal_mask(proposal_mask, 
                                                        obj_label_img, selected_obj_instance_lbls,
                                                        road_label_img)
    
    # OOD score
    ood_score = get_ood_map(filtered_proposal_mask, score_mask)           # get_ood_map(masks, scores, phrases)
    # ood_score = gaussian(ood_score, sigma=3)                            # smooth ood score predictions 
    
    
    # Get object boxes
    obj_boxes = selected_obj_bbxs                                         # get_object_boxes(boxes, phrases)
    
    #######################
    # Tracking Process ####
    
    cur_frame = img_path.split("/")[-1].split('_')[0]

    if cur_frame == '000000': # first frame
        last_tracked_obj_id = 0 
        prev_obj_boxes = []
        
    if len(obj_boxes) < 1:
        cur_track_ids = []
    else:
        if len(prev_obj_boxes) > 0:
            cur_track_ids = np.zeros(len(obj_boxes), dtype=np.uint8)
            cur_track_ids = match_boxes(obj_boxes, prev_obj_boxes, prev_track_ids)
        else:
            cur_track_ids = np.arange(len(obj_boxes), dtype=np.uint8) + last_tracked_obj_id + 1
                
        last_tracked_obj_id = max(last_tracked_obj_id, np.max(cur_track_ids))

    prev_track_ids = cur_track_ids
    prev_obj_boxes = obj_boxes

    padded_track_ids = np.insert(cur_track_ids, 0, np.where(phrases == 'road')[0]) # Pad the track ids with zeros for road box

    # Tracking labels
    ood_track = get_track_map(obj_label_img, cur_track_ids, selected_obj_instance_lbls)
    
    #######################
    #######################
    
    
    # Save OOD score/ tracking
    actual_seq = img_path.split('/')[-2]
    actual_frame = img_path.split('/')[-1].replace('_raw_data.jpg', '.npy')
    actual_frame_label = img_path.split('/')[-1].replace('_raw_data.jpg', '_label.npy')
    
    # Tracking save
    os.makedirs(os.path.join(ood_prediction_tracked_path, actual_seq), exist_ok=True)
    
    ood_track_path = os.path.join(ood_prediction_tracked_path, actual_seq, actual_frame)
    np.save(ood_track_path, ood_track)
    
    # OOD save
    os.makedirs(os.path.join(ood_score_path, actual_seq), exist_ok=True)
    
    ood_map_path = os.path.join(ood_score_path, actual_seq, actual_frame)
    ood_lbl_path = os.path.join(ood_score_path, actual_seq, actual_frame_label)
    np.save(ood_map_path, ood_score)
    np.save(ood_lbl_path, filtered_proposal_mask)

### For result visualization

In [None]:
def get_data(sequence, frame=0, gray=False, pred=False):
    sequence_r_path = os.path.join(raw_path, f'sequence_{sequence}')
    sequence_t_path = os.path.join(target_path, f'sequence_{sequence}')
    sequence_p_path = os.path.join(ood_score_path, f'sequence_{sequence}')
    
    frame = f'{frame:06d}'
    img_path = os.path.join(sequence_r_path, f'{frame}_raw_data.jpg')
    tgt_path = os.path.join(sequence_t_path, f'{frame}_semantic_ood.png')
    prd_path = os.path.join(sequence_p_path, f'{frame}_label.npy')
    scr_path = os.path.join(sequence_p_path, f'{frame}.npy')
    
    image = Image.open(img_path).convert("RGB")
    target = Image.open(tgt_path).convert("I;16" if gray else "RGB")
    prd = None
    
    if not pred:
        return image, target
    else:
        prd = np.load(prd_path)
        scr = np.load(scr_path)
        return image, target, prd, scr

In [None]:
# frame has to be a multiple of 8
def show_gt_pair(sequence, frame=0, gray=False):
    fig, ax = plt.subplots(1, 2, figsize=(24, 8))
    
    image, target = get_data(sequence, frame, gray=gray)
    
    ax[0].imshow(image)
    ax[1].imshow(target)
    
    ax[0].axis('off')
    ax[1].axis('off')
    plt.show()

In [None]:
def show_pair(im1, im2):
    fig, ax = plt.subplots(1, 2, figsize=(24, 8))
    
    ax[0].imshow(im1)
    ax[1].imshow(im2)
    
    ax[0].axis('off')
    ax[1].axis('off')
    plt.show()

In [None]:
def to_gray(npimg):
    npimg[npimg == 0] = 63
    npimg[npimg == 255] = 0
    npimg[npimg == 254] = 255
    
    return npimg

In [None]:
def show_gt_pred(sequence, frame=0, gray=False):
    fig, ax = plt.subplots(2, 2, figsize=(20, 10))
    
    image, target, pred, score = get_data(sequence, frame, gray=gray, pred=True)
    
    if gray:
        target = Image.fromarray(to_gray(np.array(target)))
        pred = Image.fromarray(to_gray(pred))
    else:
        pred = Image.fromarray(pred)
    
    ax[0][0].imshow(image)
    ax[0][1].imshow(score)
    ax[1][0].imshow(target)
    ax[1][1].imshow(pred)
    
    ax[0][0].axis('off')
    ax[0][1].axis('off')
    ax[1][0].axis('off')
    ax[1][1].axis('off')
    plt.show()
    
    return pred, score