# YOLO Pretraining with IMAGENET-VID

## Imports

In [2]:
%pip install huggingface_hub ultralytics torchinfo pygwalker comet_ml clearml

Collecting pygwalker
  Downloading pygwalker-0.4.9.13-py3-none-any.whl.metadata (20 kB)
Collecting anywidget (from pygwalker)
  Downloading anywidget-0.9.13-py3-none-any.whl.metadata (7.2 kB)
Collecting astor (from pygwalker)
  Downloading astor-0.8.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting duckdb<2.0.0,>=0.10.1 (from pygwalker)
  Downloading duckdb-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (966 bytes)
Collecting gw-dsl-parser==0.1.49.1 (from pygwalker)
  Downloading gw_dsl_parser-0.1.49.1-py3-none-any.whl.metadata (1.2 kB)
Collecting ipylab<=1.0.0 (from pygwalker)
  Downloading ipylab-1.0.0-py3-none-any.whl.metadata (6.7 kB)
Collecting kanaries-track==0.0.5 (from pygwalker)
  Downloading kanaries_track-0.0.5-py3-none-any.whl.metadata (913 bytes)
Collecting quickjs (from pygwalker)
  Downloading quickjs-1.19.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (590 bytes)
Collecting segment-analytics-python==2.2.3 (from pygwalk

In [None]:
import os
import shutil
from os import path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from torchvision import datasets, utils, transforms
from torchinfo import summary
from ultralytics import YOLO

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pygwalker as pyg

import comet_ml
#from clearml import Task, browser_login

In [2]:
# Attatch logger to the current project
comet_ml.login(project_name="APT")

In [3]:
%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
#%env CLEARML_API_ACCESS_KEY=<Your API access key>
#%env CLEARML_API_SECRET_KEY=<Your API secret key>

#browser_login()
#task = Task.init(project_name="APT", task_name="Model Pretaining")

env: CLEARML_WEB_HOST=https://app.clear.ml/
env: CLEARML_API_HOST=https://api.clear.ml
env: CLEARML_FILES_HOST=https://files.clear.ml


### Check GPU Availability

In [4]:
!nvidia-smi

Wed Feb 26 00:25:35 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  On   | 00000000:04:00.0 Off |                    0 |
| N/A   39C    P0    34W / 250W |   7320MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   39C    P0    30W / 250W |   4028MiB / 16280MiB |      0%      Default |
|       

In [5]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 3
ADDITIONAL_GPU = 1

if torch.cuda.is_available():
    if ADDITIONAL_GPU:
        torch.cuda.set_device(DEVICE_NUM)
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{DEVICE_NUM}")
else:
    device = torch.device("cpu")
    DEVICE_NUM = -1

print(f"INFO: Using device - {device}" + (f":{DEVICE_NUM}" if ADDITIONAL_GPU else ""))

INFO: Using device - cuda:3


In [6]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

INFO: Number of CPU cores - 48


## Define Dataset
ImageNet-VID

In [10]:
from typing import Callable, Optional
from dataclasses import dataclass
from pathlib import Path
import huggingface_hub
import json


@dataclass
class DataPair:
    train: any
    val: any


class ImageNetVIDDataset(datasets.ImageFolder):
    """
    ImageNet-VID dataset for Object Detection and Tracking.
    Only works in Linux.
    
    :ref: https://huggingface.co/datasets/guanxiongsun/imagenetvid
    """

    download_method = huggingface_hub.snapshot_download
    dataset_name = "ILSVRC2015_VID"
    dataset_id = "guanxiongsun/imagenetvid"
    obj_classes = [
        "airplane", "antelope", "bear", "bicycle", "bird", "bus", "car", "cattle",
        "dog", "domestic cat", "elephant", "fox", "giant panda", "hamster", "horse",
        "lion", "lizard", "monkey", "motorcycle", "rabbit", "red panda",
        "sheep", "snake", "squirrel", "tiger", "train", "turtle",
        "watercraft", "whale", "zebra"
    ]
    __cached_annotation = DataPair(None, None)
    annotation_files = DataPair("imagenet_vid_train.json", "imagenet_vid_val.json")

    def __init__(
        self,
        root: str,
        force_download: bool = True,
        train: bool = True,
        valid: bool = False,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None
        ):
        self.root = path.join(root, self.dataset_name)
        self.download(root, force=force_download)

        if train:
            root = path.join(root, "val") if valid else path.join(root, "train")
        else:
            root = path.join(root, "test")

        super().__init__(root=root, transform=transform, target_transform=target_transform)

    def apply_transform(
        self, transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None
        ):
        self.transform = self.transforms = transform  # For back-ward compatibility
        self.transforms = self.target_transform = target_transform

    @classmethod
    def download(cls, root: str, force: bool = False):
        # Clean up the existing dataset if force is flagged
        root = Path(root)
        
        if force:
            print(f"INFO: Cleaning up the existing dataset at {root} (Force-download is flagged)")
            for item in os.listdir(root):
                item_path = root / item
                if path.isfile(item_path):
                    os.remove(item_path)
                else:
                    shutil.rmtree(item_path)
            print("INFO: Dataset cleaned successfully.")
        
        # Do download if the dataset does not exist
        print(f"Downloading {cls.dataset_name} from huggingface...")
        cls.download_method(
            repo_id=cls.dataset_id,
            repo_type="dataset",
            local_dir=root,
            ignore_patterns=["*.git*", "*.md", "*ILSVRC2017*"]
        )
        if force or not path.exists(root):
            print("INFO: Dataset downloaded successfully.")
        else:
            print("INFO: Dataset archive found in the root directory. Skipping download.")
        
        # Combine split achives
        if not path.exists(root / f"{cls.dataset_name}.tar.gz"):
            print("INFO: Combining seperated archives...")
            result = os.system("cat " + root / f"{cls.dataset_name}.tar.gz.a*" + " > " + root / f"{cls.dataset_name}.tar.gz")
            if result != 0:
                raise Exception("Failed to combine split archives. Please make sure that you are running on a Linux system.")
            print("INFO: Split archives combined successfully.")
        else:
            print("INFO: Combined archives found in the root directory. Skipping combination.")

        # Extract the dataset
        print("INFO: Extracting the dataset...")
        os.system(f"tar -xvf {root / f'{cls.dataset_name}.tar.gz'} -C {root}")


    @classmethod
    def bbox(cls, img_index: int, train: bool = True):
        if train and cls.__cached_annotation.train is None:
            with open(cls.annotation_files.train, "r", encoding="utf-8") as f:
                cls.__cached_annotation.train = json.load(f)
        if not train and cls.__cached_annotation.val is None:
            with open(cls.annotation_files.val, "r", encoding="utf-8") as f:
                cls.__cached_annotation.val = json.load(f)

        cache = cls.__cached_annotation.train if train else cls.__cached_annotation.val
        return tuple(cache['annotation'][img_index]['bbox'])

    @property
    def df(self) -> pd.DataFrame:
        return pd.DataFrame(dict(path=[d[0] for d in self.samples], label=[self.classes[lb] for lb in self.targets]))

    @property
    def sample_output(self):
        return None  # TODO: Implement sample output img view for the dataset

In [11]:
# Define image size for resizing
IMG_SIZE = 640

# Define image normalization parameters (ImageNet style)
IMG_NORM = dict(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

# Create transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Adjust brightness/contrast
    transforms.ToTensor(),
    transforms.Normalize(**IMG_NORM)
])
label_transform = lambda train: transforms.Lambda(
    lambda x: ImageNetVIDDataset.bbox(x, train=train)
)
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(), 
    transforms.Normalize(**IMG_NORM)
])

