In [1]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
import os
from multiprocessing import Pool
from functools import partial
import subprocess

def unzip_file(zip_info):
    zip_file_path, extract_path = zip_info
    os.makedirs(extract_path, exist_ok=True)

    # subprocess를 사용하여 unzip 명령어 실행
    subprocess.run(['unzip', '-q', zip_file_path, '-d', extract_path], check=True)
    return f'{zip_file_path} 압축이 {extract_path}에 성공적으로 해제되었습니다.'

# 압축 해제할 경로와 파일 목록
zip_files = [
    ("/content/drive/Shareddrives/Data/dataset.zip", "/tmp"),
]

# 병렬 처리로 압축 해제
if __name__ == '__main__':
    with Pool(processes=len(zip_files)) as pool:  # 파일 개수만큼 프로세스 생성
        results = pool.map(unzip_file, zip_files)

    # 결과 출력
    for result in results:
        print(result)

# 다시


# 전처리 셀 1

In [None]:
import os, cv2, numpy as np
from pathlib import Path
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from sklearn.model_selection import train_test_split

# 경로 설정
TRAIN_DIR = '/tmp/train_zip/train'
TEST_DIR = '/tmp/test_zip/test'
PREPROCESS_DIR = '/tmp/preprocess_data'
TARGET_SIZE = 256

random.seed(42)  # 재현성 있는 랜덤

def create_directories():
    if os.path.exists(PREPROCESS_DIR):
        shutil.rmtree(PREPROCESS_DIR)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'train'), exist_ok=True)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'val'), exist_ok=True)  # 검증 데이터용 디렉토리 추가
    os.makedirs(os.path.join(PREPROCESS_DIR, 'test'), exist_ok=True)

def remove_background(image):
    # 기존 코드와 동일
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    masks = [
        cv2.inRange(hsv, np.array([0, 20, 20]), np.array([15, 255, 255])),   # red1
        cv2.inRange(hsv, np.array([165, 20, 20]), np.array([180, 255, 255])),# red2
        cv2.inRange(hsv, np.array([15, 20, 20]), np.array([40, 255, 255])),  # yellow
        cv2.inRange(hsv, np.array([5, 20, 20]), np.array([25, 255, 255]))    # orange
    ]
    color_mask = masks[0]
    for mask in masks[1:]:
        color_mask = cv2.bitwise_or(color_mask, mask)

    edges = cv2.Canny(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY), 100, 200)
    final_mask = cv2.bitwise_or(color_mask, edges)

    kernel = np.ones((5, 5), np.uint8)
    final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel)
    final_mask = cv2.dilate(final_mask, kernel, iterations=1)

    return cv2.bitwise_and(image, image, mask=final_mask)

def augment_with_noise(image):
    """이미지에 랜덤 노이즈 추가"""
    noise = np.random.normal(0, random.uniform(5, 20), image.shape).astype(np.uint8)
    noisy_img = cv2.add(image, noise)
    return noisy_img

def center_on_black(image):
    # 기존 코드와 동일
    black_bg = np.zeros((TARGET_SIZE, TARGET_SIZE, 3), dtype=np.uint8)
    h, w = image.shape[:2]
    scale = min(TARGET_SIZE / h, TARGET_SIZE / w) * 0.8
    new_h, new_w = int(h * scale), int(w * scale)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
    y_offset = (TARGET_SIZE - new_h) // 2
    x_offset = (TARGET_SIZE - new_w) // 2
    black_bg[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
    return black_bg

def basic_geometric_augment(image):
    """샘플 수 맞추기용 기본 지오메트릭 변환"""
    h, w = image.shape[:2]
    center = (w // 2, h // 2)

    transforms = [
        # 회전 (최소한의 각도)
        lambda img: cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.uniform(-10, 10), 1.0), (w, h)),
        # 스케일 (작은 변화)
        lambda img: cv2.resize(img, None, fx=np.random.uniform(0.9, 1.1), fy=np.random.uniform(0.9, 1.1)),
        # 밝기 (최소 조정)
        lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.9, 1.1), beta=np.random.randint(-10, 10))
    ]
    return random.choice(transforms)(image)

def split_train_val_files(input_dir):
    """학습/검증 데이터 분할"""
    class_files = {}
    for filename in os.listdir(input_dir):
        if filename.endswith('.jpg'):
            class_name = filename.split('_')[0]
            if class_name not in class_files:
                class_files[class_name] = []
            class_files[class_name].append(filename)

    train_files = []
    val_files = []

    for class_name, files in class_files.items():
        files.sort()  # 파일 순서 보장
        train_class_files, val_class_files = train_test_split(
            files, test_size=0.2, random_state=42, shuffle=True
        )
        train_files.extend(train_class_files)
        val_files.extend(val_class_files)

    return train_files, val_files

def preprocess_and_select(files, split='train'):
    """데이터 전처리 및 증강"""
    output_dir = os.path.join(PREPROCESS_DIR, split)
    os.makedirs(output_dir, exist_ok=True)

    # 클래스별 파일 분류
    class_files = {}
    for filename in files:
        class_name = filename.split('_')[0]
        if class_name not in class_files:
            class_files[class_name] = []
        class_files[class_name].append(filename)

    # 최대 샘플 수 계산 (학습 데이터의 경우만)
    if split == 'train':
        max_samples = max(len(files) for files in class_files.values())
        print(f"Maximum samples per class in {split}: {max_samples}")

    for class_name, class_files_list in class_files.items():
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)

        processed_images = []
        # 전처리 진행
        for filename in tqdm(class_files_list, desc=f"Preprocessing {class_name} ({split})", leave=False):
            image_path = os.path.join(TRAIN_DIR, filename)  # 원본 학습 데이터 경로 사용
            image = cv2.imread(image_path)
            if image is None:
                continue
            processed = remove_background(image)
            processed = center_on_black(processed)
            processed_images.append(processed)

        # 학습 데이터이고 샘플 수가 부족한 경우 증강
        if split == 'train':
            current_count = len(processed_images)
            if current_count < max_samples:
                needed = max_samples - current_count
                for _ in range(needed):
                    base_img = random.choice(processed_images)
                    processed_images.append(basic_geometric_augment(base_img))

        # 이미지 저장 (들여쓰기 수정)
        for i, img in enumerate(processed_images):
            save_name = f"{class_name}_{i+1}.jpg"
            cv2.imwrite(os.path.join(class_dir, save_name), img)

def preprocess_test_data():
    """테스트 데이터 전처리"""
    output_dir = os.path.join(PREPROCESS_DIR, 'test')
    os.makedirs(output_dir, exist_ok=True)

    for filename in tqdm(os.listdir(TEST_DIR), desc="Processing test data"):
        if not filename.endswith('.jpg'):
            continue

        class_name = filename.split('_')[0]
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)

        image_path = os.path.join(TEST_DIR, filename)
        image = cv2.imread(image_path)
        if image is None:
            continue

        processed = remove_background(image)
        processed = center_on_black(processed)
        cv2.imwrite(os.path.join(class_dir, filename), processed)

