In [1]:
import sys
sys.path.append('/teamspace/studios/this_studio/detr')

In [4]:
!pip install opencv-python


Collecting opencv-python
  Downloading opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (62.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 MB[0m [31m142.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: opencv-python
Successfully installed opencv-python-4.10.0.84


In [2]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T
from models import build_model

# 1. Load the model
def load_model(checkpoint_path, device):
    class Args:
        def __init__(self):
            self.device = device
            self.backbone = 'resnet50'
            self.enc_layers = 6
            self.dec_layers = 6
            self.hidden_dim = 256
            self.num_queries = 100
            self.dropout = 0.1
            self.position_embedding = 'sine'
            self.lr = 1e-4
            self.lr_backbone = 1e-5
            self.batch_size = 2
            self.epochs = 300
            self.lr_drop = 200
            self.set_cost_class = 1
            self.set_cost_bbox = 5
            self.set_cost_giou = 2
            self.mask_loss_coef = 1
            self.dice_loss_coef = 1
            self.bbox_loss_coef = 5
            self.giou_loss_coef = 2
            self.eos_coef = 0.1
            self.dataset_file = 'coco'
            self.num_workers = 2
            self.seed = 42
            self.output_dir = ''
            self.resume = ''
            self.eval = True
            self.coco_path = ''
            self.coco_panoptic_path = ''
            self.masks = False
            self.dilation = False
            self.nheads = 8
            self.dim_feedforward = 2048
            self.pre_norm = False
            self.aux_loss = True
    
    # Build the model using the training configuration
    args = Args()
    model, _, _ = build_model(args)
    model.to(device)
    
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    return model

# 2. Image Preprocessing (add normalization)
def preprocess_image(image):
    transform = T.Compose([
        T.Resize(800),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
    ])
    image_tensor = transform(image).unsqueeze(0)
    return image, image_tensor

# 3. Convert bounding boxes from normalized to pixel coordinates
def box_cxcywh_to_xyxy(box, image_size):
    w, h = image_size
    cx, cy, bw, bh = box.unbind(1)
    xmin = (cx - 0.5 * bw) * w
    xmax = (cx + 0.5 * bw) * w
    ymin = (cy - 0.5 * bh) * h
    ymax = (cy + 0.5 * bh) * h
    return torch.stack([xmin, ymin, xmax, ymax], dim=1)

# 4. Post-process the results
def post_process_outputs(outputs, image_size, threshold=0.1):
    pred_logits = outputs['pred_logits'][0]
    pred_boxes = outputs['pred_boxes'][0]

    prob = torch.softmax(pred_logits, dim=-1)
    scores, labels = prob[:, :-1].max(dim=-1)

    # Filter predictions by confidence threshold
    keep = scores > threshold
    boxes = pred_boxes[keep]
    labels = labels[keep]
    scores = scores[keep]

    # Convert boxes to pixel coordinates
    boxes = box_cxcywh_to_xyxy(boxes, image_size)
    return boxes, labels, scores

# 5. Visualize the results
def visualize_results(image, boxes, labels, scores):
    plt.figure(figsize=(12, 12))
    plt.imshow(image)
    ax = plt.gca()

    for box, label, score in zip(boxes, labels, scores):
        xmin, ymin, xmax, ymax = box.tolist()
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color="red", linewidth=2))
        ax.text(xmin, ymin, f"Label: {label.item()} | Score: {score:.2f}",
                bbox=dict(facecolor="yellow", alpha=0.5), fontsize=10, color="black")

    plt.axis("off")
    plt.show()

# 6. Process video frames
def process_video(video_path, model, device, output_path):
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Could not open video at {video_path}")
        return

    frame_count = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        print(f"Processing frame {frame_count}")

        # Convert frame to PIL image
        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        # Preprocess image
        image, image_tensor = preprocess_image(image)
        image_size = image.size  # (width, height)

        # Run inference
        with torch.no_grad():
            image_tensor = image_tensor.to(device)
            outputs = model(image_tensor)

        # Post-process the outputs
        boxes, labels, scores = post_process_outputs(outputs, image_size, threshold=0.1)

        # Visualize the results
        if len(boxes) > 0:
            visualize_results(image, boxes, labels, scores)
        
        # Save the output frame
        output_frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        cv2.imwrite(f"{output_path}/frame_{frame_count}.png", output_frame)

    cap.release()
    print(f"Video processing completed and saved to: {output_path}")

# Main inference process for video
def main_inference_for_video(checkpoint_path, video_path, device='cuda', output_path='/teamspace/studios/this_studio/up-detr/output'):
    # Load the model
    model = load_model(checkpoint_path, device)

    # Process the video
    process_video(video_path, model, device, output_path)

# Run the inference on a given video
checkpoint_path = '/teamspace/studios/this_studio/detr-r50-e632da11.pth'  # Path to your checkpoint
video_path = '/teamspace/studios/this_studio/video_20241128_161442_edit.mp4'  # Path to your input video

main_inference_for_video(checkpoint_path, video_path)




Processing frame 1
