In [None]:
import os
import cv2 as cv
from ultralytics import YOLO
import matplotlib.pyplot as plt
import shutil
import glob
import random
import yaml
import numpy as np

def detect_object_in_image(image, class_id):
    """이미지에서 객체의 실제 위치를 자동으로 찾기"""
    height, width = image.shape[:2]

    # 간단한 방법: 배경 제거 후 객체 영역 찾기
    gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)

    # 가우시안 블러로 노이즈 제거
    blurred = cv.GaussianBlur(gray, (5, 5), 0)

    # Otsu 임계값으로 이진화
    _, binary = cv.threshold(blurred, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU)

    # 모폴로지 연산으로 노이즈 제거
    kernel = np.ones((3,3), np.uint8)
    cleaned = cv.morphologyEx(binary, cv.MORPH_CLOSE, kernel)
    cleaned = cv.morphologyEx(cleaned, cv.MORPH_OPEN, kernel)

    # 컨투어 찾기
    contours, _ = cv.findContours(cleaned, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)

    if contours:
        # 가장 큰 컨투어를 객체로 간주
        largest_contour = max(contours, key=cv.contourArea)
        x, y, w, h = cv.boundingRect(largest_contour)

        # 이미지 중앙 부분에 있고 충분히 큰 객체만 선택
        center_x, center_y = width // 2, height // 2
        contour_center_x, contour_center_y = x + w // 2, y + h // 2

        # 객체가 너무 작거나 가장자리에 있으면 전체 영역 사용
        if (w * h < width * height * 0.01 or
            abs(contour_center_x - center_x) > width * 0.3 or
            abs(contour_center_y - center_y) > height * 0.3):
            # 중앙 80% 영역을 객체로 설정
            margin = 0.1
            x = int(width * margin)
            y = int(height * margin)
            w = int(width * 0.8)
            h = int(height * 0.8)

        # YOLO 형식으로 변환 (중심점, 너비, 높이를 0-1로 정규화)
        center_x = (x + w / 2) / width
        center_y = (y + h / 2) / height
        norm_w = w / width
        norm_h = h / height

        return center_x, center_y, norm_w, norm_h
    else:
        # 컨투어를 찾지 못한 경우 중앙 80% 영역 사용
        return 0.5, 0.5, 0.8, 0.8

def setup_yolo_dataset():
    """개선된 YOLO 데이터셋 생성"""
    source_path = "/home/trashnet"

    if not os.path.exists(source_path):
        return False

    yolo_path = "trash_yolo_v2"
    if os.path.exists(yolo_path):
        shutil.rmtree(yolo_path)

    for split in ['train', 'val']:
        os.makedirs(f"{yolo_path}/images/{split}", exist_ok=True)
        os.makedirs(f"{yolo_path}/labels/{split}", exist_ok=True)

    classes = {'glass': 0, 'metal': 1, 'paper': 2, 'plastic': 3, 'trash': 4}

    all_files = []
    for class_name, class_id in classes.items():
        class_dir = os.path.join(source_path, class_name)
        if os.path.exists(class_dir):
            images = glob.glob(f"{class_dir}/*.jpg") + glob.glob(f"{class_dir}/*.png") + glob.glob(f"{class_dir}/*.jpeg")
            print(f"{class_name}: {len(images)}개")
            for img_path in images:
                all_files.append((img_path, class_id, class_name))

    if len(all_files) < 10:
        return False

    random.shuffle(all_files)
    split_idx = int(len(all_files) * 0.8)

    processed = 0
    for i, (img_path, class_id, class_name) in enumerate(all_files):
        try:
            img = cv.imread(img_path)
            if img is None:
                continue

            # 이미지 전처리
            img = cv.resize(img, (640, 640))

            # 객체 위치 자동 탐지
            cx, cy, w, h = detect_object_in_image(img, class_id)

            split = 'train' if i < split_idx else 'val'
            img_name = f"{processed:05d}.jpg"

            cv.imwrite(f"{yolo_path}/images/{split}/{img_name}", img)

            with open(f"{yolo_path}/labels/{split}/{processed:05d}.txt", 'w') as f:
                f.write(f"{class_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")

            processed += 1

        except Exception as e:
            continue

    # data.yaml 생성
    yaml_content = f"""path: {os.path.abspath(yolo_path)}
train: images/train
val: images/val
nc: 5
names: ['glass', 'metal', 'paper', 'plastic', 'trash']
"""

    with open("trash_data_v2.yaml", "w") as f:
        f.write(yaml_content)

    print(f"처리 완료: {processed}개")
    return True

def train_model():
    """향상된 학습 설정"""
    if not setup_yolo_dataset():
        return None

    model = YOLO('yolov8n.pt')
    results = model.train(
        data='trash_data_v2.yaml',
        epochs=50,
        imgsz=640,
        batch=16,
        lr0=0.01,
        momentum=0.937,
        weight_decay=0.0005,
        warmup_epochs=3,
        device='cpu',
        workers=2,
        verbose=True
    )

    return model