def final_augment_image(image):
    """최종 증강 기법"""
    h, w = image.shape[:2]
    center = (w // 2, h // 2)

    # 기본 변환 (단일 변환, 강한 강도)
    basic_transforms = {
        'rotation': lambda img: cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.randint(-45, 45), 1.0), (w, h)),
        'brightness': lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.6, 1.4), beta=np.random.randint(-50, 50)),
        'contrast': lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.5, 1.5))
    }

    # 복합 변환 (여러 변환 조합)
    composite_transforms = {
        'rotation_scale': lambda img: cv2.resize(
            cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.uniform(-30, 30), 1.0), (w, h)),
            None, fx=np.random.uniform(0.8, 1.3), fy=np.random.uniform(0.8, 1.3)
        ),
        #'brightness_blur': lambda img: cv2.GaussianBlur(
        #   cv2.convertScaleAbs(img, alpha=np.random.uniform(0.7, 1.3), beta=np.random.randint(-40, 40)),
        #   (5, 5), np.random.uniform(0.5, 2.0)
        #),
        'rotation_contrast': lambda img: cv2.convertScaleAbs(
            cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.uniform(-25, 25), 1.0), (w, h)),
            alpha=np.random.uniform(0.6, 1.4)
        )
    }

    augmented_results = []

    # 기본 변환 3개
    for key in basic_transforms.keys():
        aug_img = basic_transforms[key](image)
        augmented_results.append((f'basic_{key}', aug_img))

    # 복합 변환 3개
    for key in composite_transforms.keys():
        aug_img = composite_transforms[key](image)
        augmented_results.append((f'composite_{key}', aug_img))

    return augmented_results

def final_augmentation(split='train'):
    """최종 증강 (학습 데이터만)"""
    if split != 'train':
        return

    train_dir = os.path.join(PREPROCESS_DIR, 'train')
    for class_name in os.listdir(train_dir):
        class_path = os.path.join(train_dir, class_name)
        if not os.path.isdir(class_path):
            continue

        # 원본 이미지만 선택 (class_숫자.jpg 형태)
        images = [f for f in os.listdir(class_path) if f.endswith('.jpg') and len(f.split('_')) == 2]
        images.sort()

        for filename in tqdm(images, desc=f"Final Augmenting {class_name}", leave=False):
            img_path = os.path.join(class_path, filename)
            img = cv2.imread(img_path)
            if img is None:
                continue

            aug_imgs = final_augment_image(img)
            base_name = os.path.splitext(filename)[0]

            for (transform_name, aimg) in aug_imgs:
                aug_path = os.path.join(class_path, f"{base_name}_aug_{transform_name}.jpg")
                cv2.imwrite(aug_path, aimg)

def plot_examples_for_mixed():
    class_dir = os.path.join(PREPROCESS_DIR, 'train', 'mixed')
    if not os.path.exists(class_dir):
        print("No mixed class directory found.")
        return
    images = [f for f in os.listdir(class_dir) if f.endswith('.jpg') and '_aug' not in f]
    if not images:
        print("No images found in mixed class.")
        return
    img_path = os.path.join(class_dir, images[0])
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    aug_imgs = final_augment_image(img)

    plt.figure(figsize=(15, 3))
    plt.subplot(1, len(aug_imgs) + 1, 1)
    plt.imshow(img_rgb)
    plt.title("Original")
    plt.axis('off')

    for i, (name, aimg) in enumerate(aug_imgs, 1):
        plt.subplot(1, len(aug_imgs) + 1, i+1)
        plt.imshow(cv2.cvtColor(aimg, cv2.COLOR_BGR2RGB))
        plt.title(name)  # 증강 기법 이름 표시
        plt.axis('off')
    plt.show()


if __name__ == "__main__":
    create_directories()

    print("\nSplitting training data into train/val...")
    train_files, val_files = split_train_val_files(TRAIN_DIR)

    print("\nPreprocessing and augmenting training data...")
    preprocess_and_select(train_files, split='train')

    print("\nPreprocessing validation data...")
    preprocess_and_select(val_files, split='val')

    print("\nPreprocessing test data...")
    preprocess_test_data()

    print("\nApplying final augmentation to training data...")
    final_augmentation(split='train')

    # 최종 데이터 수 출력
    for split in ['train', 'val', 'test']:
        split_dir = os.path.join(PREPROCESS_DIR, split)
        for class_name in os.listdir(split_dir):
            class_path = os.path.join(split_dir, class_name)
            if os.path.isdir(class_path):
                count = len([f for f in os.listdir(class_path) if f.endswith('.jpg')])
                print(f"{split.capitalize()} - Class {class_name}: {count} images")
    print("\nShowing examples for mixed class final augmentation...")
    plot_examples_for_mixed()

# 전처리 셀 2

In [None]:
import numpy as np
import pandas as pd
import cv2
import os
from tqdm import tqdm
import pickle
from sklearn.preprocessing import LabelEncoder

# 경로 설정
PREPROCESS_DIR = '/tmp/preprocess_data'
FEATURE_DIR = '/tmp/feature_data'
os.makedirs(FEATURE_DIR, exist_ok=True)

class FeatureExtractor:
    def __init__(self, input_size=256):
        self.input_size = input_size
        self.label_encoder = LabelEncoder()

    def extract_color_features(self, img):
        """컬러 관련 특징 추출"""
        # 각 채널별 평균과 표준편차
        means = img.mean(axis=(0, 1))
        stds = img.std(axis=(0, 1))

        # 각 채널별 히스토그램
        features = []
        for i in range(3):  # RGB 각 채널
            hist = cv2.calcHist([img], [i], None, [32], [0, 256])
            hist = hist.flatten() / hist.sum()  # 정규화
            features.extend(hist)

        features.extend(means)
        features.extend(stds)
        return np.array(features)

    def extract_texture_features(self, img):
        """텍스처 특징 추출"""
        # 그레이스케일 변환
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

        # Sobel 엣지 검출
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)

        # 엣지 강도와 방향
        magnitude = np.sqrt(sobelx**2 + sobely**2)
        direction = np.arctan2(sobely, sobelx)

        # 통계적 특징
        features = [
            np.mean(magnitude),  # 평균 엣지 강도
            np.std(magnitude),   # 엣지 강도의 표준편차
            np.percentile(magnitude, 90),  # 90번째 퍼센타일
            np.mean(direction),  # 평균 엣지 방향
            np.std(direction),   # 엣지 방향의 표준편차
        ]

        # LBP와 유사한 local 패턴 분석
        kernel_size = 3
        local_mean = cv2.blur(gray, (kernel_size, kernel_size))
        pattern = (gray > local_mean).astype(np.uint8)
        pattern_hist = cv2.calcHist([pattern], [0], None, [2], [0, 2])
        pattern_hist = pattern_hist.flatten() / pattern_hist.sum()

        features.extend(pattern_hist)
        return np.array(features)

    def extract_shape_features(self, img):
        """형태 관련 특징 추출"""
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

        # 이진화 (이미 엣지 정보가 포함된 이미지이므로)
        _, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)

        # 컨투어 찾기
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if not contours:
            return np.zeros(5)

        # 가장 큰 컨투어 선택
        largest_contour = max(contours, key=cv2.contourArea)

        # 특징 추출
        area = cv2.contourArea(largest_contour)
        perimeter = cv2.arcLength(largest_contour, True)
        hull = cv2.convexHull(largest_contour)
        hull_area = cv2.contourArea(hull)

        features = [
            area / (self.input_size ** 2),  # 정규화된 면적
            perimeter / (4 * self.input_size),  # 정규화된 둘레
            hull_area / (self.input_size ** 2),  # 정규화된 컨벡스 헐 면적
            4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0,  # 원형도
            area / hull_area if hull_area > 0 else 0  # 볼록도
        ]

        return np.array(features)

    def process_dataset(self, split='train'):
        """데이터셋 처리 및 특징 추출"""
        data_dir = os.path.join(PREPROCESS_DIR, split)
        features_list = []
        labels = []
        image_paths = []

        for class_name in os.listdir(data_dir):
            class_dir = os.path.join(data_dir, class_name)
            if not os.path.isdir(class_dir):
                continue

            print(f"Processing {split} - {class_name}")
            for img_name in tqdm(os.listdir(class_dir)):
                if not img_name.endswith('.jpg'):
                    continue

                img_path = os.path.join(class_dir, img_name)
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                # 특징 추출
                color_features = self.extract_color_features(img)
                texture_features = self.extract_texture_features(img)
                shape_features = self.extract_shape_features(img)

                # 모든 특징 합치기
                all_features = np.concatenate([
                    color_features,
                    texture_features,
                    shape_features
                ])

                features_list.append(all_features)
                labels.append(class_name)
                image_paths.append(img_path)

        # 특징을 numpy 배열로 변환
        X = np.array(features_list)

        # 레이블 인코딩 (train일 때만 fit)
        if split == 'train':
            y = self.label_encoder.fit_transform(labels)
        else:
            y = self.label_encoder.transform(labels)

        return X, y, image_paths

    def prepare_data(self):
        """전체 데이터셋 준비"""
        # 이미 분할된 데이터셋 처리
        print("Processing training data...")
        X_train, y_train, train_paths = self.process_dataset('train')

        print("Processing validation data...")
        X_val, y_val, val_paths = self.process_dataset('val')

        print("Processing test data...")
        X_test, y_test, test_paths = self.process_dataset('test')

        # 데이터 저장
        data = {
            'X_train': X_train,
            'X_val': X_val,
            'X_test': X_test,
            'y_train': y_train,
            'y_val': y_val,
            'y_test': y_test,
            'train_paths': train_paths,
            'val_paths': val_paths,
            'test_paths': test_paths,
            'label_encoder': self.label_encoder
        }

        with open(os.path.join(FEATURE_DIR, 'multimodal_data.pkl'), 'wb') as f:
            pickle.dump(data, f)

        print("\nFeature shapes:")
        print(f"X_train: {X_train.shape}")
        print(f"X_val: {X_val.shape}")
        print(f"X_test: {X_test.shape}")

        print("\nSample counts:")
        print(f"Training samples: {len(X_train)}")
        print(f"Validation samples: {len(X_val)}")
        print(f"Test samples: {len(X_test)}")

        return data

