In [1]:
import cv2
import os
from ultralytics import YOLO

# Load two models
model_detect = YOLO("best.pt")  # For vehicle detection & counting
model_violation = YOLO("violation.pt")  # For signal detection (red, green)

# Create directory to store violation images
os.makedirs("violations", exist_ok=True)

# Open video
cap = cv2.VideoCapture("traffic.mp4")
ret, frame = cap.read()
height, width = frame.shape[:2]
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

# Save output video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter("output.mp4", fourcc, 15.0, (width, height))

# Setup
vehicle_types = ['car', 'bus', 'truck', 'motorcycle', 'vehicle']
type_counts = {vt: 0 for vt in vehicle_types}
vehicle_count = 0
red_light_violations = 0
wrong_way_violations = 0

line_position = height // 2
stop_line_y = line_position + 30

counted_centers = []
track_history = {}
counted_violations = set()

frame_number = 0

while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_number += 1

    # Signal detection (from second model)
    signal_results = model_violation(frame)
    red_light_active = False
    for box in signal_results[0].boxes:
        cls = int(box.cls[0])
        label = model_violation.names[cls]
        if label == "red_light":
            red_light_active = True
            cv2.putText(frame, "🔴 RED SIGNAL", (width - 250, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
            break

    # Vehicle detection (from first model)
    results = model_detect(frame)
    boxes = results[0].boxes

    for box in boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        cls = int(box.cls[0])
        conf = float(box.conf[0])
        label = model_detect.names[cls]

        if conf < 0.5 or label not in vehicle_types:
            continue

        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(frame, label, (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)

        cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
        cv2.circle(frame, (cx, cy), 4, (0, 0, 255), -1)

        track_id = f"{label}_{(cx // 10)}_{(cy // 10)}"
        if track_id not in track_history:
            track_history[track_id] = []
        track_history[track_id].append((cx, cy))

        # Count vehicles
        if abs(cy - line_position) < 30:
            already_counted = any(
                abs(cx - px) < 20 and abs(cy - py) < 20 and label == ptype
                for (px, py, ptype) in counted_centers
            )
            if not already_counted:
                vehicle_count += 1
                type_counts[label] += 1
                counted_centers.append((cx, cy, label))

        # Red light violation
        violation_tag = f"{label}_{(cx // 10)}_{(cy // 10)}"
        if red_light_active and cy < stop_line_y:
            if violation_tag not in counted_violations:
                red_light_violations += 1
                counted_violations.add(violation_tag)
                cv2.putText(frame, "🚨 RED LIGHT VIOLATION!", (x1, y1 - 30),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
                crop = frame[y1:y2, x1:x2]
                cv2.imwrite(f"violations/redlight_{red_light_violations}.jpg", crop)

        # Wrong way
        if len(track_history[track_id]) >= 2:
            y_prev = track_history[track_id][-2][1]
            y_curr = track_history[track_id][-1][1]
            if y_curr < y_prev - 10:
                wrong_way_violations += 1
                cv2.putText(frame, "🔁 WRONG WAY!", (x1, y1 - 50),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
                crop = frame[y1:y2, x1:x2]
                cv2.imwrite(f"violations/wrongway_{wrong_way_violations}.jpg", crop)

    # Draw lines
    cv2.line(frame, (0, line_position), (width, line_position), (255, 0, 255), 2)
    cv2.line(frame, (0, stop_line_y), (width, stop_line_y), (0, 0, 255), 2)

    # Stats
    cv2.putText(frame, f"Total Vehicles: {vehicle_count}", (20, 40),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 200), 2)
    y_offset = 80
    for vt in vehicle_types:
        cv2.putText(frame, f"{vt.capitalize()}s: {type_counts[vt]}", (20, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 100, 255), 2)
        y_offset += 30

    cv2.putText(frame, f"Red Light Violations: {red_light_violations}", (20, y_offset + 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2)
    cv2.putText(frame, f"Wrong-Way Violations: {wrong_way_violations}", (20, y_offset + 40),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 255), 2)

    out.write(frame)

cap.release()
out.release()
print("✅ Done: output.mp4 and violation images saved")



0: 384x640 2 red_lights, 2 vehicles, 92.7ms
Speed: 6.6ms preprocess, 92.7ms inference, 11.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 8 cars, 5 motorbikes, 3 persons, 33.4ms
Speed: 2.3ms preprocess, 33.4ms inference, 3.0ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 vehicles, 35.3ms
Speed: 1.8ms preprocess, 35.3ms inference, 0.9ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 bicycle, 7 cars, 4 motorbikes, 6 persons, 34.0ms
Speed: 1.3ms preprocess, 34.0ms inference, 1.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 2 vehicles, 37.3ms
Speed: 1.2ms preprocess, 37.3ms inference, 1.1ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 bicycle, 7 cars, 5 motorbikes, 6 persons, 33.2ms
Speed: 1.5ms preprocess, 33.2ms inference, 0.6ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 red_light, 2 vehicles, 1 yellow_light, 34.0ms
Speed: 1.2ms preprocess, 34.0ms inference, 0.7ms postprocess per image a