# Part 1: Complete Project Setup (Run This Cell Only Once Per Session)

In [None]:
import os
import sys
from google.colab import drive

# --- 1. Define Key Paths ---
drive_mount_point = '/content/drive'
master_data_dir_on_drive = os.path.join(drive_mount_point, 'My Drive/colab_data/yolo2_data')
master_video_path_on_drive = os.path.join(master_data_dir_on_drive, 'soccer.mp4')
local_data_dir = '/content/data'
local_video_path = os.path.join(local_data_dir, 'soccer.mp4')

# --- 2. Mount Google Drive ---
print("Mounting Google Drive...")
drive.mount(drive_mount_point, force_remount=True)

# --- 3. High-Speed Data Sync Logic ---
print("\nSyncing data for high-speed access...")
os.makedirs(local_data_dir, exist_ok=True)
if not os.path.exists(local_video_path):
    print("Local data not found. Checking Google Drive...")
    
    if not os.path.exists(master_video_path_on_drive):
        print("Data not found on Drive. Performing ONE-TIME download from internet...")
        os.makedirs(master_data_dir_on_drive, exist_ok=True)
        # The gdown command to download your specific data to your Drive
        !gdown --id 1-2S26402YUn_S2aG_2S1i-tWbW0QASgB -O '{master_video_path_on_drive}'
        print("✓ Dataset downloaded to Google Drive.")
    else:
        print("✓ Data found on Google Drive.")

    # The key step: Copy from slow Drive to fast local SSD
    print("Copying data from Google Drive to local VM for high-speed access...")
    !cp '{master_video_path_on_drive}' '{local_video_path}'
    print("✓ Data is now on the local SSD.")
else:
    print("✓ High-speed local data already exists on the VM.")

# --- 4. GitHub Repo and Environment Setup ---
print("\nSetting up project repository and dependencies...")
if not os.path.exists('/content/yolo2'):
    !git clone https://github.com/victornaguiar/yolo2.git /content/yolo2
else:
    print("Repository already exists.")

project_dir = '/content/yolo2'
%cd {project_dir}

if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

!pip install -q -r requirements.txt
!pip install -q boxmot

# --- 5. Link the Local Data to the Project ---
project_data_dir = os.path.join(project_dir, 'data')
os.makedirs(project_data_dir, exist_ok=True)
linked_video_path = os.path.join(project_data_dir, 'soccer.mp4')

if os.path.lexists(linked_video_path):
    os.remove(linked_video_path)

# Create the symlink to the FAST, LOCAL copy of the data
os.symlink(local_video_path, linked_video_path)

# --- Verification ---
print("\n======================================================================")
print("✓✓✓ ENVIRONMENT IS FULLY PREPARED FOR HIGH-PERFORMANCE RUN ✓✓✓")
print("======================================================================")
print(f"Model will read data from (fast local SSD): {os.path.realpath(linked_video_path)}")
!ls -l {project_data_dir}

# Part 2: Simple Tracking Demo

## 1. Imports

In [None]:
import cv2
from pathlib import Path

from boxmot import DeepOCSORT
from boxmot.utils import ROOT, WEIGHTS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils.checks import TestRequirements
from boxmot.utils.torch_utils import select_device
from boxmot.utils.plotting import Colors, Annotator
from boxmot.yolo.utils.files import increment_path

from ultralytics import YOLO

## 2. Define parameters

In [None]:
# define some parameters
yolo_model = Path('yolov8n.pt')
tracking_method = 'deepocsort' # deepocsort, botsort, strongsort, ocsort, bytetrack
reid_model = Path('osnet_x0_25_msmt17.pt')
source = Path('data/people_walking.mp4')
device = 'cpu' # 'cuda:0', 'cuda:1', ..
save = True

## 3. Create instances

In [None]:
# Create instances
model = YOLO(yolo_model)