# 특징 추출 실행
if __name__ == "__main__":
    extractor = FeatureExtractor()
    data = extractor.prepare_data()
    print("\nFeature extraction completed and saved!")

Processing training data...
Processing train - banana


100%|██████████| 420/420 [00:03<00:00, 110.67it/s]


Processing train - orange


100%|██████████| 420/420 [00:03<00:00, 110.56it/s]


Processing train - mixed


100%|██████████| 420/420 [00:03<00:00, 109.57it/s]


Processing train - apple


100%|██████████| 420/420 [00:03<00:00, 110.15it/s]


Processing validation data...
Processing val - banana


100%|██████████| 15/15 [00:00<00:00, 114.05it/s]


Processing val - orange


100%|██████████| 15/15 [00:00<00:00, 112.21it/s]


Processing val - mixed


100%|██████████| 4/4 [00:00<00:00, 113.74it/s]


Processing val - apple


100%|██████████| 15/15 [00:00<00:00, 111.07it/s]


Processing test data...
Processing test - banana


100%|██████████| 18/18 [00:00<00:00, 108.93it/s]


Processing test - orange


100%|██████████| 18/18 [00:00<00:00, 111.30it/s]


Processing test - mixed


100%|██████████| 5/5 [00:00<00:00, 113.13it/s]


Processing test - apple


100%|██████████| 19/19 [00:00<00:00, 111.98it/s]


Feature shapes:
X_train: (1680, 114)
X_val: (49, 114)
X_test: (60, 114)

Sample counts:
Training samples: 1680
Validation samples: 49
Test samples: 60

Feature extraction completed and saved!





# 모델 구조 및 학습 셀

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import os
import cv2
import numpy as np
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import transforms
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class Config:
    IMAGE_SIZE = 256
    BATCH_SIZE = 8  # 배치 크기 감소
    NUM_WORKERS = 2

    DROPOUT_RATE = 0.4  # 드롭아웃 증가
    HIDDEN_DIM = 256
    BN_MOMENTUM = 0.1  # BatchNorm 모멘텀 추가

    LEARNING_RATE = 0.0001  # 학습률 감소
    NUM_EPOCHS = 150
    EARLY_STOP_PATIENCE = 10  # 얼리스탑 참을성 증가
    WEIGHT_DECAY = 0.001  # 정규화 강화

    LABEL_SMOOTHING = 0.1
    GRAD_CLIP = 1.0  # Gradient Clipping 추가

    WARMUP_EPOCHS = 5  # Warmup 추가

    USE_SCHEDULER = True
    SCHEDULER_PATIENCE = 5
    SCHEDULER_FACTOR = 0.3
    MIN_LR = 1e-6

    SPARSITY_WEIGHT = 0.0005

    SAVE_DIR = 'model_checkpoints'
    os.makedirs(SAVE_DIR, exist_ok=True)

class MultiModalDataset(Dataset):
    def __init__(self, images_paths, features, labels, image_size=256):
        self.image_paths = images_paths
        self.features = torch.FloatTensor(features)
        self.labels = torch.LongTensor(labels)
        self.image_size = image_size

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.image_size, self.image_size))
        image = self.transform(image)

        return {
            'image': image,
            'features': self.features[idx],
            'label': self.labels[idx]
        }