def validate_model():
    """학습된 모델 검증"""
    runs_dir = "runs/detect"
    model_path = None

    if os.path.exists(runs_dir):
        train_dirs = [d for d in os.listdir(runs_dir) if d.startswith('train')]
        if train_dirs:
            latest_train = sorted(train_dirs)[-1]
            model_path = f"{runs_dir}/{latest_train}/weights/best.pt"

    if not model_path or not os.path.exists(model_path):
        return

    model = YOLO(model_path)

    # 각 클래스별 테스트
    classes = ['glass', 'metal', 'paper', 'plastic', 'trash']
    colors = [(0,255,0), (255,0,0), (0,0,255), (255,255,0), (255,0,255)]

    fig, axes = plt.subplots(5, 2, figsize=(12, 20))

    test_count = 0
    for class_idx, class_name in enumerate(classes):
        class_dir = f"trashnet/{class_name}"
        if os.path.exists(class_dir):
            images = glob.glob(f"{class_dir}/*.jpg") + glob.glob(f"{class_dir}/*.png")

            for img_idx, img_path in enumerate(images[:2]):
                if test_count >= 10:
                    break

                img = cv.imread(img_path)
                img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)

                results = model.predict(img_path, conf=0.5, verbose=False)

                if results[0].boxes is not None:
                    boxes = results[0].boxes.xyxy.cpu().numpy()
                    cls_ids = results[0].boxes.cls.cpu().numpy()
                    confs = results[0].boxes.conf.cpu().numpy()

                    for box, cls_id, conf in zip(boxes, cls_ids, confs):
                        x1, y1, x2, y2 = map(int, box)
                        color = colors[int(cls_id)]

                        cv.rectangle(img_rgb, (x1, y1), (x2, y2), color, 3)

                        pred_class = classes[int(cls_id)]
                        label = f"예측: {pred_class} ({conf:.2f})"
                        cv.putText(img_rgb, label, (x1, y1-10),
                                  cv.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

                row = test_count // 2
                col = test_count % 2
                axes[row, col].imshow(img_rgb)
                axes[row, col].set_title(f"실제: {class_name}")
                axes[row, col].axis('off')

                test_count += 1

    plt.tight_layout()
    plt.show()

def upload_test():
    """업로드 테스트"""
    try:
        from google.colab import files
        from PIL import Image
        import io

        runs_dir = "runs/detect"
        model_path = None

        if os.path.exists(runs_dir):
            train_dirs = [d for d in os.listdir(runs_dir) if d.startswith('train')]
            if train_dirs:
                latest_train = sorted(train_dirs)[-1]
                model_path = f"{runs_dir}/{latest_train}/weights/best.pt"

        if not model_path or not os.path.exists(model_path):
            return

        model = YOLO(model_path)
        uploaded = files.upload()

        classes = ['glass', 'metal', 'paper', 'plastic', 'trash']
        colors = [(0,255,0), (255,0,0), (0,0,255), (255,255,0), (255,0,255)]

        for filename in uploaded.keys():
            img = Image.open(io.BytesIO(uploaded[filename]))
            img_array = np.array(img)

            if len(img_array.shape) == 3:
                img_bgr = cv.cvtColor(img_array, cv.COLOR_RGB2BGR)
            else:
                img_bgr = cv.cvtColor(img_array, cv.COLOR_GRAY2BGR)

            results = model.predict(img_bgr, conf=0.3, verbose=False)

            img_rgb = img_array if len(img_array.shape) == 3 else cv.cvtColor(img_array, cv.COLOR_GRAY2RGB)

            detection_info = []
            if results[0].boxes is not None:
                boxes = results[0].boxes.xyxy.cpu().numpy()
                cls_ids = results[0].boxes.cls.cpu().numpy()
                confs = results[0].boxes.conf.cpu().numpy()

                for box, cls_id, conf in zip(boxes, cls_ids, confs):
                    x1, y1, x2, y2 = map(int, box)
                    color = colors[int(cls_id)]
                    pred_class = classes[int(cls_id)]

                    cv.rectangle(img_rgb, (x1, y1), (x2, y2), color, 4)
                    label = f"{pred_class}: {conf:.2f}"
                    cv.putText(img_rgb, label, (x1, y1-10),
                              cv.FONT_HERSHEY_SIMPLEX, 1.0, color, 3)

                    detection_info.append(f"{pred_class} ({conf:.2f})")

            plt.figure(figsize=(12, 8))
            plt.imshow(img_rgb)
            plt.title(f"분류 결과: {', '.join(detection_info) if detection_info else '탐지 없음'}")
            plt.axis('off')
            plt.show()

    except ImportError:
        pass

if __name__ == "__main__":
    model = train_model()
    if model:
        validate_model()
        upload_test()