In [None]:
import os
import mlflow
from ultralytics import YOLO
from app.yolo.utils import download_s3_folder

TRAINING_BUCKET = "big9-project-02-training-bucket"
MODEL_BUCKET = "big9-project-02-model-bucket"

def train_yolo_model():
    # 데이터 다운로드
    os.makedirs("./data/yolo/images", exist_ok=True)
    os.makedirs("./data/yolo/labels", exist_ok=True)
    download_s3_folder(TRAINING_BUCKET, 'train/yolo/images/', './data/yolo/images')
    download_s3_folder(TRAINING_BUCKET, 'train/yolo/labels/', './data/yolo/labels')

    # MLflow 설정
    mlflow.set_tracking_uri("http://<MLflow-server>:5000")
    mlflow.set_experiment("YOLO Training")
    
    with mlflow.start_run(run_name="YOLO Training Run"):
        model = YOLO("yolov8n.pt")
        model.train(data="./app/yolo/config.yaml", epochs=20, imgsz=640, save_dir="./results/yolo")
        
        # 결과 기록 및 업로드
        mlflow.log_artifact("./results/yolo/train/weights/best.pt", artifact_path="models/yolo")
        os.system(f"aws s3 cp ./results/yolo/train/weights/best.pt s3://{MODEL_BUCKET}/yolo/best.pt")
        print(f"Uploaded YOLO model to s3://{MODEL_BUCKET}/yolo/best.pt")


In [None]:
version: "3.8"
services:
  yolo-training:
    build: .
    container_name: yolo_training
    ports:
      - "5000:5000" # MLflow server access (optional)
    volumes:
      - ./data:/app/data # Bind mount for accessing local data
    command: ["python", "app/yolo/train_yolo.py"]


In [None]:
version: "3.8"
services:
  yolo-training:
    build: .
    container_name: yolo_training
      MLFLOW_TRACKING_URI: "http://mlflow_server:5000"  # MLflow 서버 URI 설정
    volumes:
      - ./data:/app/data  # YOLO 학습 데이터
    depends_on:
      - mlflow-server  # MLflow 서버가 먼저 실행되도록 설정
    command: ["python", "app/train.py", "yolo"]  # YOLO 학습 실행

  mlflow-server:
    image: python:3.8-slim
    container_name: mlflow_server
    working_dir: /mlflow
    volumes:
      - ./mlruns:/mlflow/mlruns  # MLflow 로그 저장
      - ./mlflow.db:/mlflow/mlflow.db  # MLflow 데이터베이스 파일
    ports:
      - "5000:5000"  # MLflow 서버 접근
    command: ["mlflow", "server", "--backend-store-uri", "sqlite:///mlflow.db", "--default-artifact-root", "./mlruns", "--host", "0.0.0.0"]