class Expert(nn.Module):
    def __init__(self, feature_dim, config):
        super().__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c, momentum=config.BN_MOMENTUM),  # 모멘텀 적용
                nn.ReLU(),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c, momentum=config.BN_MOMENTUM),  # 모멘텀 적용
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Dropout2d(config.DROPOUT_RATE)
            )

        self.image_encoder = nn.Sequential(
            conv_block(3, 64),
            conv_block(64, 128),
            conv_block(128, 256),
            conv_block(256, 512),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )

        # Feature encoder capacity 증가
        self.feature_encoder = nn.Sequential(
            nn.Linear(feature_dim, config.HIDDEN_DIM * 4),
            nn.BatchNorm1d(config.HIDDEN_DIM * 4, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT_RATE),
            nn.Linear(config.HIDDEN_DIM * 4, config.HIDDEN_DIM * 2),
            nn.BatchNorm1d(config.HIDDEN_DIM * 2, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Linear(config.HIDDEN_DIM * 2, config.HIDDEN_DIM),
            nn.BatchNorm1d(config.HIDDEN_DIM, momentum=config.BN_MOMENTUM),
            nn.ReLU()
        )

        self.attention = nn.Sequential(
            nn.Linear(512 + config.HIDDEN_DIM, 256),
            nn.Tanh(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

        self.combined = nn.Sequential(
            nn.Linear(512 + config.HIDDEN_DIM, config.HIDDEN_DIM * 4),  # Capacity 증가
            nn.BatchNorm1d(config.HIDDEN_DIM * 4, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT_RATE),
            nn.Linear(config.HIDDEN_DIM * 4, config.HIDDEN_DIM * 2),
            nn.BatchNorm1d(config.HIDDEN_DIM * 2, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Linear(config.HIDDEN_DIM * 2, 1)
        )

    def forward(self, images, features):
        img_feat = self.image_encoder(images)
        num_feat = self.feature_encoder(features)

        combined = torch.cat([img_feat, num_feat], dim=1)
        attention_weights = self.attention(combined)
        attended_features = combined * attention_weights

        score = self.combined(attended_features)
        return score

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, class_weights=None):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = torch.tensor([1.0, 1.0, 3.0, 1.0]) if class_weights is None else class_weights

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(
            inputs, targets,
            weight=self.class_weights.to(inputs.device),
            label_smoothing=Config.LABEL_SMOOTHING,  # 레이블 스무딩 적용
            reduction='none'
        )
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, path='best_model.pth'):
        self.patience = patience
        self.min_delta = min_delta
        self.path = path
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss > self.best_loss + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

class MoEMultiModal(nn.Module):
    def __init__(self, num_classes, feature_dim, config):
        super().__init__()
        self.num_classes = num_classes

        self.experts = nn.ModuleList([
            Expert(feature_dim, config) for _ in range(num_classes)
        ])

        self.gate = nn.Sequential(
            nn.Linear(feature_dim, config.HIDDEN_DIM * 2),  # Capacity 증가
            nn.BatchNorm1d(config.HIDDEN_DIM * 2, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT_RATE),
            nn.Linear(config.HIDDEN_DIM * 2, config.HIDDEN_DIM),
            nn.BatchNorm1d(config.HIDDEN_DIM, momentum=config.BN_MOMENTUM),
            nn.ReLU(),
            nn.Linear(config.HIDDEN_DIM, num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, images, features):
        gate_weights = self.gate(features)

        expert_outputs = []
        for expert in self.experts:
            out = expert(images, features)
            expert_outputs.append(out)

        expert_outputs = torch.cat(expert_outputs, dim=1)
        final_output = expert_outputs * gate_weights

        return final_output

def evaluate_one_epoch(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            images = batch['image'].to(device)
            features = batch['features'].to(device)
            labels = batch['label'].to(device)

            outputs = model(images, features)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')

    return avg_loss, accuracy, f1

def get_lr_multiplier(epoch, warmup_epochs):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    return 1.0

def train_moe_model(model, train_loader, val_loader, criterion, optimizer, config, device):
    early_stopping = EarlyStopping(
        patience=config.EARLY_STOP_PATIENCE,
        path=os.path.join(config.SAVE_DIR, 'best_moe_model.pth')
    )

    if config.USE_SCHEDULER:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=config.SCHEDULER_PATIENCE,
            factor=config.SCHEDULER_FACTOR, min_lr=config.MIN_LR, verbose=True
        )

    train_losses, train_accs, train_f1s = [], [], []
    val_losses, val_accs, val_f1s = [], [], []

    for epoch in range(config.NUM_EPOCHS):
        # Learning rate warmup
        if epoch < config.WARMUP_EPOCHS:
            lr_multiplier = get_lr_multiplier(epoch, config.WARMUP_EPOCHS)
            for param_group in optimizer.param_groups:
                param_group['lr'] = config.LEARNING_RATE * lr_multiplier

        # Training
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.NUM_EPOCHS}'):
            images = batch['image'].to(device)
            features = batch['features'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            outputs = model(images, features)
            loss = criterion(outputs, labels)

            gate_sparsity = 0.0
            for expert in model.experts:
                gate_sparsity += torch.mean(torch.abs(expert.combined[-1].weight))

            total_loss = loss + config.SPARSITY_WEIGHT * gate_sparsity
            total_loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_CLIP)

            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        train_loss = train_loss / len(train_loader)
        train_acc = accuracy_score(train_labels, train_preds)
        train_f1 = f1_score(train_labels, train_preds, average='macro')

        # Validation
        val_loss, val_acc, val_f1 = evaluate_one_epoch(model, val_loader, criterion, device)

        # Save metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_f1s.append(train_f1)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_f1s.append(val_f1)

        # Print progress
        print(f'\nEpoch {epoch+1}/{config.NUM_EPOCHS}:')
        print(f'Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}')
        print(f'Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}')

        if config.USE_SCHEDULER:
            scheduler.step(val_loss)

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("\nEarly stopping triggered")
            break

    # Plot training curves
    plot_metrics(train_losses, val_losses, 'Loss',
                os.path.join(config.SAVE_DIR, 'loss_curves.png'))
    plot_metrics(train_accs, val_accs, 'Accuracy',
                os.path.join(config.SAVE_DIR, 'accuracy_curves.png'))
    plot_metrics(train_f1s, val_f1s, 'F1 Score',
                os.path.join(config.SAVE_DIR, 'f1_curves.png'))

def plot_confusion_matrix(true_labels, pred_labels, class_names, save_path):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def plot_metrics(train_metrics, val_metrics, metric_name, save_path):
    plt.figure(figsize=(10, 5))
    plt.plot(train_metrics, label=f'Train {metric_name}')
    plt.plot(val_metrics, label=f'Val {metric_name}')
    plt.xlabel('Epoch')
    plt.ylabel(metric_name)
    plt.legend()
    plt.title(f'{metric_name} over Training')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def calculate_class_weights(y_train):
    class_counts = np.bincount(y_train)
    total = len(y_train)
    weights = total / (len(class_counts) * class_counts)
    return torch.FloatTensor(weights)

def main():
    torch.cuda.empty_cache()
    config = Config()

    # 데이터 로드
    with open(os.path.join(FEATURE_DIR, 'multimodal_data.pkl'), 'rb') as f:
        data = pickle.load(f)

    # 클래스 가중치 계산
    class_weights = calculate_class_weights(data['y_train'])

    # Dataset and DataLoader setup
    train_dataset = MultiModalDataset(
        data['train_paths'], data['X_train'], data['y_train'],
        image_size=config.IMAGE_SIZE
    )
    val_dataset = MultiModalDataset(
        data['val_paths'], data['X_val'], data['y_val'],
        image_size=config.IMAGE_SIZE
    )
    test_dataset = MultiModalDataset(
        data['test_paths'], data['X_test'], data['y_test'],
        image_size=config.IMAGE_SIZE
    )

    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE,
        shuffle=True, num_workers=config.NUM_WORKERS
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config.BATCH_SIZE,
        shuffle=False, num_workers=config.NUM_WORKERS
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MoEMultiModal(
        num_classes=len(data['label_encoder'].classes_),
        feature_dim=data['X_train'].shape[1],
        config=config
    ).to(device)

    criterion = FocalLoss(gamma=2, class_weights=class_weights)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )

    # Training
    print("Starting training...")
    train_moe_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        config=config,
        device=device
    )

    # Final evaluation
    print("\nLoading best model for final evaluation...")
    model.load_state_dict(torch.load(os.path.join(config.SAVE_DIR, 'best_moe_model.pth')))

    print("\nValidation Set Performance:")
    val_loss, val_acc, val_f1 = evaluate_one_epoch(model, val_loader, criterion, device)
    print(f"Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}")

    print("\nTest Set Performance:")
    test_loss, test_acc, test_f1 = evaluate_one_epoch(model, test_loader, criterion, device)
    print(f"Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}, F1 Score: {test_f1:.4f}")

    # Detailed classification report for test set
    all_preds = []
    all_labels = []
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            features = batch['features'].to(device)
            labels = batch['label'].to(device)

            outputs = model(images, features)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print("\nDetailed Classification Report:")
    print(classification_report(all_labels, all_preds,
                              target_names=data['label_encoder'].classes_))

    plot_confusion_matrix(
        all_labels, all_preds,
        data['label_encoder'].classes_,
        os.path.join(config.SAVE_DIR, 'confusion_matrix.png')
    )

if __name__ == "__main__":
    main()

Starting training...


Epoch 1/150: 100%|██████████| 210/210 [00:20<00:00, 10.48it/s]



Epoch 1/150:
Train - Loss: 0.7316, Acc: 0.3881, F1: 0.3835
Val   - Loss: 0.6823, Acc: 0.4694, F1: 0.3588


Epoch 2/150: 100%|██████████| 210/210 [00:20<00:00, 10.49it/s]



Epoch 2/150:
Train - Loss: 0.5544, Acc: 0.5893, F1: 0.5759
Val   - Loss: 0.3361, Acc: 0.7347, F1: 0.6491


Epoch 3/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 3/150:
Train - Loss: 0.4155, Acc: 0.6976, F1: 0.6954
Val   - Loss: 0.2464, Acc: 0.8367, F1: 0.7309


Epoch 4/150: 100%|██████████| 210/210 [00:19<00:00, 10.51it/s]



Epoch 4/150:
Train - Loss: 0.3729, Acc: 0.7262, F1: 0.7245
Val   - Loss: 0.1736, Acc: 0.8776, F1: 0.7702


