In [7]:
import os, sys
#sys.path.append('/home/biaslab/Zhen/HDC_DINO')
import torch, json
import numpy as np
import cv2
from PIL import Image
from main import build_model_main
from util.slconfig import SLConfig
from datasets import build_dataset
from util.visualizer import COCOVisualizer
from util import box_ops
from torchvision.transforms import functional as F

In [8]:
model_config_path = "/home/biaslab/Zhen/HDC_DINO/config/DINO/DINO_4scale.py" # change the path of the model config file
model_checkpoint_path = "/home/biaslab/Zhen/HDC_DINO/checkpoint/checkpoint0033_4scale.pth"  # change the path of the model checkpoint
# See our Model Zoo section in README.md for more details about our pretrained models.

In [9]:
args = SLConfig.fromfile(model_config_path) 
args.device = 'cuda' 
model, criterion, postprocessors = build_model_main(args)
checkpoint = torch.load(model_checkpoint_path, map_location=args.device)
state_dict = checkpoint['model']
model.load_state_dict(state_dict, strict=False)
# # Filter out keys for class_embed
# filtered_state_dict = {k: v for k, v in state_dict.items() if "class_embed" not in k}
# model.load_state_dict(filtered_state_dict, strict=False)
# custom_class_embed = torch.load("/home/biaslab/Zhen/DINO/Demo/mlp_ckpt/frcnn_mlp_0.pth", map_location=args.device)
# model.class_embed.load_state_dict(custom_class_embed)
# # Print out the state dictionary for the classification head to verify custom weights
# print("Custom classification head state dict keys:")
# for key, value in model.class_embed.state_dict().items():
#     print(f"{key}: norm = {value.norm().item():.4f}")
# _ = model.eval()

  from .autonotebook import tqdm as notebook_tqdm
  checkpoint = torch.load(model_checkpoint_path, map_location=args.device)


<All keys matched successfully>

In [10]:
# load coco names
with open('/home/biaslab/Zhen/HDC_DINO/util/coco_id2name.json') as f:
    id2name = json.load(f)
    id2name = {int(k):v for k,v in id2name.items()}

In [11]:
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils
from matplotlib import transforms

