<a href="https://colab.research.google.com/github/shredder0812/endocv/blob/main/done_strongsort.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!cp /content/drive/MyDrive/ENDOCV/Daday_3.mp4 /content/drive/MyDrive/ENDOCV/2602/detect/train2/weights/best.pt /content

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install ultralytics
!pip install boxmot


In [None]:
!git clone https://github.com/KeeganFernandesWork/yolo_tracking
%cd yolo_tracking
!pip install -r requirements.txt
!pip install .

In [None]:
from boxmot import (OCSORT, BoTSORT, BYTETracker, DeepOCSORT, StrongSORT,
                    create_tracker, get_tracker_config)
from pathlib import Path
import cv2
import sys
import numpy as np
import datetime
from ultralytics import YOLO

import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

In [None]:
from ultralytics import YOLO
from boxmot import StrongSORT
from pathlib import Path
from time import perf_counter
import cv2
import numpy as np
import torch

class Colors:
    def __init__(self, num_colors=80):
        self.num_colors = num_colors
        self.color_palette = self.generate_color_palette()


    def generate_color_palette(self):
        hsv_palette = np.zeros((self.num_colors, 1, 3), dtype=np.uint8)
        hsv_palette[:, 0, 0] = np.linspace(0, 180, self.num_colors, endpoint=False)
        hsv_palette[:, :, 1:] = 255
        bgr_palette = cv2.cvtColor(hsv_palette, cv2.COLOR_HSV2BGR)
        return bgr_palette.reshape(-1, 3)

    def __call__(self, class_id):
        color = tuple(map(int, self.color_palette[class_id]))
        return color

class ObjectDetection:
    def __init__(self, model_weights="yolov8s.pt", capture_index=0):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Using Device: ", self.device)
        self.model = self.load_model(model_weights)
        self.classes = self.model.names
        self.classes[0] = 'polyp'
        self.colors = Colors(len(self.classes))
        self.font = cv2.FONT_HERSHEY_SIMPLEX
        self.capture_index = capture_index
        self.cap = self.load_capture()
        reid_weights = Path("osnet_x0_25_msmt17.pt")

        self.tracker = StrongSORT(reid_weights,
                                  torch.device(self.device),
                                  fp16 = False,
                                  )

    def load_model(self, weights):
        model = YOLO(weights)
        model.fuse()
        return model

    def predict(self, frame):
        results = self.model(frame, stream=True, verbose=False, conf=0.45, line_width=1)
        return results

    def draw_tracks(self, frame, tracks):
        for track in tracks:
            x1, y1, x2, y2 = int(track[0]), int(track[1]), int(track[2]), int(track[3])
            id = int(track[4])
            conf = track[5]
            class_id = int(track[6])
            class_name = self.classes[class_id]
            cv2.rectangle(frame, (x1,y1), (x2, y2), self.colors(class_id), 2)
            label = f'{class_name} {id}'
            (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
            cv2.rectangle(frame, (x1, y1-h-15), (x1+w, y1), self.colors(class_id), -1)
            cv2.putText(frame, label, (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255) , 2)
        return frame

    def load_capture(self):
        cap = cv2.VideoCapture(self.capture_index)
        assert cap.isOpened()
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
        self.writer = cv2.VideoWriter(fr'strongsort_daday3.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        return cap

    def __call__(self):
        tracker = self.tracker
        while True:
            start_time = perf_counter()
            ret, frame = self.cap.read()
            if not ret:
                break
            cv2.rectangle(frame, (0,30), (220,80), (255,255,255),-1 )
            detections = self.predict(frame)
            for dets in detections:
                tracks = tracker.update(dets.boxes.data.to("cpu").numpy(), frame)
                if len(tracks.shape) == 2 and tracks.shape[1] == 8:
                    frame = self.draw_tracks(frame, tracks)
            end_time = perf_counter()
            fps = 1/np.round(end_time - start_time, 2)
            cv2.putText(frame, f'FPS: {int(fps)}', (20,70), self.font, 1.5, (0,255,0), 2)
            self.writer.write(frame)
            #cv2.imshow('YOLOv8 Tracking', frame)
            if cv2.waitKey(5) & 0xFF == 27:
                break
        self.cap.release()
        self.writer.release()
        cv2.destroyAllWindows()

test_vid = "/content/Daday_3.mp4"
model_weights = "/content/best.pt"
detector = ObjectDetection(model_weights, test_vid)
detector()

Using Device:  cuda
Model summary (fused): 268 layers, 68124531 parameters, 0 gradients, 257.4 GFLOPs


Downloading...
From: https://drive.google.com/uc?id=1sSwXSUlj4_tHZequ_iZ8w_Jh0VaRQMqF
To: /content/yolo_tracking/osnet_x0_25_msmt17.pt
100%|██████████| 3.06M/3.06M [00:00<00:00, 101MB/s]
[32m2024-02-26 03:53:19.075[0m | [32m[1mSUCCESS [0m | [36mboxmot.appearance.reid_model_factory[0m:[36mload_pretrained_weights[0m:[36m207[0m - [32m[1mSuccessfully loaded pretrained weights from "osnet_x0_25_msmt17.pt"[0m