Epoch 5/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 5/150:
Train - Loss: 0.3302, Acc: 0.7732, F1: 0.7727
Val   - Loss: 0.1673, Acc: 0.9184, F1: 0.8120


Epoch 6/150: 100%|██████████| 210/210 [00:19<00:00, 10.51it/s]



Epoch 6/150:
Train - Loss: 0.3058, Acc: 0.7935, F1: 0.7925
Val   - Loss: 0.1598, Acc: 0.8980, F1: 0.7970


Epoch 7/150: 100%|██████████| 210/210 [00:19<00:00, 10.54it/s]



Epoch 7/150:
Train - Loss: 0.3060, Acc: 0.8071, F1: 0.8073
Val   - Loss: 0.1393, Acc: 0.9184, F1: 0.8520


Epoch 8/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 8/150:
Train - Loss: 0.2857, Acc: 0.8202, F1: 0.8200
Val   - Loss: 0.1527, Acc: 0.9184, F1: 0.8890


Epoch 9/150: 100%|██████████| 210/210 [00:19<00:00, 10.55it/s]



Epoch 9/150:
Train - Loss: 0.2471, Acc: 0.8351, F1: 0.8350
Val   - Loss: 0.1643, Acc: 0.8776, F1: 0.8106


Epoch 10/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 10/150:
Train - Loss: 0.2595, Acc: 0.8488, F1: 0.8488
Val   - Loss: 0.1475, Acc: 0.8980, F1: 0.8249


Epoch 11/150: 100%|██████████| 210/210 [00:19<00:00, 10.51it/s]



Epoch 11/150:
Train - Loss: 0.2334, Acc: 0.8506, F1: 0.8505
Val   - Loss: 0.1155, Acc: 0.9388, F1: 0.9042


Epoch 12/150: 100%|██████████| 210/210 [00:19<00:00, 10.55it/s]



Epoch 12/150:
Train - Loss: 0.2375, Acc: 0.8583, F1: 0.8584
Val   - Loss: 0.1668, Acc: 0.9184, F1: 0.8682


Epoch 13/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 13/150:
Train - Loss: 0.2272, Acc: 0.8619, F1: 0.8616
Val   - Loss: 0.1156, Acc: 0.9388, F1: 0.8844


Epoch 14/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 14/150:
Train - Loss: 0.2200, Acc: 0.8726, F1: 0.8727
Val   - Loss: 0.1378, Acc: 0.9388, F1: 0.8844


Epoch 15/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 15/150:
Train - Loss: 0.2040, Acc: 0.8964, F1: 0.8965
Val   - Loss: 0.1134, Acc: 0.9184, F1: 0.8520


Epoch 16/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 16/150:
Train - Loss: 0.1978, Acc: 0.8881, F1: 0.8882
Val   - Loss: 0.1412, Acc: 0.9388, F1: 0.8844


Epoch 17/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 17/150:
Train - Loss: 0.2084, Acc: 0.8917, F1: 0.8916
Val   - Loss: 0.1200, Acc: 0.9184, F1: 0.8107


Epoch 18/150: 100%|██████████| 210/210 [00:19<00:00, 10.50it/s]



Epoch 18/150:
Train - Loss: 0.1756, Acc: 0.9030, F1: 0.9030
Val   - Loss: 0.1137, Acc: 0.9388, F1: 0.8844


Epoch 19/150: 100%|██████████| 210/210 [00:20<00:00, 10.47it/s]



Epoch 19/150:
Train - Loss: 0.2037, Acc: 0.8964, F1: 0.8960
Val   - Loss: 0.1387, Acc: 0.9388, F1: 0.8844


Epoch 20/150: 100%|██████████| 210/210 [00:20<00:00, 10.46it/s]



Epoch 20/150:
Train - Loss: 0.1745, Acc: 0.9125, F1: 0.9126
Val   - Loss: 0.1045, Acc: 0.9592, F1: 0.9396


Epoch 21/150: 100%|██████████| 210/210 [00:20<00:00, 10.45it/s]



Epoch 21/150:
Train - Loss: 0.1719, Acc: 0.9119, F1: 0.9120
Val   - Loss: 0.1466, Acc: 0.9184, F1: 0.8520


Epoch 22/150: 100%|██████████| 210/210 [00:19<00:00, 10.50it/s]



Epoch 22/150:
Train - Loss: 0.1857, Acc: 0.9119, F1: 0.9121
Val   - Loss: 0.1298, Acc: 0.9388, F1: 0.8844


Epoch 23/150: 100%|██████████| 210/210 [00:20<00:00, 10.48it/s]



Epoch 23/150:
Train - Loss: 0.1590, Acc: 0.9196, F1: 0.9199
Val   - Loss: 0.1228, Acc: 0.9388, F1: 0.8844


Epoch 24/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 24/150:
Train - Loss: 0.1630, Acc: 0.9220, F1: 0.9219
Val   - Loss: 0.1264, Acc: 0.9388, F1: 0.8844


Epoch 25/150: 100%|██████████| 210/210 [00:19<00:00, 10.52it/s]



Epoch 25/150:
Train - Loss: 0.1444, Acc: 0.9262, F1: 0.9262
Val   - Loss: 0.0932, Acc: 0.9388, F1: 0.8844


Epoch 26/150: 100%|██████████| 210/210 [00:19<00:00, 10.50it/s]



Epoch 26/150:
Train - Loss: 0.1642, Acc: 0.9179, F1: 0.9179
Val   - Loss: 0.1136, Acc: 0.9388, F1: 0.8844


Epoch 27/150: 100%|██████████| 210/210 [00:19<00:00, 10.52it/s]



Epoch 27/150:
Train - Loss: 0.1493, Acc: 0.9274, F1: 0.9273
Val   - Loss: 0.1471, Acc: 0.9184, F1: 0.8102


Epoch 28/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 28/150:
Train - Loss: 0.1391, Acc: 0.9375, F1: 0.9376
Val   - Loss: 0.1760, Acc: 0.9184, F1: 0.8520


Epoch 29/150: 100%|██████████| 210/210 [00:19<00:00, 10.51it/s]



Epoch 29/150:
Train - Loss: 0.1536, Acc: 0.9280, F1: 0.9280
Val   - Loss: 0.1086, Acc: 0.9592, F1: 0.9005


Epoch 30/150: 100%|██████████| 210/210 [00:19<00:00, 10.52it/s]



Epoch 30/150:
Train - Loss: 0.1422, Acc: 0.9333, F1: 0.9333
Val   - Loss: 0.1334, Acc: 0.9388, F1: 0.8844


Epoch 31/150: 100%|██████████| 210/210 [00:19<00:00, 10.52it/s]



Epoch 31/150:
Train - Loss: 0.1895, Acc: 0.9268, F1: 0.9271
Val   - Loss: 0.1454, Acc: 0.9388, F1: 0.8844


Epoch 32/150: 100%|██████████| 210/210 [00:20<00:00, 10.49it/s]



Epoch 32/150:
Train - Loss: 0.1381, Acc: 0.9381, F1: 0.9383
Val   - Loss: 0.1402, Acc: 0.9388, F1: 0.8844


Epoch 33/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 33/150:
Train - Loss: 0.1244, Acc: 0.9488, F1: 0.9490
Val   - Loss: 0.1341, Acc: 0.9388, F1: 0.8844


Epoch 34/150: 100%|██████████| 210/210 [00:19<00:00, 10.53it/s]



Epoch 34/150:
Train - Loss: 0.1279, Acc: 0.9470, F1: 0.9471
Val   - Loss: 0.1235, Acc: 0.9388, F1: 0.8844


Epoch 35/150: 100%|██████████| 210/210 [00:19<00:00, 10.51it/s]



Epoch 35/150:
Train - Loss: 0.1264, Acc: 0.9446, F1: 0.9447
Val   - Loss: 0.1247, Acc: 0.9388, F1: 0.8844

