<a href="https://colab.research.google.com/github/shredder0812/endocv/blob/main/track0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load video and model
## Video:
- UTDD: BVK042.mp4

- UTTQ: CS201.mp4

- test_6s

- test_3s

## Model
- thucquan.pt: ['2_Viem_thuc_quan', '5_Ung_thu_thuc_quan']

- daday.pt: ['3_Viem_da_day_HP_am', '4_Viem_da_day_HP_duong', '6_Ung_thu_da_day']

- htt.pt: ['7_Loet_HTT']

- best0903.pt, best2602.pt: ['polyp', 'esophagael cancer']


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [17]:
!cp /content/drive/MyDrive/ENDOCV/video_test/UTDD/BVK042.mp4 /content/drive/MyDrive/ENDOCV/video_test/UTTQ/CS201.mp4 /content/drive/MyDrive/ENDOCV/video_test/UTTQ/BVK037.mp4 /content

In [18]:
!cp /content/drive/MyDrive/ENDOCV/model_pt/model_yolo/thucquan.pt /content/drive/MyDrive/ENDOCV/model_pt/model_yolo/daday.pt /content

In [None]:
!cp /content/drive/MyDrive/ENDOCV/model_pt/best0903.pt /content/drive/MyDrive/ENDOCV/model_pt/best2602.pt /content/drive/MyDrive/ENDOCV/video_test/test_3s.mp4 /content/drive/MyDrive/ENDOCV/video_test/test_6s.mp4 /content

# Install requirements

In [None]:
!pip install torchvision==0.17.1+cu121 -f https://download.pytorch.org/whl/torch_stable.html
!pip install ultralytics
!pip install boxmot

In [None]:
!git clone https://github.com/mikel-brostrom/yolo_tracking.git
!pip install -v -e .

In [None]:
!git clone https://github.com/KeeganFernandesWork/yolo_tracking
%cd yolo_tracking
!pip install -r requirements.txt
!pip install .

# Track

In [None]:
from ultralytics import YOLO
from pathlib import Path
from time import perf_counter
import cv2
import numpy as np
import torch
from boxmot import (OCSORT, BoTSORT, BYTETracker, DeepOCSORT, StrongSORT, create_tracker, get_tracker_config)
from pathlib import Path
import sys
import datetime
import pandas as pd
from google.colab.patches import cv2_imshow

In [39]:
test_vid = "/content/CS201.mp4"
model_weights = "/content/best2602.pt"

# Lấy tên file video từ test_vid
input_video_name = test_vid.split("/")[-1].split(".")[0]

# Lấy tên file video từ test_vid
input_video_name = test_vid.split("/")[-1].split(".")[0]

# Tạo từ điển ánh xạ giữa tên model_weights và model_classes
model_classes_dict = {
    "/content/thucquan.pt": ['2_Viem_thuc_quan', '5_Ung_thu_thuc_quan'],
    "/content/daday.pt": ['3_Viem_da_day_HP_am', '4_Viem_da_day_HP_duong', '6_Ung_thu_da_day'],
    "/content/htt.pt": ['7_Loet_HTT']
}

# Thiết lập model_classes từ từ điển, nếu không khớp thì trả về ['polyp', 'esophagael cancer']
model_classes = model_classes_dict.get(model_weights, ['polyp', 'esophagael cancer'])

print("Input Video Name:", input_video_name)
print("Model Classes:", model_classes)

Input Video Name: CS201
Model Classes: ['polyp', 'esophagael cancer']


In [33]:
class Colors:
    def __init__(self, num_colors=80):
        self.num_colors = num_colors
        self.color_palette = self.generate_color_palette()


    def generate_color_palette(self):
        hsv_palette = np.zeros((self.num_colors, 1, 3), dtype=np.uint8)
        hsv_palette[:, 0, 0] = np.linspace(0, 180, self.num_colors, endpoint=False)
        hsv_palette[:, :, 1:] = 255
        bgr_palette = cv2.cvtColor(hsv_palette, cv2.COLOR_HSV2BGR)
        return bgr_palette.reshape(-1, 3)

    def __call__(self, class_id):
        color = tuple(map(int, self.color_palette[class_id]))
        return color