tracker = create_tracker(
    tracker_type=tracking_method,
    model_weights=reid_model, # which ReID model to use
    device=device, # 'cpu', 'cuda:0', 'cuda:1', ...
    fp16=False, # wether to run the ReID model with fp16
    #asso_func="ciou",  # 'iou' or 'ciou'
    #delta_t=3, # time step
    #asso_thresh=0.2, # iou threshold
    #min_hits=3, # minimum hits to create a track
    #inertia=0.2, # inertia factor
    #use_byte=True # wether to use byte track
)

## 4. Run detection and tracking

In [None]:
# Run detection and tracking
results = model.track(
    source=str(source),
    tracker=tracking_method, # here you can choose the tracker type
    persist=True,
    conf=0.3,
    iou=0.5,
    classes=0, # track people only
    verbose=False,
    #tracker='cfg/trackers/deepocsort.yaml' # you can also specify your own tracker config
)

## 5. Process results

In [None]:
# Process results
for frame in results:
    # you can print the frame object to see what it contains
    # print(frame)
    # you can also see the frame image
    # frame.show()
    pass

# Part 3: Soccer Tracking Pipeline

## 1. Imports

In [None]:
import cv2
import torch
import numpy as np
from pathlib import Path

from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors

from boxmot import DeepOCSORT
from boxmot.utils import ROOT, WEIGHTS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils.checks import TestRequirements
from boxmot.utils.torch_utils import select_device
from boxmot.utils.plotting import Colors, Annotator
from boxmot.yolo.utils.files import increment_path

## 2. Utils

In [None]:
def get_video_info(video_path):
    # Open the video file
    cap = cv2.VideoCapture(video_path)

    # Check if the video file was opened successfully
    if not cap.isOpened():
        print("Error: Could not open video file.")
        return None

    # Get video properties
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Release the video capture object
    cap.release()

    # Create a dictionary to store video information
    video_info = {
        'width': width,
        'height': height,
        'fps': fps,
        'frame_count': frame_count
    }

    return video_info

def create_video_writer(video_info, output_path, fourcc='mp4v'):
    # Create a VideoWriter object to save the output video
    fourcc = cv2.VideoWriter_fourcc(*fourcc) # You can also use 'XVID' for .avi format
    video_writer = cv2.VideoWriter(output_path, fourcc, video_info['fps'], (video_info['width'], video_info['height']))

    return video_writer

def ensure_dir(path):
    # Create the directory if it does not exist
    Path(path).mkdir(parents=True, exist_ok=True)

## 3. Parameters

In [None]:
# define some parameters
yolo_model = Path('../models/yolov8n.pt')
source = Path('../data/soccer.mp4')
device = 'cuda:0' # 'cpu', 'cuda:0', 'cuda:1', ..
project = '../output'
name = 'exp'
save = True
show = False
# tracking parameters
conf = 0.5
iou = 0.7
classes = [0, 32] # track person and ball

## 4. Dataloader

In [None]:
# setup dataloader
from boxmot.yolo.utils.dataloaders import LoadStreams, LoadImages
from boxmot.yolo.utils.torch_utils import select_device

device = select_device(device)

dataset = LoadImages(source, img_size=640, stride=32, auto=True)
nr_frames = dataset.nf

## 5. Instanciate models

In [None]:
# Create instances
yolo = YOLO(yolo_model)
yolo.to(device)

## 6. Run tracking

In [None]:
# instanciate trackers
from boxmot import DeepOCSORT

tracker = DeepOCSORT(
    model_weights=Path('osnet_x0_25_msmt17.pt'), # which ReID model to use
    device='cuda:0', # 'cpu', 'cuda:0', 'cuda:1', ...
    fp16=True, # wether to run the ReID model with fp16
)

In [None]:
from tqdm import tqdm

tracking_results = {}
yolo_preds = []

#for tracker_name, tracker in trackers.items():
    
frame_idx = 0
all_tracks = []

for frame in tqdm(dataset, desc=f'Processing frames for OCSORT', total=nr_frames):

    # yolo detection
    path, im, im0s, vid_cap, s = frame
    # make detections 
    preds = yolo.predict(
        im, 
        conf=conf, 
        iou=iou, 
        classes=classes,
        verbose=False
    )
    yolo_preds.append(preds[0])

    # track
    tracks = tracker.update(preds[0].boxes.data.cpu(), im0s)
    if tracks.size > 0:
        # frame_idx, track_id, x, y, w, h, class_id, conf
        all_tracks.append(np.concatenate((np.full((tracks.shape[0], 1), frame_idx), tracks), axis=1))

    frame_idx += 1