Early stopping triggered

Loading best model for final evaluation...

Validation Set Performance:
Loss: 0.0932, Accuracy: 0.9388, F1 Score: 0.8844

Test Set Performance:
Loss: 0.1822, Accuracy: 0.9000, F1 Score: 0.8216

Detailed Classification Report:
              precision    recall  f1-score   support

       apple       0.90      0.95      0.92        19
      banana       0.89      0.94      0.92        18
       mixed       0.67      0.40      0.50         5
      orange       0.94      0.94      0.94        18

    accuracy                           0.90        60
   macro avg       0.85      0.81      0.82        60
weighted avg       0.89      0.90      0.89        60



# 찐 다시

In [None]:
import os, cv2, numpy as np
from pathlib import Path
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

# 경로 설정
TRAIN_DIR = '/tmp/train_zip/train'
TEST_DIR = '/tmp/test_zip/test'
PREPROCESS_DIR = '/tmp/preprocessed_data'
TARGET_SIZE = 256

random.seed(42)  # 재현성을 위한 시드 설정

def create_directories():
    """작업 디렉토리 생성"""
    if os.path.exists(PREPROCESS_DIR):
        shutil.rmtree(PREPROCESS_DIR)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'train'), exist_ok=True)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'validation'), exist_ok=True)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'test'), exist_ok=True)

def remove_background(image):
    """과일 이미지에서 배경 제거"""
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # HSV 색상 범위 정의
    masks = [
        cv2.inRange(hsv, np.array([0, 20, 20]), np.array([15, 255, 255])),    # red1
        cv2.inRange(hsv, np.array([165, 20, 20]), np.array([180, 255, 255])), # red2
        cv2.inRange(hsv, np.array([15, 20, 20]), np.array([40, 255, 255])),   # yellow-orange
        cv2.inRange(hsv, np.array([40, 20, 20]), np.array([80, 255, 255])),   # green
        cv2.inRange(hsv, np.array([100, 20, 20]), np.array([140, 255, 255]))  # blue
    ]

    # 모든 마스크 결합
    color_mask = masks[0]
    for mask in masks[1:]:
        color_mask = cv2.bitwise_or(color_mask, mask)

    # 엣지 검출 및 마스크와 결합
    edges = cv2.Canny(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY), 100, 200)
    final_mask = cv2.bitwise_or(color_mask, edges)

    # 마스크 정제
    kernel = np.ones((5,5), np.uint8)
    final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel)
    final_mask = cv2.dilate(final_mask, kernel, iterations=1)

    return cv2.bitwise_and(image, image, mask=final_mask)