In [None]:
DATA_ROOT = path.join(".", "data")

train_dataset = ImageNetVIDDataset(
    root=DATA_ROOT, force_download=False, train=True,
    transform=train_transform, target_transform=label_transform
)
valid_dataset = ImageNetVIDDataset(
    root=DATA_ROOT, force_download=False, valid=True,
    transform=train_transform, target_transform=label_transform
)
test_dataset = ImageNetVIDDataset(
    root=DATA_ROOT, force_download=False, train=False,
    transform=test_transform
)

print(f"INFO: Dataset loaded successfully. Number of samples - Train({len(train_dataset)}), Valid({len(valid_dataset)}), Test({len(test_dataset)})")

Downloading ILSVRC2015_VID from huggingface...


Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

ILSVRC2015_VID.tar.gz.ab:   0%|          | 0.00/43.7G [00:00<?, ?B/s]

ILSVRC2015_VID.tar.gz.aa:   0%|          | 0.00/48.3G [00:00<?, ?B/s]

annotations.tar.gz:   0%|          | 0.00/56.9M [00:00<?, ?B/s]

## Load Model

In [None]:
model = YOLO("yolo11m.yaml")  # build a new model from YAML
#model = YOLO("path/to/last.pt")  # load a partially trained model
model

In [None]:
train_configs = dict(
    data="imagenetvid.yaml",
    epochs=100,
    patience=10,
    batch=512,
    imgsz=640,
    save=True,
    save_period=10,
    cache=True,
    device=list(range(DEVICE_NUM, DEVICE_NUM+ADDITIONAL_GPU+1)) if ADDITIONAL_GPU else DEVICE_NUM,
    workers=cpu_cores,
    project="pretrained4imagenetvid",
    resume=True,
    val=True,
    plots=True
)

## Train

In [None]:
# Start/Resume model training
results = model.train(**train_configs)

## Validation

In [None]:
# Evaluate model performance on the validation set
metrics = model.val()

In [None]:
# Perform object detection on an image
results = model("./data/ILSVRC2015_VID/val/.jpg")
results[0].show()