all_tracks = np.concatenate(all_tracks, axis=0)
tracking_results['ocsort'] = all_tracks

## 7. Generate Videos

In [None]:
# Generate tracking videos for visualization
if 'tracking_results' in locals() and tracking_results and data_loader:
    video_output_dir = '../output/videos'
    ensure_dir(video_output_dir)

    for tracker_name, all_tracks in tracking_results.items():
        print(f'\nGenerating video for {tracker_name} tracker...')

        # Create annotated frames
        annotated_frames = []
        video_info = get_video_info(str(source))
        colors = Colors()

        for frame_idx, frame in enumerate(tqdm(data_loader, desc=f'Annotating frames for {tracker_name}', total=video_info['frame_count'])):

            path, im, im0s, vid_cap, s = frame
            annotator = Annotator(im0s, 2, "Arial.ttf")

            # get tracks in this frame
            frame_tracks = all_tracks[all_tracks[:, 0] == frame_idx]

            # draw tracks
            for t in frame_tracks:
                track_id, x, y, w, h, cls, conf = t[1:]
                # convert to top-left, bottom-right
                xyxy = [x, y, x + w, y + h]
                annotator.box_label(xyxy, f'{int(track_id)}', color=colors(int(cls), True))

            annotated_frames.append(annotator.result())

        # Write video
        output_video_path = f'{video_output_dir}/{tracker_name}.mp4'
        video_writer = create_video_writer(video_info, output_video_path)

        for frame in tqdm(annotated_frames, desc=f'Writing video for {tracker_name}', total=len(annotated_frames)):
            video_writer.write(frame)

        video_writer.release()
        print(f'Video for {tracker_name} tracker saved to {output_video_path}')

# Part 4: Evaluation and Analysis

## 1. Imports

In [None]:
import sys
from pathlib import Path

import motmetrics as mm
import numpy as np
import pandas as pd
from tqdm import tqdm

## 2. Load GT

In [None]:
gt_path = Path('../data/gt/mot_challenge/ball-challenge-train/ball/gt/gt.txt')
gt = mm.io.loadtxt(gt_path, fmt='mot15-2D')
gt_df = gt.reset_index().set_index(['FrameId', 'Id'])

## 3. Convert tracking results to MOT format

In [None]:
# convert tracking results to the MOT format
ts_df = pd.DataFrame(
    all_tracks,
    columns=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'ClassId', 'Conf']
).set_index(['FrameId', 'Id'])

# only evaluate on the ball class
ts_df = ts_df[ts_df['ClassId'] == 32]

## 4. Run MOT evaluation

In [None]:
acc = mm.MOTAccumulator(auto_id=True)

# iterate over all frames of the video
for frame_id in tqdm(gt_df.index.get_level_values('FrameId').unique()):
    
    # get the gt and ts objects for the current frame
    gt_frame = gt_df.loc[frame_id]
    ts_frame = ts_df.loc[frame_id] if frame_id in ts_df.index.get_level_values('FrameId') else pd.DataFrame()
    
    # compute the distance between the gt and ts objects
    C = mm.distances.iou_matrix(gt_frame[['X', 'Y', 'Width', 'Height']], ts_frame[['X', 'Y', 'Width', 'Height']], max_iou=0.5)
    
    # update the accumulator with the results for the current frame
    acc.update(
        gt_frame.index.get_level_values('Id').tolist(),
        ts_frame.index.get_level_values('Id').tolist(),
        C
    )

## 5. Get metrics

In [None]:
mh = mm.metrics.create()
summary = mh.compute(
    acc, 
    metrics=mm.metrics.motchallenge_metrics, 
    name='acc'
)

print(mm.io.render_summary(
    summary, 
    formatters=mh.formatters, 