def addtgt_cv(frame, tgt):
    """
    Draws bounding boxes and annotations directly on an OpenCV image.
    
    Args:
        frame (np.ndarray): The image/frame on which to draw (BGR format).
        tgt (dict): A dictionary containing:
            - 'boxes': Tensor of shape [num_boxes, 4] in normalized xywh format (center, width, height).
            - 'size': Tensor/list with [H, W] of the canonical image.
            - 'box_label': (optional) list of labels for each box.
            - 'caption': (optional) a caption string.
    Returns:
        np.ndarray: The annotated frame.
    """
    # Get canonical target dimensions and actual frame dimensions
    target_h, target_w = tgt['size'].tolist()  # 401, 1331 from your tgt
    frame_h, frame_w, _ = frame.shape         # 374, 1242 from your frame

    # Compute scaling factors from target size to frame size
    scale_x = frame_w / target_w
    scale_y = frame_h / target_h

    boxes = []
    colors = []

    # Process each bounding box
    for box in tgt['boxes'].cpu():
        # Convert normalized center-based xywh to top-left based coordinates
        # First, compute the box relative to the canonical target size
        unnormbbox = box * torch.tensor([target_w, target_h, target_w, target_h])
        unnormbbox[:2] -= unnormbbox[2:] / 2  # shift from center to top-left
        
        # Rescale coordinates to match the actual frame dimensions
        unnormbbox[0] *= scale_x  # x coordinate
        unnormbbox[1] *= scale_y  # y coordinate
        unnormbbox[2] *= scale_x  # width
        unnormbbox[3] *= scale_y  # height
        
        bbox = unnormbbox.round().int().tolist()  # round to nearest int
        boxes.append(bbox)
        
        # Generate a random color (BGR) for visualization
        c = (np.random.rand(3) * 0.6 + 0.4) * 255  # roughly values in [102,255]
        c = tuple(int(x) for x in c)
        colors.append(c)
    
    # Create an overlay to draw filled rectangles with transparency
    overlay = frame.copy()
    
    # Draw each rectangle and its label
    for i, bbox in enumerate(boxes):
        x, y, w, h = bbox  # already integers
        # Draw a filled rectangle on the overlay
        cv2.rectangle(overlay, (x, y), (x + w, y + h), colors[i], -1)
        # Draw the rectangle border on the original frame
        cv2.rectangle(frame, (x, y), (x + w, y + h), colors[i], 2)
        # If label is provided, add text above the rectangle
        if 'box_label' in tgt:
            label = str(tgt['box_label'][i])
            cv2.putText(frame, label, (x, y - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, colors[i], 1, cv2.LINE_AA)
    
    # Blend the overlay with the frame to get transparent filled boxes
    alpha = 0.1  # transparency factor
    cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
    
    # If a caption is provided, add it to the top of the frame
    if 'caption' in tgt:
        caption = tgt['caption']
        cv2.putText(frame, caption, (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2, cv2.LINE_AA)
    
    return frame


def addtgt(tgt):
        """
        - tgt: dict. args:
            - boxes: num_boxes, 4. xywh, [0,1].
            - box_label: num_boxes.
        """
        assert 'boxes' in tgt
        ax = plt.gca()
        H, W = tgt['size'].tolist() 
        numbox = tgt['boxes'].shape[0]

        color = []
        polygons = []
        boxes = []
        for box in tgt['boxes'].cpu():
            unnormbbox = box * torch.Tensor([W, H, W, H])
            unnormbbox[:2] -= unnormbbox[2:] / 2
            [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
            boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
            poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
            np_poly = np.array(poly).reshape((4,2))
            polygons.append(Polygon(np_poly))
            c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
            color.append(c)

        p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
        ax.add_collection(p)
        p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
        ax.add_collection(p)


        if 'box_label' in tgt:
            assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
            for idx, bl in enumerate(tgt['box_label']):
                _string = str(bl)
                bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
                # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
                ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1})

        if 'caption' in tgt:
            ax.set_title(tgt['caption'], wrap=True)

In [12]:
from util.visualizer import COCOVisualizer, renorm
from PIL import Image
import datetime
import datasets.transforms as T
import torchvision.transforms as TF
# transform images
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def vis_frame(frame, model, savedir = '/home/biaslab/Zhen/HDC_DINO/Demo/processed_frame', thershold = 0.3, caption=None):
    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    image = TF.Resize((368, 600))(image)
    image, _ = transform(image, None)
    with torch.no_grad():
        output = model.cuda()(image[None].cuda())
        output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]).cuda())[0]


    vslzr = COCOVisualizer()

    scores = output['scores']
    labels = output['labels']
    boxes = box_ops.box_xyxy_to_cxcywh(output['boxes'])
    #print(scores[:10])
    score_to_name = {score: id2name[int(label)] for score, label in zip(scores, labels)}
    top25 = sorted(score_to_name.items(), key=lambda x: x[0].item(), reverse=True)[:25]
    with open('top25.txt', 'w') as f:
        for score, label in top25:
            f.write(f"{score.item()}: {label}\n")
    print(top25)
    select_mask = scores > thershold

    box_label = [id2name[int(item)] for item in labels[select_mask]]
    pred_dict = {
        'boxes': boxes[select_mask],
        'size': torch.Tensor([image.shape[1], image.shape[2]]),
        'box_label': box_label
    }
    # plt.figure(dpi=120)
    # plt.rcParams['font.size'] = '5'
    # ax = plt.gca()
    # image = renorm(image).permute(1, 2, 0)
    #ax.imshow(image)
    processed_frame = addtgt_cv(frame, pred_dict)
    #import pdb;pdb.set_trace()
    #processed_frame = capture_figure_as_frame()
    #processed_frame = capture_figure_as_frame_buffer()
    #plt.show()
    #plt.close()
    #return processed_frame
    #vslzr.visualize(image, pred_dict, savedir=None, dpi=100)
    return processed_frame 



