In [1]:
SOURCE_VIDEO_PATH = "/content/baseball.mp4"

In [3]:
!pip install ultralytics
from ultralytics import YOLO

Collecting ultralytics
  Downloading ultralytics-8.3.58-py3-none-any.whl.metadata (35 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.13-py3-none-any.whl.metadata (9.4 kB)
Downloading ultralytics-8.3.58-py3-none-any.whl (905 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m905.3/905.3 kB[0m [31m24.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ultralytics_thop-2.0.13-py3-none-any.whl (26 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.58 ultralytics-thop-2.0.13
Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [18]:
import requests
import os
from tqdm import tqdm
import zipfile
import io
from typing import Optional, Union
import shutil

class LoadTools:
    """
    Class dedicated to downloading / loading models and datasets from either the BallDataLab API or specified text files.

    Attributes:
        session (requests.Session): Session object for making requests.
        chunk_size (int): Size of chunks to use when downloading files.
        BDL_MODEL_API (str): Base URL for the BallDataLab model API.
        BDL_DATASET_API (str): Base URL for the BallDataLab dataset API.

    Methods:
        load_model(model_alias: str, model_type: str = 'YOLO', use_bdl_api: Optional[bool] = True) -> str:
            Loads a given baseball computer vision model into the repository.
        load_dataset(dataset_alias: str, use_bdl_api: Optional[bool] = True) -> str:
            Loads a zipped dataset and extracts it to a folder.
        _download_files(url: str, dest: Union[str, os.PathLike], is_dataset: bool = False) -> None:
            Protected method to handle model and dataset downloads.
        _get_url(alias: str, txt_path: str, use_bdl_api: bool, api_endpoint: str) -> str:
            Protected method to obtain the download URL from the BDL API or a text file.
    """

    def __init__(self):
        self.session = requests.Session()
        self.chunk_size = 1024
        self.BDL_MODEL_API = "https://balldatalab.com/api/models/"
        self.BDL_DATASET_API = "https://balldatalab.com/api/datasets/"
        self.yolo_model_aliases = {
            'phc_detector': 'models/YOLO/pitcher_hitter_catcher_detector/model_weights/pitcher_hitter_catcher_detector_v4.txt',
            'bat_tracking': 'models/YOLO/bat_tracking/model_weights/bat_tracking.txt',
            'ball_tracking': './models/YOLO/ball_tracking/model_weights/ball_tracking.txt',
            'glove_tracking': 'models/YOLO/glove_tracking/model_weights/glove_tracking.txt',
            'ball_trackingv4': '/content/models/YOLO/ball_tracking/model_weights/ball_trackingv4.txt'
        }
        self.florence_model_aliases = {
            'ball_tracking': 'models/FLORENCE2/ball_tracking/model_weights/florence_ball_tracking.txt',
            'florence_ball_tracking': 'models/FLORENCE2/ball_tracking/model_weights/florence_ball_tracking.txt'
        }
        self.dataset_aliases = {
            'okd_nokd': 'datasets/yolo/OKD_NOKD.txt',
            'baseball_rubber_home_glove': 'datasets/yolo/baseball_rubber_home_glove.txt',
            'baseball_rubber_home': 'datasets/yolo/baseball_rubber_home.txt',
            'broadcast_10k_frames': 'datasets/raw_photos/broadcast_10k_frames.txt',
            'broadcast_15k_frames': 'datasets/raw_photos/broadcast_15k_frames.txt',
            'baseball_rubber_home_COCO': 'datasets/COCO/baseball_rubber_home_COCO.txt',
            'baseball_rubber_home_glove_COCO': 'datasets/COCO/baseball_rubber_home_glove_COCO.txt',
            'baseball': 'datasets/yolo/baseball.txt'
        }

    def _download_files(self, url: str, dest: Union[str, os.PathLike], is_folder: bool = False, is_labeled: bool = False) -> None:
        response = self.session.get(url, stream=True)
        if response.status_code == 200:
            total_size = int(response.headers.get('content-length', 0))
            progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(dest)}")

            if is_folder:
                content = io.BytesIO()
                for data in response.iter_content(chunk_size=self.chunk_size):
                    size = content.write(data)
                    progress_bar.update(size)

                progress_bar.close()

                with zipfile.ZipFile(content) as zip_ref:
                    for file in zip_ref.namelist():
                        if not file.startswith('__MACOSX') and not file.startswith('._'):
                            if is_labeled:
                                zip_ref.extract(file, dest)
                            else:
                                if '/' in file:
                                    filename = file.split('/')[-1]
                                    if filename:
                                        with zip_ref.open(file) as source, open(os.path.join(dest, filename), 'wb') as target:
                                            shutil.copyfileobj(source, target)
                                else:
                                    zip_ref.extract(file, dest)

                if not is_labeled:
                    for root, dirs, files in os.walk(dest, topdown=False):
                        for dir in dirs:
                            dir_path = os.path.join(root, dir)
                            if not os.listdir(dir_path):
                                os.rmdir(dir_path)

                print(f"Dataset downloaded and extracted to {dest}")
            else:
                with open(dest, 'wb') as file:
                    for data in response.iter_content(chunk_size=self.chunk_size):
                        size = file.write(data)
                        progress_bar.update(size)

                progress_bar.close()
                print(f"Model downloaded to {dest}")
        else:
            print(f"Download failed. STATUS: {response.status_code}")

    def _get_url(self, alias: str, txt_path: str, use_bdl_api: bool, api_endpoint: str) -> str:
        if use_bdl_api:
            return f"{api_endpoint}{alias}"
        else:
            with open(txt_path, 'r') as file:
                return file.read().strip()

    def load_model(self, model_alias: str, model_type: str = 'YOLO', use_bdl_api: Optional[bool] = True, model_txt_path: Optional[str] = None) -> str:
        '''
        Loads a given baseball computer vision model into the repository.

        Args:
            model_alias (str): Alias of the model to load.
            model_type (str): The type of the model to utilize. Defaults to YOLO.
            use_bdl_api (Optional[bool]): Whether to use the BallDataLab API.
            model_txt_path (Optional[str]): Path to .txt file containing download link to model weights.
                                            Only used if use_bdl_api is specified as False.

        Returns:
            model_weights_path (str):  Path to where the model weights are saved within the repo.
        '''
        if model_type == 'YOLO':
            model_txt_path = self.yolo_model_aliases.get(model_alias) if use_bdl_api else model_txt_path
        elif model_type == 'FLORENCE2':
            model_txt_path = self.florence_model_aliases.get(model_alias) if use_bdl_api else model_txt_path
        else:
            raise ValueError(f"Invalid model type: {model_type}")

        if not model_txt_path:
            raise ValueError(f"Invalid alias: {model_alias}")

        base_dir = os.path.dirname(model_txt_path)
        base_name = os.path.splitext(os.path.basename(model_txt_path))[0]

        if model_type == 'YOLO':
            model_weights_path = f"{base_dir}/{base_name}.pt"
        else:
            model_weights_path = f"{base_dir}/{base_name}"
            os.makedirs(model_weights_path, exist_ok=True)

        if os.path.exists(model_weights_path):
            print(f"Model found at {model_weights_path}")
            return model_weights_path

        url = self._get_url(model_alias, model_txt_path, use_bdl_api, self.BDL_MODEL_API)
        self._download_files(url, model_weights_path, is_folder=(model_type=='FLORENCE2'))

        return model_weights_path

    def load_dataset(self, dataset_alias: str, use_bdl_api: Optional[bool] = True, file_txt_path: Optional[str] = None) -> str:
        '''
        Loads a zipped dataset and extracts it to a folder.

        Args:
            dataset_alias (str): Alias of the dataset to load that corresponds to a dataset folder to download
            use_bdl_api (Optional[bool]): Whether to use the BallDataLab API. Defaults to True.
            file_txt_path (Optional[str]): Path to .txt file containing download link to zip file containing dataset.
                                           Only used if use_bdl_api is specified as False.

        Returns:
            dir_name (str): Path to the folder containing the dataset.
        '''
        txt_path = self.dataset_aliases.get(dataset_alias) if use_bdl_api else file_txt_path
        if not txt_path:
            raise ValueError(f"Invalid alias or missing path: {dataset_alias}")

        base = os.path.splitext(os.path.basename(txt_path))[0]
        dir_name = "unlabeled_" + base if 'raw_photos' in base or 'frames' in base or 'frames' in dataset_alias else base

        if os.path.exists(dir_name):
            print(f"Dataset found at {dir_name}")
            return dir_name

        url = self._get_url(dataset_alias, txt_path, use_bdl_api, self.BDL_DATASET_API)
        os.makedirs(dir_name, exist_ok=True)
        self._download_files(url, dir_name, is_folder=True)

        return dir_name

load_tools = LoadTools()
model_weights = load_tools.load_model(model_alias='ball_trackingv4')
model = YOLO(model_weights)

Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [00:45<?, ?iB/s]


Model found at /content/models/YOLO/ball_tracking/model_weights/ball_trackingv4.pt


In [19]:
import requests
import os
from tqdm import tqdm
import zipfile
import io
from typing import Optional, Union
import shutil

class LoadTools:
    """
    Class dedicated to downloading / loading models and datasets from either the BallDataLab API or specified text files.

    Attributes:
        session (requests.Session): Session object for making requests.
        chunk_size (int): Size of chunks to use when downloading files.
        BDL_MODEL_API (str): Base URL for the BallDataLab model API.
        BDL_DATASET_API (str): Base URL for the BallDataLab dataset API.

    Methods:
        load_model(model_alias: str, model_type: str = 'YOLO', use_bdl_api: Optional[bool] = True) -> str:
            Loads a given baseball computer vision model into the repository.
        load_dataset(dataset_alias: str, use_bdl_api: Optional[bool] = True) -> str:
            Loads a zipped dataset and extracts it to a folder.
        _download_files(url: str, dest: Union[str, os.PathLike], is_dataset: bool = False) -> None:
            Protected method to handle model and dataset downloads.
        _get_url(alias: str, txt_path: str, use_bdl_api: bool, api_endpoint: str) -> str:
            Protected method to obtain the download URL from the BDL API or a text file.
    """

    def __init__(self):
        self.session = requests.Session()
        self.chunk_size = 1024
        self.BDL_MODEL_API = "https://balldatalab.com/api/models/"
        self.BDL_DATASET_API = "https://balldatalab.com/api/datasets/"
        self.yolo_model_aliases = {
            'phc_detector': 'models/YOLO/pitcher_hitter_catcher_detector/model_weights/pitcher_hitter_catcher_detector_v4.txt',
            'bat_tracking': 'models/YOLO/bat_tracking/model_weights/bat_tracking.txt',
            'ball_tracking': '/content/models/YOLO/ball_tracking/model_weights/ball_tracking.txt',
            'glove_tracking': 'models/YOLO/glove_tracking/model_weights/glove_tracking.txt',
            'ball_trackingv4': '/content/models/YOLO/ball_tracking/model_weights/ball_trackingv4.txt'
        }
        self.florence_model_aliases = {
            'ball_tracking': 'models/FLORENCE2/ball_tracking/model_weights/florence_ball_tracking.txt',
            'florence_ball_tracking': 'models/FLORENCE2/ball_tracking/model_weights/florence_ball_tracking.txt'
        }
        self.dataset_aliases = {
            'okd_nokd': 'datasets/yolo/OKD_NOKD.txt',
            'baseball_rubber_home_glove': 'datasets/yolo/baseball_rubber_home_glove.txt',
            'baseball_rubber_home': 'datasets/yolo/baseball_rubber_home.txt',
            'broadcast_10k_frames': 'datasets/raw_photos/broadcast_10k_frames.txt',
            'broadcast_15k_frames': 'datasets/raw_photos/broadcast_15k_frames.txt',
            'baseball_rubber_home_COCO': 'datasets/COCO/baseball_rubber_home_COCO.txt',
            'baseball_rubber_home_glove_COCO': 'datasets/COCO/baseball_rubber_home_glove_COCO.txt',
            'baseball': 'datasets/yolo/baseball.txt'
        }

    def _download_files(self, url: str, dest: Union[str, os.PathLike], is_folder: bool = False, is_labeled: bool = False) -> None:
        response = self.session.get(url, stream=True)
        if response.status_code == 200:
            total_size = int(response.headers.get('content-length', 0))
            progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(dest)}")

            if is_folder:
                content = io.BytesIO()
                for data in response.iter_content(chunk_size=self.chunk_size):
                    size = content.write(data)
                    progress_bar.update(size)

                progress_bar.close()

                with zipfile.ZipFile(content) as zip_ref:
                    for file in zip_ref.namelist():
                        if not file.startswith('__MACOSX') and not file.startswith('._'):
                            if is_labeled:
                                zip_ref.extract(file, dest)
                            else:
                                if '/' in file:
                                    filename = file.split('/')[-1]
                                    if filename:
                                        with zip_ref.open(file) as source, open(os.path.join(dest, filename), 'wb') as target:
                                            shutil.copyfileobj(source, target)
                                else:
                                    zip_ref.extract(file, dest)

                if not is_labeled:
                    for root, dirs, files in os.walk(dest, topdown=False):
                        for dir in dirs:
                            dir_path = os.path.join(root, dir)
                            if not os.listdir(dir_path):
                                os.rmdir(dir_path)

                print(f"Dataset downloaded and extracted to {dest}")
            else:
                with open(dest, 'wb') as file:
                    for data in response.iter_content(chunk_size=self.chunk_size):
                        size = file.write(data)
                        progress_bar.update(size)

                progress_bar.close()
                print(f"Model downloaded to {dest}")
        else:
            print(f"Download failed. STATUS: {response.status_code}")

    def _get_url(self, alias: str, txt_path: str, use_bdl_api: bool, api_endpoint: str) -> str:
        if use_bdl_api:
            return f"{api_endpoint}{alias}"
        else:
            with open(txt_path, 'r') as file:
                return file.read().strip()

    def load_model(self, model_alias: str, model_type: str = 'YOLO', use_bdl_api: Optional[bool] = True, model_txt_path: Optional[str] = None) -> str:
        '''
        Loads a given baseball computer vision model into the repository.

        Args:
            model_alias (str): Alias of the model to load.
            model_type (str): The type of the model to utilize. Defaults to YOLO.
            use_bdl_api (Optional[bool]): Whether to use the BallDataLab API.
            model_txt_path (Optional[str]): Path to .txt file containing download link to model weights.
                                            Only used if use_bdl_api is specified as False.

        Returns:
            model_weights_path (str):  Path to where the model weights are saved within the repo.
        '''
        if model_type == 'YOLO':
            model_txt_path = self.yolo_model_aliases.get(model_alias) if use_bdl_api else model_txt_path
        elif model_type == 'FLORENCE2':
            model_txt_path = self.florence_model_aliases.get(model_alias) if use_bdl_api else model_txt_path
        else:
            raise ValueError(f"Invalid model type: {model_type}")

        if not model_txt_path:
            raise ValueError(f"Invalid alias: {model_alias}")

        base_dir = os.path.dirname(model_txt_path)
        base_name = os.path.splitext(os.path.basename(model_txt_path))[0]

        if model_type == 'YOLO':
            model_weights_path = f"{base_dir}/{base_name}.pt"
        else:
            model_weights_path = f"{base_dir}/{base_name}"
            os.makedirs(model_weights_path, exist_ok=True)

        if os.path.exists(model_weights_path):
            print(f"Model found at {model_weights_path}")
            return model_weights_path

        url = self._get_url(model_alias, model_txt_path, use_bdl_api, self.BDL_MODEL_API)
        self._download_files(url, model_weights_path, is_folder=(model_type=='FLORENCE2'))

        return model_weights_path

    def load_dataset(self, dataset_alias: str, use_bdl_api: Optional[bool] = True, file_txt_path: Optional[str] = None) -> str:
        '''
        Loads a zipped dataset and extracts it to a folder.

        Args:
            dataset_alias (str): Alias of the dataset to load that corresponds to a dataset folder to download
            use_bdl_api (Optional[bool]): Whether to use the BallDataLab API. Defaults to True.
            file_txt_path (Optional[str]): Path to .txt file containing download link to zip file containing dataset.
                                           Only used if use_bdl_api is specified as False.

        Returns:
            dir_name (str): Path to the folder containing the dataset.
        '''
        txt_path = self.dataset_aliases.get(dataset_alias) if use_bdl_api else file_txt_path
        if not txt_path:
            raise ValueError(f"Invalid alias or missing path: {dataset_alias}")

        base = os.path.splitext(os.path.basename(txt_path))[0]
        dir_name = "unlabeled_" + base if 'raw_photos' in base or 'frames' in base or 'frames' in dataset_alias else base

        if os.path.exists(dir_name):
            print(f"Dataset found at {dir_name}")
            return dir_name

        url = self._get_url(dataset_alias, txt_path, use_bdl_api, self.BDL_DATASET_API)
        os.makedirs(dir_name, exist_ok=True)
        self._download_files(url, dir_name, is_folder=True)

        return dir_name

load_tools = LoadTools()
model_weights = load_tools.load_model(model_alias='ball_tracking')
model = YOLO(model_weights)


Downloading ball_tracking.pt:   0%|          | 0.00/114M [00:00<?, ?iB/s][A
Downloading ball_tracking.pt:   0%|          | 56.3k/114M [00:00<03:23, 563kiB/s][A
Downloading ball_tracking.pt:   0%|          | 132k/114M [00:00<02:57, 644kiB/s] [A
Downloading ball_tracking.pt:   0%|          | 517k/114M [00:00<00:57, 1.99MiB/s][A
Downloading ball_tracking.pt:   2%|▏         | 2.04M/114M [00:00<00:16, 6.77MiB/s][A
Downloading ball_tracking.pt:   6%|▌         | 6.47M/114M [00:00<00:05, 19.1MiB/s][A
Downloading ball_tracking.pt:  10%|▉         | 11.2M/114M [00:00<00:03, 27.4MiB/s][A
Downloading ball_tracking.pt:  14%|█▍        | 16.0M/114M [00:00<00:03, 32.3MiB/s][A
Downloading ball_tracking.pt:  18%|█▊        | 20.7M/114M [00:00<00:02, 35.5MiB/s][A
Downloading ball_tracking.pt:  23%|██▎       | 25.8M/114M [00:00<00:02, 40.2MiB/s][A
Downloading ball_tracking.pt:  27%|██▋       | 31.0M/114M [00:01<00:01, 43.5MiB/s][A
Downloading ball_tracking.pt:  32%|███▏      | 36.9M/114M [00:01<

Model downloaded to /content/models/YOLO/ball_tracking/model_weights/ball_tracking.pt


In [20]:
results = model.predict(source=SOURCE_VIDEO_PATH, save=True)



errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.

Example:
    results = model(source=..., stream=True)  # generator of Results objects
    for r in results:
        boxes = r.boxes  # Boxes object for bbox outputs
        masks = r.masks  # Masks object for segment masks outputs
        probs = r.probs  # Class probabilities for classification outputs

video 1/1 (frame 1/40) /content/baseball.mp4: 416x640 2 homeplates, 1 baseball, 65.5ms
video 1/1 (frame 2/40) /content/baseball.mp4: 416x640 2 homeplates, 1 baseball, 58.6ms
video 1/1 (frame 3/40) /content/baseball.mp4: 416x640 2 homeplates, 1 baseball, 40.8ms
video 1/1 (frame 4/40) /content/baseball.mp4: 416x640 2 homeplates, 2 baseballs, 40.8ms
video 1/1 (frame 5/40) /content/baseball.mp4: 416x640 2 homeplates, 2 baseballs, 39.6ms
video 1/1 (frame 6/40) /content/baseball.mp4: 416x640 2 homeplates, 1 baseball, 39.5ms
video 1/1 (frame 7/40) /content/baseball.mp4: 41

In [21]:
import moviepy.editor
%cd /content/runs/detect/predict
# Load the video file
video = moviepy.editor.VideoFileClip(filename="baseball.avi")
# Resize the video to a new resolution, e.g., (width, height)
resized_video = video.resize((640, 360))  # Example: resizing to 640x360
moviepy.editor.ipython_display(resized_video)




[A[A[A

[A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [13:45<?, ?iB/s]


Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:56<?, ?iB/s][A[A[A

Downloading ball_tracking.pt:   0%|          | 0.00/114M [04:32<?, ?iB/s][A[A






[A[A[A

[A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [13:45<?, ?iB/s]


Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:56<?, ?iB/s][A[A[A

Downloading ball_tracking.pt:   0%|          | 0.00/114M [04:32<?, ?iB/s][A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:05<?, ?iB/s][A[A[A[A

/content/runs/detect/predict
Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4




t:   0%|          | 0/40 [00:00<?, ?it/s, now=None][A
t:  20%|██        | 8/40 [00:00<00:00, 77.36it/s, now=None][A
t:  40%|████      | 16/40 [00:00<00:00, 75.81it/s, now=None][A
t:  60%|██████    | 24/40 [00:00<00:00, 65.23it/s, now=None][A
t:  78%|███████▊  | 31/40 [00:00<00:00, 65.24it/s, now=None][A
t:  98%|█████████▊| 39/40 [00:00<00:00, 68.18it/s, now=None][A



[A[A[A

[A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [13:46<?, ?iB/s]


Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:56<?, ?iB/s][A[A[A

Downloading ball_tracking.pt:   0%|          | 0.00/114M [04:33<?, ?iB/s][A[A






[A[A[A

[A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [13:46<?, ?iB/s]


Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:56<?, ?iB/s][A[A[A

Downloading ball_tracking.pt:   0%|          | 0.00/114M [04:33<?, ?iB/s][A[A



Downloading ball_trackingv4.pt:   0%|          | 0.00/114M [01:06<?, ?iB/s][A[A[

Moviepy - Done !
Moviepy - video ready __temp__.mp4
