In [7]:
# !pip install "numpy<2.0" --upgrade
# !pip install torchreid
# !pip install torch==2.0.1 torchvision==0.15.2
# !pip install gdown
# !pip uninstall torchreid -y
# !pip install git+https://github.com/KaiyangZhou/deep-person-reid.git

In [None]:
from roboflow import Roboflow

rf = Roboflow(api_key="OFJsbzSXtei8j554tCdF")
project = rf.workspace().project("walking-staff-detection-ms3uf")
model = project.version(2).model

loading Roboflow workspace...
loading Roboflow project...


In [None]:
import torch
from torchreid.utils import FeatureExtractor

device = "cuda" if torch.cuda.is_available() else "cpu"

extractor = FeatureExtractor(
    model_name='osnet_x1_0',
    pretrained=True,
    device=device
)

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

staff_db = []

def get_embedding(image):
    """
    image: numpy array (H, W, C), BGR 或 RGB 都可以
    """
    return extractor(image)[0].cpu().numpy()

def is_staff(feature, threshold=0.7):
    if not staff_db:
        return False
    sims = cosine_similarity([feature], staff_db)[0]
    return np.max(sims) > threshold


In [None]:
def boxes_overlap(box1, box2):
    """
    box = [x1, y1, x2, y2]
    return: True if IoU > 0
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    if x2 < x1 or y2 < y1:
        return False
    return True


In [None]:
import cv2

video_path = "test.mp4"  # 上傳影片到 Colab

cap = cv2.VideoCapture(video_path)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Roboflow 預測
    results = model.predict(frame, confidence=40, overlap=30).json()

    people = [obj for obj in results['predictions'] if obj['class'] == 'person']
    tags = [obj for obj in results['predictions'] if obj['class'] == 'tag']

    for p in people:
        x, y, w, h = int(p['x']), int(p['y']), int(p['width']), int(p['height'])
        person_crop = frame[y-h//2:y+h//2, x-w//2:x+w//2]

        feature = get_embedding(person_crop)

        if not is_staff(feature):
            for t in tags:
                tx, ty, tw, th = int(t['x']), int(t['y']), int(t['width']), int(t['height'])
                tag_box = [tx-tw//2, ty-th//2, tx+tw//2, ty+th//2]
                person_box = [x-w//2, y-h//2, x+w//2, y+h//2]

                if boxes_overlap(person_box, tag_box):
                    staff_db.append(feature)
                    break

cap.release()
print(f"Staff DB size: {len(staff_db)}")


In [7]:
job_id, signed_url, expire_time = model.predict_video(
    "test.mp4",
    fps=25,
    prediction_type="batch-video",
)

results = model.poll_until_video_results(job_id)

Checking for video inference results for job 9af1d193-c809-4f43-9eab-a3242eaf1bad every 60s
(0s): Checking for inference results


In [15]:
import json
with open('data.json', 'w') as file:
    json.dump(results, file, indent=4)

In [24]:
project = "walking-staff-detection-ms3uf"

for frame in range(len(results["frame_offset"])):
    for obj in results[project][frame]:
        print(obj, " : ", results[project][frame][obj])
    if frame == 10:
        break

inference_id  :  7baac446-8217-41e6-8857-f72c6339be12
time  :  0.2827616059948923
image  :  {'width': 960, 'height': 720}
predictions  :  [{'x': 679.0, 'y': 94.0, 'width': 104.0, 'height': 136.0, 'confidence': 0.8160362839698792, 'class': 'people', 'class_id': 0, 'detection_id': 'aad9c306-8677-4f80-a349-2a489202e0fc'}]
inference_id  :  7baac446-8217-41e6-8857-f72c6339be12
time  :  0.28276918600022327
image  :  {'width': 960, 'height': 720}
predictions  :  [{'x': 678.0, 'y': 96.0, 'width': 100.0, 'height': 134.0, 'confidence': 0.8084633946418762, 'class': 'people', 'class_id': 0, 'detection_id': 'c7122bf5-e451-404e-9ef1-e24759de9ac3'}]
inference_id  :  7baac446-8217-41e6-8857-f72c6339be12
time  :  0.28277114700176753
image  :  {'width': 960, 'height': 720}
predictions  :  [{'x': 677.0, 'y': 97.0, 'width': 102.0, 'height': 132.0, 'confidence': 0.812053918838501, 'class': 'people', 'class_id': 0, 'detection_id': 'e2e5909f-db56-4d84-b297-e00c9133bff6'}]
inference_id  :  7baac446-8217-41e6-