In [None]:
# Advanced code
from google.colab.patches import cv2_imshow
import cv2
import numpy as np
import torch
from collections import OrderedDict
import math
import csv
import matplotlib.pyplot as plt
import pandas as pd
import time

# Load YOLOv5 model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

VEHICLE_CLASSES = ["car", "truck", "bus", "motorcycle"]
ALERT_CLASSES = ["bus"]
CONFIDENCE_THRESHOLD = 0.5
IOU_THRESHOLD = 0.4
LINE_POSITION = 250
SPEED_LINE_POSITION = 150  # Rough line before the main line
REAL_WORLD_DISTANCE = 10  # in meters (rough estimate between speed lines)
TRACKING_ZONE_HEIGHT = 50

class Tracker:
    def __init__(self, max_distance=50):
        self.center_points = OrderedDict()  # id: (cx, cy, class, counted, t_entry)
        self.id_count = 0
        self.max_distance = max_distance

    def update(self, detections, current_time):
        updated_ids = {}
        for cx, cy, cls in detections:
            if cy < LINE_POSITION - TRACKING_ZONE_HEIGHT:
                continue

            matched_id = None
            min_dist = float('inf')

            for obj_id, (px, py, pcls, counted, t_entry) in self.center_points.items():
                dist = math.hypot(cx - px, cy - py)
                if dist < self.max_distance and cls == pcls and dist < min_dist:
                    min_dist = dist
                    matched_id = obj_id

            if matched_id is not None:
                prev = self.center_points[matched_id]
                updated_ids[matched_id] = (cx, cy, cls, prev[3], prev[4])
            else:
                updated_ids[self.id_count] = (cx, cy, cls, False, current_time)
                self.id_count += 1

        self.center_points = updated_ids.copy()
        return updated_ids

tracker = Tracker()
log_data = []
speed_dict = {}


def calculate_speed(entry_time, exit_time):
    time_diff = exit_time - entry_time
    if time_diff == 0:
        return 0
    speed_mps = REAL_WORLD_DISTANCE / time_diff
    speed_kmph = speed_mps * 3.6
    return round(speed_kmph, 2)


def process_video(video_path, skip_frames=5):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file at {video_path}")
        return

    total_vehicles = 0
    frame_count = 0
    fps = cap.get(cv2.CAP_PROP_FPS)
    ret, frame = cap.read()
    if not ret:
        print("Error: Could not read video")
        return

    while ret:
        frame_count += 1
        current_time = frame_count / fps

        if frame_count % skip_frames != 0:
            ret, frame = cap.read()
            continue

        frame_resized = cv2.resize(frame, (640, 360))
        results = model(frame_resized)

        boxes, confidences, class_names, centers = [], [], [], []
        for det in results.pred[0]:
            x1, y1, x2, y2, conf, cls = det.tolist()
            class_name = results.names[int(cls)]
            if class_name in VEHICLE_CLASSES and conf > CONFIDENCE_THRESHOLD:
                boxes.append([int(x1), int(y1), int(x2), int(y2)])
                confidences.append(conf)
                class_names.append(class_name)
                cx = int((x1 + x2) / 2)
                cy = int((y1 + y2) / 2)
                centers.append((cx, cy, class_name))

        tracked_objects = tracker.update(centers, current_time)

        for obj_id, (cx, cy, cls, counted, t_entry) in tracked_objects.items():
            speed = "-"
            if not counted and cy > LINE_POSITION:
                exit_time = current_time
                speed = calculate_speed(t_entry, exit_time)
                speed_dict[obj_id] = speed
                total_vehicles += 1
                tracker.center_points[obj_id] = (cx, cy, cls, True, t_entry)

            # Alert for specific vehicle type
            if cls in ALERT_CLASSES:
                cv2.putText(frame_resized, "ALERT: BUS DETECTED", (10, 320),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)

            cv2.circle(frame_resized, (cx, cy), 5, (0, 255, 0), -1)
            cv2.putText(frame_resized, f'ID {obj_id}', (cx, cy - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 2)

        for i in range(len(boxes)):
            x1, y1, x2, y2 = boxes[i]
            cls = class_names[i]
            conf = confidences[i]
            cx, cy, _ = centers[i]
            obj_id = None
            for tid, (tcx, tcy, tcls, _, _) in tracker.center_points.items():
                if abs(cx - tcx) < 5 and abs(cy - tcy) < 5 and tcls == cls:
                    obj_id = tid
                    break

            total_count = total_vehicles if tracker.center_points.get(obj_id, (None, None, None, False))[3] else 0
            speed_display = speed_dict.get(obj_id, '-')
            log_data.append([frame_count, obj_id if obj_id is not None else "-", cls, f"{conf:.2f}", cx, cy, total_count, speed_display])

            cv2.rectangle(frame_resized, (x1, y1), (x2, y2), (0, 255, 255), 2)
            cv2.putText(frame_resized, f'{cls} {conf:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
            if speed_display != "-":
                cv2.putText(frame_resized, f'Speed: {speed_display} km/h', (x1, y2 + 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 2)

        cv2.line(frame_resized, (0, LINE_POSITION), (640, LINE_POSITION), (0, 0, 255), 2)
        cv2.line(frame_resized, (0, SPEED_LINE_POSITION), (640, SPEED_LINE_POSITION), (255, 0, 0), 1)
        text = f"Total Vehicles: {total_vehicles}"
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
        text_x = frame_resized.shape[1] - text_size[0] - 20
        text_y = 30
        cv2.putText(frame_resized, text, (text_x, text_y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

        if frame_count % 10 == 0:
            cv2_imshow(frame_resized)

        ret, frame = cap.read()

    cap.release()

    with open("vehicle_log_revised_advanced.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Frame", "ObjectID", "Class", "Confidence", "CentroidX", "CentroidY", "TotalCount", "Speed_kmph"])
        writer.writerows(log_data)

    print(f"\nLog saved as 'vehicle_log.csv' with {len(log_data)} entries.")
    print(f" Total vehicles detected: {total_vehicles}")

    try:
        true_count = int(input("\n🔢 Enter actual vehicle count (if known): "))
        accuracy = (1 - abs(true_count - total_vehicles) / true_count) * 100
        print(f"Estimated detection accuracy: {accuracy:.2f}%")
    except:
        print("ℹ️ Skipped accuracy calculation (no input).")

    df = pd.read_csv("vehicle_log_revised_advanced.csv")
    class_counts = df["Class"].value_counts()
    frame_trend = df.groupby("Frame")["ObjectID"].nunique()

    plt.figure(figsize=(8, 5))
    class_counts.plot(kind='bar', color='skyblue', edgecolor='black')
    plt.title("Per-Class Vehicle Count")
    plt.xlabel("Vehicle Class")
    plt.ylabel("Count")
    plt.grid(axis='y')
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(frame_trend.index, frame_trend.values, marker='o', linestyle='-', color='red')
    plt.title("Vehicle Count Trend per Frame")
    plt.xlabel("Frame")
    plt.ylabel("Unique Vehicle Count")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Run the function
process_video('/content/drive/MyDrive/car_detection_dataset/SRMvideo1.mp4', skip_frames=3)
