# YOLO Train

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

from src.datasets import SkyFusionDataset, TrafficSignsDataset, FruitsDataset
from src.filesystem import FileSystem
from src.models import YOLO
from src.transforms import IMAGES_RESOLUTION

dataset_choices = {
    "skyfusion": SkyFusionDataset,
    "traffic_signs": TrafficSignsDataset,
    "fruits": FruitsDataset,
}

In [None]:
dataset_name = "skyfusion"
dataset = dataset_choices[dataset_name]

In [None]:
yaml_file = dataset.export_to_yolo()
save_dir = FileSystem.LOGS_DIR / dataset_name
os.makedirs(save_dir, exist_ok=True)

epochs = 100
batch_size = 32

print("yaml_file:", os.path.relpath(yaml_file, FileSystem.PROJECT_ROOT))
print("yaml_content:")
with open(yaml_file, 'r') as f:
    print(f"  {'  '.join(f.readlines())}", end='')
print("save_dir:", os.path.relpath(save_dir, FileSystem.PROJECT_ROOT))
print("epochs:", epochs)

In [None]:
yolo_model = YOLO("yolov8m.pt", models_path=FileSystem.MODELS_DIR / dataset_name)
yolo_model = yolo_model.to("cuda")
yolo_model.info()

In [None]:
results = yolo_model.train(
    data=yaml_file,
    epochs=epochs,
    batch=batch_size,
    imgsz=IMAGES_RESOLUTION,
    project=save_dir,
    optimizer="Adam",
)