
# Burger Detector â€” YOLOv8 Custom Object Detection  
This notebook walks through training a YOLOv8 model to detect burgers using a custom dataset.

## Setup
Install dependencies and set paths before running.


In [None]:

# Install dependencies
!pip install ultralytics pillow matplotlib

# Imports
from ultralytics import YOLO
import os
from PIL import Image
import matplotlib.pyplot as plt
import glob


In [None]:

# Set dataset path (update to your local dataset)
DATASET_PATH = "data/"


In [None]:

def train_yolo_model(
    dataset_path: str,
    model_name: str = "yolov8n.pt",
    output_dir: str = None,
    epochs: int = 50,
    imgsz: int = 640,
    batch: int = 16,
):
    yaml_path = os.path.join(dataset_path, "data.yaml")
    if not os.path.exists(yaml_path):
        raise FileNotFoundError(f"data.yaml not found in {dataset_path}")

    output_dir = output_dir or os.path.join(dataset_path, "model")
    os.makedirs(output_dir, exist_ok=True)

    model = YOLO(model_name)
    results = model.train(
        data=yaml_path,
        epochs=epochs,
        imgsz=imgsz,
        batch=batch,
        project=output_dir,
        name="weights",
        exist_ok=True
    )
    return model


In [None]:

# Example training call
# burger_model = train_yolo_model(DATASET_PATH, model_name="yolov8s.pt", epochs=50, imgsz=640)


In [None]:

def load_model(weights_path: str):
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Weights not found at: {weights_path}")
    return YOLO(weights_path)

def run_inference(model, image_path: str, save: bool = False, conf: float = 0.25):
    return model(image_path, save=save, conf=conf)


In [None]:

def show_side_by_side(original_path: str, results):
    original = Image.open(original_path)
    results[0].save("temp_pred.jpg")
    predicted = Image.open("temp_pred.jpg")

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Original")
    plt.imshow(original)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("Prediction")
    plt.imshow(predicted)
    plt.axis("off")

    plt.tight_layout()
    plt.show()