In [40]:
class ObjectDetection:
    def __init__(self, model_weights="yolov8s.pt", capture_index=0):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Using Device: ", self.device)
        self.model = self.load_model(model_weights)
        self.classes = self.model.names
        #self.classes[0] = 'gastritis'
        self.classes = model_classes
        self.colors = Colors(len(self.classes))
        self.font = cv2.FONT_HERSHEY_SIMPLEX
        self.capture_index = capture_index
        self.cap = self.load_capture()
        reid_weights = Path("osnet_x0_25_msmt17.pt")
        self.tracker = StrongSORT(reid_weights,
                                  torch.device(self.device),
                                  fp16 = False,
                                  )

    def load_model(self, weights):
        model = YOLO(weights)
        model.fuse()
        return model

    def predict(self, frame):
        results = self.model(frame, stream=True, verbose=False, conf=0.8, line_width=1)
        return results

    def draw_tracks(self, frame, tracks, txt_file, overlap_threshold=0.5):
        for track in tracks:
            x1, y1, x2, y2 = int(track[0]), int(track[1]), int(track[2]), int(track[3])
            id = int(track[4])
            conf = round(track[5], 2)
            class_id = int(track[6])
            class_name = self.classes[class_id]
            cv2.rectangle(frame, (x1,y1), (x2, y2), self.colors(class_id), 2)
            label = f'{class_name} {conf}' # hiển thị
            (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.5, 3)
            cv2.rectangle(frame, (x1, y1-h-15), (x1+w, y1), self.colors(class_id), -1)
            cv2.putText(frame, label, (x1,y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255,255,255) , 3)
            # Ghi kết quả vào file txt
            txt_file.write(f"{int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))}, {id}, {x1}, {y1}, {x2-x1}, {y2-y1}, {conf}, -1, -1, -1\n")

        return frame


    def load_capture(self):
        cap = cv2.VideoCapture(self.capture_index)
        assert cap.isOpened()
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
        video_name = "strongsort_" + input_video_name + ".mp4"
        self.writer = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        return cap

    def write_seqinfo_ini(self, seq_name, seq_length, frame_rate, im_width, im_height, im_ext, im_dir):
        with open("seqinfo.ini", "w") as f:
            f.write("[Sequence]\n")
            f.write(f"name={seq_name}\n")
            f.write(f"imDir={im_dir}\n")  # Thay thế bằng thư mục chứa ảnh nếu cần
            f.write(f"frameRate={frame_rate}\n")
            f.write(f"seqLength={seq_length}\n")
            f.write(f"imWidth={im_width}\n")
            f.write(f"imHeight={im_height}\n")
            f.write(f"imExt={im_ext}\n")


    def __call__(self):
        tracker = self.tracker


        # Lấy thông tin từ video kết quả
        seq_name = "StrongSort"
        im_dir="img1"
        seq_length = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
        im_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        im_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        im_ext = ".jpg"  # Phần mở rộng của ảnh

        # Ghi thông tin vào file seqinfo.ini
        self.write_seqinfo_ini(seq_name, seq_length, frame_rate, im_width, im_height, im_ext, im_dir)

        # Mở file txt để ghi kết quả
        with open("results.txt", "w") as txt_file:
          while True:
              start_time = perf_counter()
              ret, frame = self.cap.read()
              if not ret:
                  break
              cv2.rectangle(frame, (0,30), (220,80), (255,255,255),-1 )
              detections = self.predict(frame)
              for dets in detections:
                  tracks = tracker.update(dets.boxes.data.to("cpu").numpy(), frame)
                  if len(tracks.shape) == 2 and tracks.shape[1] == 8:
                      frame = self.draw_tracks(frame, tracks, txt_file)
                  #print(tracks)
              end_time = perf_counter()
              fps = 1/np.round(end_time - start_time, 2)
              cv2.putText(frame, f'FPS: {int(fps)}', (20,70), self.font, 1.5, (0,255,0), 5)
              self.writer.write(frame)
              #cv2_imshow(frame)
              if cv2.waitKey(5) & 0xFF == 27:
                  break
          self.cap.release()
          self.writer.release()
          cv2.destroyAllWindows()

In [None]:
detector = ObjectDetection(model_weights, test_vid)
detector()



In [None]:
# test_vid = "/content/BVK042.mp4"
# model_weights = "/content/daday.pt"
# detector = ObjectDetection(model_weights, test_vid)
# detector()

In [None]:
# Frame trước ổn định nhưng sau đó box gặp vấn đề: vấn đề 1 là đối tượng biến dạng mạnh nhưng box không track theo kịp; vấn đề 2 là có nhiều box track cùng 1 đối tượng duy nhất