def center_on_black(image):
    """이미지를 검은 배경 중앙에 위치시키기"""
    black_bg = np.zeros((TARGET_SIZE, TARGET_SIZE, 3), dtype=np.uint8)
    h, w = image.shape[:2]
    scale = min(TARGET_SIZE / h, TARGET_SIZE / w) * 0.8
    new_h, new_w = int(h * scale), int(w * scale)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)

    y_offset = (TARGET_SIZE - new_h) // 2
    x_offset = (TARGET_SIZE - new_w) // 2
    black_bg[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized

    return black_bg

def basic_transforms(image):
    """기본 증강: 회전과 대칭만 사용"""
    h, w = image.shape[:2]
    center = (w // 2, h // 2)

    # 변환 종류 선택 (회전 또는 대칭)
    transform_type = random.choice(['flip', 'rotation'])

    if transform_type == 'flip':
        # 좌우 대칭만 적용
        return cv2.flip(image.copy(), 1)
    else:
        # 90도 단위 회전
        angle = random.choice([90, 180, 270])
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        return cv2.warpAffine(image.copy(), M, (w, h),
                            flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT)

def advanced_transforms(image):
   """고급 증강: 기하 변환 위주로 적용"""
   h, w = image.shape[:2]
   center = (w // 2, h // 2)

   def affine_transform(img):
       """어파인 변환"""
       src_pts = np.float32([[0,0], [w-1,0], [0,h-1]])
       dst_pts = src_pts + np.random.uniform(-w*0.05, w*0.05, src_pts.shape)
       M = cv2.getAffineTransform(src_pts, dst_pts)
       return cv2.warpAffine(img, M, (w,h), borderMode=cv2.BORDER_CONSTANT)

   def scale_variation(img):
       """크기 변화"""
       scale = np.random.uniform(0.9, 1.1)
       new_size = (int(w * scale), int(h * scale))
       scaled = cv2.resize(img, new_size)

       result = np.zeros_like(img)
       y_offset = (h - scaled.shape[0]) // 2
       x_offset = (w - scaled.shape[1]) // 2

       y_start = max(0, y_offset)
       y_end = min(h, y_offset + scaled.shape[0])
       x_start = max(0, x_offset)
       x_end = min(w, x_offset + scaled.shape[1])

       result[y_start:y_end, x_start:x_end] = scaled[:y_end-y_start, :x_end-x_start]
       return result

   def perspective_transform(img):
       """원근 변환"""
       src_pts = np.float32([[0,0], [w-1,0], [0,h-1], [w-1,h-1]])
       dst_pts = src_pts + np.random.uniform(-w*0.05, w*0.05, src_pts.shape)
       M = cv2.getPerspectiveTransform(src_pts, dst_pts)
       return cv2.warpPerspective(img, M, (w,h), borderMode=cv2.BORDER_CONSTANT)

   def shear_transform(img):
       """전단 변환"""
       shear_factor = np.random.uniform(-0.1, 0.1)
       M = np.float32([[1, shear_factor, 0], [0, 1, 0]])
       return cv2.warpAffine(img, M, (w,h), borderMode=cv2.BORDER_CONSTANT)

   def elastic_transform(img):
       """탄성 변형"""
       dx = np.random.uniform(-5, 5, (h,w))
       dy = np.random.uniform(-5, 5, (h,w))
       x, y = np.meshgrid(np.arange(w), np.arange(h))
       map_x = np.float32(x + dx)
       map_y = np.float32(y + dy)
       return cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)

   transforms = {
       'affine': affine_transform,
       'scale': scale_variation,
       'perspective': perspective_transform,
       'shear': shear_transform,
       'elastic': elastic_transform
   }

   # 3개의 랜덤한 기하 변환 적용
   chosen_transforms = random.sample(list(transforms.keys()), k=3)
   result_image = image.copy()

   for transform_name in chosen_transforms:
       result_image = transforms[transform_name](result_image)

   return result_image

def preprocess_and_split():
   """데이터 전처리 및 분할"""
   # 클래스별 파일 정리
   class_files = {}
   for filename in os.listdir(TRAIN_DIR):
       if filename.endswith('.jpg'):
           class_name = filename.split('_')[0]
           if class_name not in class_files:
               class_files[class_name] = []
           class_files[class_name].append(filename)

   # 8:2 분할
   train_files = {}
   val_files = {}
   for class_name, files in class_files.items():
       random.shuffle(files)
       split_idx = int(len(files) * 0.8)
       train_files[class_name] = files[:split_idx]
       val_files[class_name] = files[split_idx:]

   # 학습 세트 최대 샘플 수 확인
   train_max_samples = max(len(files) for files in train_files.values())
   print(f"Maximum samples per class in training set: {train_max_samples}")

   # 1단계: 학습 세트 처리
   for class_name, files in train_files.items():
       print(f"\nProcessing {class_name}...")
       processed_images = []
       processed_backgrounds = []

       # 원본 이미지 로드
       for filename in tqdm(files, desc="Loading original images"):
           image = cv2.imread(os.path.join(TRAIN_DIR, filename))
           if image is None:
               continue
           processed_images.append(image)

           # 배경 제거 버전도 저장
           bg_removed = remove_background(image.copy())
           centered = center_on_black(bg_removed)
           processed_backgrounds.append(centered)

       # 기본 증강으로 샘플 수 맞추기
       if len(processed_images) < train_max_samples:
           needed = train_max_samples - len(processed_images)
           original_images = processed_images.copy()

           for _ in tqdm(range(needed), desc="Basic augmentation"):
               idx = random.randint(0, len(original_images)-1)
               # 원본에 기본 증강 적용
               aug_img = basic_transforms(original_images[idx])
               # 증강된 이미지의 배경 제거 버전 생성
               aug_bg_removed = remove_background(aug_img)
               aug_centered = center_on_black(aug_bg_removed)

               processed_images.append(aug_img)
               processed_backgrounds.append(aug_centered)

       # 2단계: 고급 증강 (배경이 제거된 이미지에 적용)
       augmented_images = []
       for idx in tqdm(range(len(processed_backgrounds)), desc="Advanced augmentation"):
           base_img = processed_backgrounds[idx]
           # 각 이미지당 6개의 증강
           for _ in range(6):
               aug_img = advanced_transforms(base_img.copy())
               augmented_images.append(aug_img)
# 저장
       save_dir = os.path.join(PREPROCESS_DIR, 'train', class_name)
       os.makedirs(save_dir, exist_ok=True)

       # 기본 증강 이미지 저장 (배경 제거 버전)
       for i, img in enumerate(processed_backgrounds):
           save_name = f"{class_name}_orig_{i+1}.jpg"
           cv2.imwrite(os.path.join(save_dir, save_name), img)

       # 고급 증강 이미지 저장
       for i, img in enumerate(augmented_images):
           save_name = f"{class_name}_adv_{i+1}.jpg"
           cv2.imwrite(os.path.join(save_dir, save_name), img)

       print(f"Final count for {class_name}: {len(processed_backgrounds) + len(augmented_images)}")

   # Validation 세트 처리
   for class_name, files in val_files.items():
       save_dir = os.path.join(PREPROCESS_DIR, 'validation', class_name)
       os.makedirs(save_dir, exist_ok=True)

       for filename in tqdm(files, desc=f"Processing validation {class_name}"):
           image = cv2.imread(os.path.join(TRAIN_DIR, filename))
           if image is None:
               continue
           processed = remove_background(image)
           processed = center_on_black(processed)
           save_name = f"{class_name}_{filename}"
           cv2.imwrite(os.path.join(save_dir, save_name), processed)

def process_test_images():
   """테스트 세트 처리"""
   for filename in tqdm(os.listdir(TEST_DIR), desc="Processing test images"):
       if not filename.endswith('.jpg'):
           continue

       class_name = filename.split('_')[0]
       save_dir = os.path.join(PREPROCESS_DIR, 'test', class_name)
       os.makedirs(save_dir, exist_ok=True)

       image = cv2.imread(os.path.join(TEST_DIR, filename))
       if image is None:
           continue

       processed = remove_background(image)
       processed = center_on_black(processed)
       cv2.imwrite(os.path.join(save_dir, filename), processed)

def visualize_augmentations():
   """증강 결과 시각화"""
   train_dir = os.path.join(PREPROCESS_DIR, 'train')
   plt.figure(figsize=(20, 10))

   # 원본 이미지 선택
   sample_class = random.choice(os.listdir(TRAIN_DIR))
   image_path = os.path.join(TRAIN_DIR, random.choice([f for f in os.listdir(os.path.join(TRAIN_DIR, sample_class))
                                                      if f.endswith('.jpg')]))
   original_img = cv2.imread(image_path)

   # 원본 이미지 배경 제거 및 중앙 정렬
   bg_removed = remove_background(original_img.copy())
   centered = center_on_black(bg_removed)

   # 첫 번째 줄: 원본과 기본 증강
   plt.subplot(2, 5, 1)
   plt.imshow(cv2.cvtColor(centered, cv2.COLOR_BGR2RGB))
   plt.title("Original")
   plt.axis('off')

   # 기본 증강 4개 표시
   for i in range(4):
       plt.subplot(2, 5, i+2)
       # 원본에 기본 증강 적용
       augmented = basic_transforms(original_img.copy())
       # 증강된 이미지 배경 제거 및 중앙 정렬
       aug_bg_removed = remove_background(augmented)
       aug_centered = center_on_black(aug_bg_removed)
       plt.imshow(cv2.cvtColor(aug_centered, cv2.COLOR_BGR2RGB))
       plt.title(f"Basic Aug {i+1}")
       plt.axis('off')

   # 두 번째 줄: 고급 증강
   for i in range(5):
       plt.subplot(2, 5, i+6)
       # 배경이 제거된 이미지에 고급 증강 적용
       augmented = advanced_transforms(centered.copy())
       plt.imshow(cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB))
       plt.title(f"Advanced Aug {i+1}")
       plt.axis('off')

   plt.tight_layout()
   plt.show()

if __name__ == "__main__":
   print("Creating directories...")
   create_directories()

   print("\nPreprocessing and splitting training data...")
   preprocess_and_split()

   print("\nProcessing test data...")
   process_test_images()

   print("\nShowing augmentation examples...")
   visualize_augmentations()

   print("Done!")

Creating directories...

Preprocessing and splitting training data...
Maximum samples per class in training set: 60

Processing banana...


Loading original images: 100%|██████████| 58/58 [00:01<00:00, 37.64it/s]
Basic augmentation: 100%|██████████| 2/2 [00:00<00:00, 74.01it/s]
Advanced augmentation:   0%|          | 0/60 [00:00<?, ?it/s]


error: OpenCV(4.10.0) /io/opencv/modules/imgproc/src/imgwarp.cpp:3624: error: (-215:Assertion failed) src.checkVector(2, CV_32F) == 4 && dst.checkVector(2, CV_32F) == 4 in function 'getPerspectiveTransform'


In [None]:
import os, cv2, numpy as np
from pathlib import Path
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

# 경로 설정
TRAIN_DIR = '/tmp/train_zip/train'
TEST_DIR = '/tmp/test_zip/test'
PREPROCESS_DIR = '/tmp/preprocess_data'
TARGET_SIZE = 256

random.seed(42)  # 재현성 있는 랜덤

def create_directories():
    if os.path.exists(PREPROCESS_DIR):
        shutil.rmtree(PREPROCESS_DIR)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'train'), exist_ok=True)
    os.makedirs(os.path.join(PREPROCESS_DIR, 'test'), exist_ok=True)

def remove_background(image):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    # 컬러 마스크 범위 (느슨하게)
    masks = [
        cv2.inRange(hsv, np.array([0, 20, 20]), np.array([15, 255, 255])),   # red1
        cv2.inRange(hsv, np.array([165, 20, 20]), np.array([180, 255, 255])),# red2
        cv2.inRange(hsv, np.array([15, 20, 20]), np.array([40, 255, 255])),  # yellow
        cv2.inRange(hsv, np.array([5, 20, 20]), np.array([25, 255, 255]))    # orange
    ]
    color_mask = masks[0]
    for mask in masks[1:]:
        color_mask = cv2.bitwise_or(color_mask, mask)

    # 엄격한 엣지 검출
    edges = cv2.Canny(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY), 100, 200)
    final_mask = cv2.bitwise_or(color_mask, edges)

    kernel = np.ones((5, 5), np.uint8)
    final_mask = cv2.morphologyEx(final_mask, cv2.MORPH_CLOSE, kernel)
    final_mask = cv2.dilate(final_mask, kernel, iterations=1)

    return cv2.bitwise_and(image, image, mask=final_mask)

