In [None]:
from src.models import ObjectDetectionModel

from dataclasses import dataclass, field

import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

import torch
import random

import numpy as np
from torch import nn
import torch.nn.functional as F
import cv2

import plotly.graph_objects as go
import numpy as np

In [None]:
int_to_label = {
    0: "person",
    1: "birds",
    2: "parking meter",
    3: "stop sign",
    4: "street sign",
    5: "fire hydrant",
    6: "traffic light",
    7: "motorcycle",
    8: "bicycle",
    9: "LMVs",
    10: "HMVs",
    11: "animals",
    12: "poles",
    13: "barricades",
    14: "traffic cones",
    15: "mailboxes",
    16: "stones",
    17: "small walls",
    18: "bins",
    19: "furniture",
    20: "pot plant",
    21: "sign boards",
    22: "boxes",
    23: "trees",
}

# Generate 24 distinct colors
int_to_color = {i: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for i in range(24)}
det_to_color = {i: (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for i in range(300)}

In [None]:

def resize_image(image_tensor, size):
    return F.interpolate( image_tensor.unsqueeze(0), size = ( size, size ),
                          mode = 'bilinear', align_corners = False ).squeeze(0)

@dataclass
class ConfigBackbone:
    in_channels: float = 3
    embed_dim: float = 512
    num_heads: float = 8
    depth: float = 4
    num_tokens: float = 4096
    model = ''

@dataclass
class BackboneConfig:
    in_channels: float = 3
    embed_dim: float = 384
    num_heads: float = 8
    depth: float = 4
    num_tokens: float = 4096
    model: str = "linear"

@dataclass
class ComputePrecision:
    grad_scaler: bool = True

@dataclass
class HungarianLoss:
    lambda_bbox: int = 1.0
    lambda_cls: float = 1.0
    image_width: float = 512
    image_height: float = 512
    num_classes: int = 25

@dataclass
class Detection:
    nc: int = 24
    ch: tuple = (384, 384, 384)
    hd: int = 256  # hidden dim
    nq: int = 300  # num queries
    ndp: int = 4  # num decoder points
    nh: int = 8  # num head
    ndl: int = 6  # num decoder layers
    d_ffn: int = 1024  # dim of feedforward
    dropout: float = 0.0
    act: nn.Module = nn.ReLU()
    eval_idx: int = -1
    # Training args
    learnt_init_query: bool = False

@dataclass
class Config:
    log_dir: str = "./models/"
    name: str = "detection_v5_small"
    backbone_name: str = "encoder_v5"
    compute_precision: ComputePrecision = field( default_factory = ComputePrecision )
    backbone: BackboneConfig = field( default_factory = BackboneConfig )
    hungarian_loss: HungarianLoss = field( default_factory = HungarianLoss )
    detection: Detection = field( default_factory = Detection )

In [None]:
cfg = Config()
device = ( 'cuda' if torch.cuda.is_available() else 'cpu' )

In [None]:
model = ObjectDetectionModel( cfg, device )
# model.load( chpt = 150000 )

In [None]:
model.to( device )
model.eval()

## Hooks

In [None]:
# Hook variables
hook_outputs = {}

# Hook functions
def hook_fn(capture):
    def execute_hook(module, input, output):
        global hook_outputs
        hook_outputs[capture] = {
            'input': input,
            'output': output
        }
    return execute_hook

# Register the hook
model.detection.enc_score_head.register_forward_hook( hook_fn( 'proposal_queries' ) ) # Proposal Queries ( Tokens )
model.backbone.backbone.net.model.layers[0].register_forward_hook( hook_fn( 'features' ) ) # Backbone
model.backbone.backbone.net.model.layers[31].register_forward_hook( hook_fn( 'features_large_objects' ) ) # Backbone
model.backbone.backbone.net.model.layers[28].register_forward_hook( hook_fn( 'features_medium_objects' ) ) # Backbone
model.backbone.backbone.net.model.layers[25].register_forward_hook( hook_fn( 'features_small_objects' ) ) # Backbone

In [None]:
# Functions to extract the informations form the hook

def occurrence_indices(strings):
    counts = {}
    indices = []
    for s in strings:
        # Get the current count (default is 0)
        count = counts.get(s, 0)
        indices.append(count)
        # Update the count for this string
        counts[s] = count + 1
    return indices

def process_detections_tokens(shapes):
    
    global hook_outputs

    # shapes = [ ( 90, 160 ), ( 45, 80 ), ( 22, 40 ) ]

    # Get the outputs from the hook
    enc_outputs_scores = hook_outputs['proposal_queries']['output'].to('cpu')  # [bs, h*w, num_classes]

    # Get shape information
    bs, hw_total, num_classes = enc_outputs_scores.shape

    # Apply softmax to the scores
    enc_outputs_scores = enc_outputs_scores.sigmoid()

    # Get the max scores and labels
    scores, _ = enc_outputs_scores.max( -1 )

    # get top-k tokens and values from the all 3 feature maps
    topk_values, topk_indices = torch.topk( scores, 300, dim = 1 )
    topk_indices = topk_indices[0]  # shape: (300,)
    topk_values  = topk_values[0]   # shape: (300,)
    indices_idx = torch.arange( 0, 300 )  # shape: (300,)

    level0_end = shapes[0][0] * shapes[0][1]
    level1_end = level0_end + shapes[1][0] * shapes[1][1]
    level2_end = level1_end + shapes[2][0] * shapes[2][1]  # Should equal hw_total

    # Isolate tokens for level 0: indices in [0, level0_end)
    mask0 = ( topk_indices < level0_end )
    top_indices_feat_0 = topk_indices[mask0]
    top_values_feat_0  = topk_values[mask0]
    indices_idx_feat_0 = indices_idx[mask0]

    # Isolate tokens for level 1: indices in [level0_end, level1_end)
    mask1 = ( topk_indices >= level0_end ) & ( topk_indices < level1_end )
    top_indices_feat_1 = topk_indices[mask1] - level0_end
    top_values_feat_1  = topk_values[mask1]
    indices_idx_feat_1 = indices_idx[mask1]

    # Isolate tokens for level 2: indices in [level1_end, level2_end)
    mask2 = ( topk_indices >= level1_end ) & ( topk_indices < level2_end )
    top_indices_feat_2 = topk_indices[mask2] - level1_end  # local index if desired
    top_values_feat_2  = topk_values[mask2]
    indices_idx_feat_2 = indices_idx[mask2]

    all_top_indices = [ top_indices_feat_0, top_indices_feat_1, top_indices_feat_2 ]
    all_top_values = [ top_values_feat_0, top_values_feat_1, top_values_feat_2 ]
    all_top_indices_idx = [ indices_idx_feat_0, indices_idx_feat_1, indices_idx_feat_2 ]

    return all_top_indices, all_top_values, all_top_indices_idx

def extrcat_features(feats):
    img = feats['output'][0] + feats['output'][0].std( 0 )
    return {
        'image': img.cpu().numpy(), 
        'layers': len( img )
    }

In [None]:
# Function to visualize the features

def get_proposals_heatmap_overlay(image, labels, indices, detectors, int_to_label, process_detections_tokens, det_to_color):
    """
    Creates an object proposals heatmap overlay with detector annotations.
    
    Parameters:
        image (np.array): The original image (RGB).
        labels (iterable): Labels used to create the legend.
        indices (iterable): Indices corresponding to each label.
        detectors (iterable): Detector IDs.
        int_to_label (dict or callable): Mapping from integer label to text label.
        process_detections_tokens (callable): Function that processes detections tokens;
                                                should return (all_top_indices, all_top_values, all_top_indices_idx)
        det_to_color (dict): Mapping from detector id to a color tuple (R, G, B) in 0-255 range.
    
    Returns:
        np.array: The final blended RGB image with the heatmap and detector annotations.
    """
    # Create the legend mapping for detectors.
    legend = {d: int_to_label[l.item()] + '_' + str(i) for l, i, d in zip(labels, indices, detectors)}
    
    # Define scales for multi-scale detection.
    shapes = [(80, 80), (40, 40), (20, 20)]
    
    # Process detection tokens to get indices, values, and indices positions for each scale.
    all_top_indices, all_top_values, all_top_indices_idx = process_detections_tokens(shapes)
    
    original_h, original_w = image.shape[0:2]
    heatmap_overlay = np.zeros((original_h, original_w), dtype=np.float32)
    all_detector_coords = []
    
    # Loop through each scale and update the heatmap overlay.
    for i, ((h, w), s_indices, values, dxs) in enumerate(zip(shapes, all_top_indices, all_top_values, all_top_indices_idx)):
        scale_y, scale_x = original_h / h, original_w / w
        circle_radius = max(1, int(min(scale_y, scale_x) // 2))
        for idx, val, dx in zip(s_indices, values, dxs):
            idx = idx.item()
            feat_y, feat_x = idx // w, idx % w
            
            orig_y = int((feat_y + 0.5) * scale_y)
            orig_x = int((feat_x + 0.5) * scale_x)
            
            orig_y = min(max(0, orig_y), original_h - 1)
            orig_x = min(max(0, orig_x), original_w - 1)
            
            score_val = val.item()
            
            cv2.circle(heatmap_overlay, (orig_x, orig_y), circle_radius, score_val, -1)
            
            if dx in detectors:
                all_detector_coords.append((orig_x, orig_y, circle_radius * 3, dx.item()))
    
    # Normalize the heatmap overlay.
    if heatmap_overlay.max() > 0:
        heatmap_overlay = heatmap_overlay / heatmap_overlay.max()
    
    # Create the colored heatmap using the COLORMAP_HOT.
    heatmap_color = cv2.applyColorMap((heatmap_overlay * 255).astype(np.uint8), cv2.COLORMAP_HOT)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
    
    # Blend the original image with the heatmap overlay.
    image_rgb = image.copy()
    alpha = 0.7
    blended = cv2.addWeighted(image_rgb, 1 - alpha, heatmap_color, alpha, 0)
    
    # Add 40px padding to the blended image.
    blended = cv2.copyMakeBorder(blended, 40, 40, 40, 40, cv2.BORDER_CONSTANT, value=(0, 0, 0))
    
    # Draw detector circles and text labels on the final image using OpenCV.
    for (cx, cy, radius, n) in all_detector_coords:
        cx += 40  # adjust for the padding
        cy += 40  # adjust for the padding
        # Draw circle (using the color from det_to_color, assumed to be in (R, G, B)).
        cv2.circle(blended, (cx, cy), radius, det_to_color[n], thickness=2)
        # Draw text label near the circle.
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(blended, legend[n], (cx + 10, cy - 15), font, 0.4, (255, 255, 255), thickness=1, lineType=cv2.LINE_AA)
    
    return blended

def interactive_channel_viewer(image, layers):
    
    # Create the initial heatmap using the first channel.
    fig = go.Figure(
        data=go.Heatmap(
            z=image[0],
            colorscale='gray',
            zmin=float(np.min(image)),
            zmax=float(np.max(image)),
            colorbar=dict(title="Intensity")
        )
    )
    
    # Create a slider with steps for each channel.
    steps = []
    for i in range(layers):
        step = dict(
            method="update",
            args=[{"z": [image[i]]}],  # update the heatmap's data
            label=f"Channel {i}"
        )
        steps.append(step)
    
    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Channel: "},
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        title="Interactive Channel Viewer",
        sliders=sliders,
        xaxis=dict(title="Width"),
        yaxis=dict(title="Height")
    )
    
    return fig


In [None]:
import cv2
image = cv2.imread( 'test_images/8.png' )
image = cv2.cvtColor( image, cv2.COLOR_BGR2RGB )
torch_image = torch.from_numpy( image ).permute( 2, 0, 1 ).float() / 255.0
torch_image = torch_image.to( device )

# resize the image to 640x640
torch_image = resize_image( torch_image, 640 )

In [None]:
boxes, scores, labels, detectors = model.predict( torch_image, stride_slices = 32, confidence_threshold = 0.3, iou_threshold = 0.3 )
detectors = detectors.cpu().numpy()
indices = occurrence_indices( labels.cpu().numpy() )

In [None]:
hook_outputs['features']['output'][0].shape

In [None]:
fig = interactive_channel_viewer( **extrcat_features( hook_outputs['features'] ) )

In [None]:
fig.show()

In [None]:
labels, detectors

In [None]:
from PIL import ImageDraw, ImageFont
from torchvision.transforms import ToPILImage

image_height, image_width = image.shape[0:2]
image_ = ToPILImage()( image )
draw = ImageDraw.Draw( image_ )
font = ImageFont.load_default()

for d in range( boxes.shape[0] ):

    bbox = boxes[d]
    score = scores[d]
    label = labels[d]
    detector = detectors[d]

    x_center, y_center, width, height = bbox
    x_min = ( x_center - width / 2 ) * image_width
    y_min = ( y_center - height / 2 ) * image_height
    x_max = ( x_center + width / 2 ) * image_width
    y_max = ( y_center + height / 2 ) * image_height
    box = np.array( [ x_min, y_min, x_max, y_max ] )
    
    label_name = int_to_label[label.item()]
    score = score.item() * 100
    detector_number = detector.item()
    draw.rectangle( box, outline = int_to_color[label.item()], width = 2 )
    draw.text( ( x_min, y_min ), f"{label_name} \n {score:.2f}% \n {detector_number}", fill = "red", font = font )

image_.show()
image_.save( "test_images/8_p.png" )

In [None]:
image = get_proposals_heatmap_overlay( image, labels, indices, detectors, int_to_label, process_detections_tokens, det_to_color )

In [None]:
# draw the image
plt.imshow( image )
plt.axis( 'off' )
plt.title( 'Proposals Heatmap Overlay' )
plt.show()