In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
cwd = os.getcwd()
if cwd.endswith("tutorial"):
    os.chdir("../")

from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
import torch
from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmcv.parallel import scatter
from mmcv.cnn.utils.flops_counter import add_flops_counting_methods

from projects.mmdet3d_plugin.datasets.builder import build_dataloader
from projects.mmdet3d_plugin.datasets.utils import draw_lidar_bbox3d

In [2]:
gpu_id = 0
config = "sparse4dv3_temporal_r50_1x8_bs6_256x704"
checkpoint = "ckpt/sparse4dv3_temporal_r50_1x8_bs6_256x704.pth"
# checkpoint = "ckpt/sparse4dv3_r50.pth"

cfg = Config.fromfile(f"projects/configs/{config}.py")
# cfg.model["use_deformable_func"] = False
# cfg.model["head"]["deformable_model"]["use_deformable_func"] = False
img_norm_mean = np.array(cfg.img_norm_cfg["mean"])
img_norm_std = np.array(cfg.img_norm_cfg["std"])

In [3]:
dataset = build_dataset(cfg.data.val)
dataloader = build_dataloader(
    dataset,
    samples_per_gpu=1,
    workers_per_gpu=0,
    dist=False,
    shuffle=False,
)
data_iter = dataloader.__iter__()
data = next(data_iter)
data = scatter(data, [gpu_id])[0]

{'version': 'v1.0-trainval'}


In [4]:
model = build_detector(cfg.model)
model = model.cuda(gpu_id)
_ = model.load_state_dict(torch.load(checkpoint)["state_dict"], strict=False)
model = model.eval()
# assert model.use_deformable_func, "Please compile deformable aggregation first !!!"

In [5]:
import warnings
warnings.filterwarnings('ignore')
import os
import numpy as np
import torch
import av  # PyAV library
from mmcv.parallel import scatter
from PIL import Image, ImageDraw, ImageFont

# Create output directory if it doesn't exist
output_dir = 'output'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Class names and corresponding color mapping in RGB
CLASSES = (
    "car",
    "truck",
    "construction_vehicle",
    "bus",
    "trailer",
    "barrier",
    "motorcycle",
    "bicycle",
    "pedestrian",
    "traffic_cone",
)
ID_COLOR_MAP = [
    (59, 59, 238),   # car - blue
    (0, 255, 0),     # truck - green
    (255, 0, 0),     # construction_vehicle - red
    (255, 255, 0),   # bus - yellow
    (0, 255, 255),   # trailer - cyan
    (255, 0, 255),   # barrier - magenta
    (255, 255, 255), # motorcycle - white
    (255, 127, 0),   # bicycle - orange
    (71, 130, 255),  # pedestrian - light blue
    (127, 127, 0),   # traffic_cone - olive
]

# Convert RGB colors to BGR for OpenCV compatibility (for both legend and bounding boxes)
def convert_rgb_to_bgr(color_map):
    return [(b, g, r) for (r, g, b) in color_map]

# Apply BGR conversion to the ID_COLOR_MAP
ID_COLOR_MAP_BGR = convert_rgb_to_bgr(ID_COLOR_MAP)

# Function to draw the legend on an image using BGR colors
def draw_legend(image):
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()  # Use default font
    x, y = 10, 10  # Starting position of the legend
    for i, (class_name, color) in enumerate(zip(CLASSES, ID_COLOR_MAP)):
        draw.rectangle([x, y + i * 20, x + 15, y + 15 + i * 20], fill=color)
        draw.text((x + 20, y + i * 20), class_name, fill=(255, 255, 255), font=font)
    return image

# Initialize variables to store frame dimensions (extracted dynamically)
frame_width = None
frame_height = None

# Set up video writer using PyAV
output_video_path = "output_detections_video_pyav.mp4"
fps = 5  # Frames per second

# Reset instance bank
model.head.instance_bank.reset()

# Iterate over the dataset and save each frame
for i, data in enumerate(data_iter):
    # if i >= 100:
    #     break
    data = scatter(data, [gpu_id])[0]
    
    # Feature extraction and model output
    feature_maps = model.extract_feat(data["img"], metas=data)
    model_outs = model.head(feature_maps, data)
    
    # Decode predicted bounding boxes
    pred_bbox3d = model.head.decoder.decode_box(model_outs["prediction"][-1][0])
    
    # Extract class scores and find the predicted class for each instance
    class_scores = model_outs['classification'][-1].max(dim=-1)[0][0].sigmoid()  # Extract confidence scores
    predicted_classes = model_outs['classification'][-1].argmax(dim=-1)[0]  # Extract predicted class per instance
    
    # Apply confidence threshold
    mask = class_scores > 0.35
    num_det = mask.sum()

    # Prepare the raw image and correct color scaling
    raw_imgs = data["img"][0].permute(0, 2, 3, 1).cpu().numpy()
    raw_imgs = (raw_imgs * img_norm_std + img_norm_mean).astype(np.uint8)  # Ensure it's in the uint8 range for proper color display

    # Assign colors to bounding boxes based on predicted classes (using BGR for OpenCV compatibility)
    colors = [ID_COLOR_MAP[class_idx.item() % len(ID_COLOR_MAP)] for class_idx in predicted_classes[mask]]
    
    # Draw 3D bounding boxes on the image with class-based colors
    img_with_detections = draw_lidar_bbox3d(
        pred_bbox3d[mask],
        raw_imgs, data["projection_mat"][0],
        color=colors  # Assign the color for each detected box (BGR)
    )

    # Convert to PIL Image for further drawing (legend and boxes)
    img_rgb = Image.fromarray(img_with_detections.astype(np.uint8))

    # Draw the color legend on the image (now using BGR colors)
    img_rgb = draw_legend(img_rgb)

    # Dynamically set frame dimensions based on img_with_detections
    if frame_width is None or frame_height is None:
        frame_height, frame_width = img_rgb.size[1], img_rgb.size[0]

        # Initialize PyAV container and video stream after getting dimensions
        container = av.open(output_video_path, mode='w')
        stream = container.add_stream('mpeg4', rate=fps)
        stream.width = frame_width
        stream.height = frame_height
        stream.pix_fmt = 'yuv420p'  # Set pixel format for compatibility

        # Set higher quality parameters
        stream.bit_rate = 5000000  # 5 Mbps (adjust as needed for higher quality)
        stream.options = {'qscale:v': '2'}  # Lower qscale = higher quality

    # Save the frame as an image for debugging
    frame_path = os.path.join(output_dir, f"frame_{i:03d}.png")
    img_rgb.save(frame_path)

    # Convert the image to a PyAV video frame
    frame = av.VideoFrame.from_image(img_rgb)

    # Encode the frame
    for packet in stream.encode(frame):
        container.mux(packet)

# Finalize the video stream (flush remaining frames)
for packet in stream.encode():
    container.mux(packet)

# Close the PyAV container
container.close()

print(f"Video saved to {output_video_path}")
print(f"Frames saved to {output_dir}/")


Video saved to output_detections_video_pyav.mp4
Frames saved to output/


In [6]:
# Finalize the video stream (flush remaining frames)
for packet in stream.encode():
    container.mux(packet)

# Close the PyAV container
container.close()