In [13]:
def main():
    # Set device to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    #  Load pre-trained Faster R-CNN model 
    # model = fasterrcnn_resnet50_fpn(pretrained=True).to(device)
    # model.eval()  # Set to evaluation mode
    
    # # state_dict = torch.load('ckpt/linear_9.pth')
    # # model.roi_heads.box_predictor.cls_score.weight = torch.nn.Parameter(state_dict['fc.weight'])
    # # model.roi_heads.box_predictor.cls_score.bias = torch.nn.Parameter(state_dict['fc.bias'])
    args = SLConfig.fromfile(model_config_path) 
    args.device = 'cuda' 
    model, criterion, postprocessors = build_model_main(args)
    checkpoint = torch.load(model_checkpoint_path, map_location=args.device)
    state_dict = checkpoint['model']
    model.load_state_dict(state_dict, strict=False)
    # Filter out keys for class_embed
    filtered_state_dict = {k: v for k, v in state_dict.items() if "class_embed" not in k}
    model.load_state_dict(filtered_state_dict, strict=False)
    custom_class_embed = torch.load("/home/biaslab/Zhen/HDC_DINO/Demo/hd_ckpt_100/hd_classification_head_epoch_11.pth", map_location=args.device)
    custom_bbox_embed = torch.load("/home/biaslab/Zhen/HDC_DINO/Demo/hd_ckpt_100/regression_head_epoch_11.pth", map_location=args.device)
    model.class_embed.load_state_dict(custom_class_embed)
    model.bbox_embed.load_state_dict(custom_bbox_embed)
    _ = model.eval()
    #print(model)
    # Open video file
    cap = cv2.VideoCapture('/home/biaslab/Zhen/HDC_DINO/videos/0001.mp4')

    # Video writer to save output
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter('/home/biaslab/Zhen/HDC_DINO/output_hd_100.mp4', fourcc, cap.get(cv2.CAP_PROP_FPS),
                          (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))

    # Process video frame by frame
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        #vis_frame(frame, model,  thershold = 0.3)
        processed_frame = vis_frame(frame, model, thershold = 0.3)
        #import pdb; pdb.set_trace()
        out.write(processed_frame)
    cap.release()
    out.release()
    print(f"Processed video saved at /home/biaslab/Zhen/HDC_DINO/output_hd_100.mp4")

In [14]:
main()

Using device: cuda


  checkpoint = torch.load(model_checkpoint_path, map_location=args.device)
  custom_class_embed = torch.load("/home/biaslab/Zhen/HDC_DINO/Demo/hd_ckpt_100/hd_classification_head_epoch_11.pth", map_location=args.device)
  custom_bbox_embed = torch.load("/home/biaslab/Zhen/HDC_DINO/Demo/hd_ckpt_100/regression_head_epoch_11.pth", map_location=args.device)


RuntimeError: Error(s) in loading state_dict for ModuleList:
	Missing key(s) in state_dict: "0.weight", "0.bias", "1.weight", "1.bias", "2.weight", "2.bias", "3.weight", "3.bias", "4.weight", "4.bias", "5.weight", "5.bias". 
	Unexpected key(s) in state_dict: "0.encoder.weight", "0.encoder.bias", "0.model.weight", "0.model.bias", "1.encoder.weight", "1.encoder.bias", "1.model.weight", "1.model.bias", "2.encoder.weight", "2.encoder.bias", "2.model.weight", "2.model.bias", "3.encoder.weight", "3.encoder.bias", "3.model.weight", "3.model.bias", "4.encoder.weight", "4.encoder.bias", "4.model.weight", "4.model.bias", "5.encoder.weight", "5.encoder.bias", "5.model.weight", "5.model.bias". 