def center_on_black(image):
    black_bg = np.zeros((TARGET_SIZE, TARGET_SIZE, 3), dtype=np.uint8)
    h, w = image.shape[:2]
    scale = min(TARGET_SIZE / h, TARGET_SIZE / w) * 0.8
    new_h, new_w = int(h * scale), int(w * scale)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
    y_offset = (TARGET_SIZE - new_h) // 2
    x_offset = (TARGET_SIZE - new_w) // 2
    black_bg[y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized
    return black_bg

# mixed를 max_samples까지 늘릴 때 사용할 특수 증강
def special_augment_for_mixed(image):
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    transforms = [
        lambda img: cv2.GaussianBlur(img, (5, 5), 1.5),
        lambda img: cv2.resize(img, None, fx=np.random.uniform(1.3, 1.7), fy=np.random.uniform(1.3, 1.7)),
        lambda img: cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.uniform(-15, 15), 1.0), (w, h))
    ]
    return random.choice(transforms)(image)

# 다른 클래스가 부족할 경우 사용할 단순 증강
def simple_augment(image):
    transforms = [
        lambda img: cv2.flip(img, np.random.choice([-1, 0, 1])),
        lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.8, 1.2), beta=np.random.randint(-30, 30)),
        lambda img: cv2.GaussianBlur(img, (3, 3), 0),
        lambda img: cv2.resize(img, None, fx=np.random.uniform(1.1,1.3), fy=np.random.uniform(1.1,1.3))
    ]
    return random.choice(transforms)(image)

# 최종 증강 기법: 여러개 중 5개를 랜덤으로 선택, 각각 하나씩 적용
def final_augment_image(image):
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    transforms = {
        'rotation': lambda img: cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.randint(-30, 30), 1.0), (w, h)),
        'brightness': lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.7, 1.3), beta=np.random.randint(-50, 50)),
        'flip': lambda img: cv2.flip(img, np.random.choice([-1, 0, 1])),
        'contrast': lambda img: cv2.convertScaleAbs(img, alpha=np.random.uniform(0.5, 1.5)),
        'zoom': lambda img: cv2.resize(img, None, fx=np.random.uniform(1.2, 1.5), fy=np.random.uniform(1.2, 1.5)),
        'blur': lambda img: cv2.GaussianBlur(img, (5, 5), 0),
        'shear': lambda img: cv2.warpAffine(img, cv2.getRotationMatrix2D(center, np.random.uniform(-10, 10), 1.0), (w, h))
    }

    # 5개의 랜덤 증강 기법 선택
    chosen_keys = random.sample(list(transforms.keys()), 5)
    augmented_results = []
    for key in chosen_keys:
        aug_img = transforms[key](image)
        augmented_results.append((key, aug_img))
    return augmented_results

def preprocess_and_select(train=True):
    input_dir = TRAIN_DIR if train else TEST_DIR
    output_dir = os.path.join(PREPROCESS_DIR, 'train' if train else 'test')
    os.makedirs(output_dir, exist_ok=True)

    # 원본 이미지 카운트
    class_counts = {}
    for filename in os.listdir(input_dir):
        if filename.endswith('.jpg'):
            class_name = filename.split('_')[0]
            class_counts[class_name] = class_counts.get(class_name, 0) + 1

    max_samples = max(class_counts.values())
    print(f"{'Training' if train else 'Test'} max samples per class: {max_samples}")

    for class_name, count in class_counts.items():
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)

        class_files = [f for f in os.listdir(input_dir) if f.startswith(class_name) and f.endswith('.jpg')]
        class_files.sort()

        # 만약 count > max_samples면 max_samples까지만 사용 (학습 데이터일 경우)
        if train and count > max_samples:
            class_files = class_files[:max_samples]

        processed_images = []
        # 전처리 진행
        for filename in tqdm(class_files, desc=f"Preprocessing {class_name}", leave=False):
            image_path = os.path.join(input_dir, filename)
            image = cv2.imread(image_path)
            if image is None:
                continue
            processed = remove_background(image)
            processed = center_on_black(processed)
            processed_images.append(processed)

        # 부족한 경우 증강해서 max_samples 맞추기
        if train:
            if class_name == 'mixed' and len(processed_images) < max_samples:
                needed = max_samples - len(processed_images)
                for _ in range(needed):
                    base_img = random.choice(processed_images)
                    processed_images.append(special_augment_for_mixed(base_img))
            elif class_name != 'mixed' and len(processed_images) < max_samples:
                needed = max_samples - len(processed_images)
                for _ in range(needed):
                    base_img = random.choice(processed_images)
                    processed_images.append(simple_augment(base_img))

            # 최종적으로 max_samples 만족
            assert len(processed_images) == max_samples, f"{class_name} does not have {max_samples} images after preprocessing."

        # test일 경우는 증강 없음, 그냥 처리된 이미지만 저장
        # train일 경우 max_samples개 이미지 저장 (증강 전 상태)
        for i, img in enumerate(processed_images):
            save_name = f"{class_name}_{i+1}.jpg"
            cv2.imwrite(os.path.join(class_dir, save_name), img)

def final_augmentation():
    """
    모든 학습 클래스에 대해:
    이미 각 클래스당 max_samples개 이미지를 갖춤.
    각 이미지당 5개 증강 기법을 랜덤 선택, 각각 1장씩 총 5장 증강 이미지 생성
    원본 1장 + 증강 5장 = 총 6장
    """
    train_dir = os.path.join(PREPROCESS_DIR, 'train')
    for class_name in os.listdir(train_dir):
        class_path = os.path.join(train_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        images = [f for f in os.listdir(class_path) if f.endswith('.jpg') and '_aug' not in f]
        images.sort()

        for filename in tqdm(images, desc=f"Final Augmenting {class_name}", leave=False):
            img_path = os.path.join(class_path, filename)
            img = cv2.imread(img_path)
            if img is None:
                continue
            aug_imgs = final_augment_image(img)
            base_name = os.path.splitext(filename)[0]

            # 선택된 증강 기법 출력 (옵션)
            chosen_transforms = [t[0] for t in aug_imgs]
            print(f"{filename} -> chosen transforms: {chosen_transforms}")

            for (transform_name, aimg) in aug_imgs:
                aug_path = os.path.join(class_path, f"{base_name}_aug_{transform_name}.jpg")
                cv2.imwrite(aug_path, aimg)

    # 최종 개수 출력
    for class_name in os.listdir(train_dir):
        cpath = os.path.join(train_dir, class_name)
        if os.path.isdir(cpath):
            final_count = len([f for f in os.listdir(cpath) if f.endswith('.jpg')])
            print(f"Class {class_name} final count: {final_count}")

def plot_examples_for_mixed():
    class_dir = os.path.join(PREPROCESS_DIR, 'train', 'mixed')
    if not os.path.exists(class_dir):
        print("No mixed class directory found.")
        return
    images = [f for f in os.listdir(class_dir) if f.endswith('.jpg') and '_aug' not in f]
    if not images:
        print("No images found in mixed class.")
        return
    img_path = os.path.join(class_dir, images[0])
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    aug_imgs = final_augment_image(img)

    plt.figure(figsize=(15, 3))
    plt.subplot(1, len(aug_imgs) + 1, 1)
    plt.imshow(img_rgb)
    plt.title("Original")
    plt.axis('off')

    for i, (name, aimg) in enumerate(aug_imgs, 1):
        plt.subplot(1, len(aug_imgs) + 1, i+1)
        plt.imshow(cv2.cvtColor(aimg, cv2.COLOR_BGR2RGB))
        plt.title(name)  # 증강 기법 이름 표시
        plt.axis('off')
    plt.show()

if __name__ == "__main__":
    create_directories()

    print("\nPreprocessing training images...")
    preprocess_and_select(train=True)

    print("\nPreprocessing test images (no augmentation)...")
    preprocess_and_select(train=False)

    print("\nFinal augmentation for training images...")
    final_augmentation()

    print("\nShowing examples for mixed class final augmentation...")
    plot_examples_for_mixed()

