In [None]:
import os
import random
import fnmatch
import cv2
import numpy as np
import datetime
import pickle
import imgaug.augmenters as img_aug


class FashionModel:
    def __init__(self):
        data_dir = 'fashion_data'
        file_list = os.listdir(data_dir)
        image_paths = []
        fashion_labels = []  # 패션 카테고리 저장
        pattern = "*.png"
        self.model_output_dir = 'fashion_output'
        os.makedirs(self.model_output_dir, exist_ok=True)

        # 파일명에서 카테고리(라벨) 추출
        for filename in file_list:
            if fnmatch.fnmatch(filename, pattern):
                image_paths.append(os.path.join(data_dir, filename))
                try:
                    # 예: 파일명에서 패션 카테고리 추출, 또는 라벨 파일에서 읽어옴
                    category = int(filename.split('_')[1])  # 예시: 'img_1.png'에서 1 추출
                    fashion_labels.append(category)
                except ValueError:
                    print(f"Error parsing category from {filename}")

        # 데이터셋 분리
        self.X_train, self.X_valid, self.y_train, self.y_valid = train_test_split(
            image_paths, fashion_labels, test_size=0.2)
        print("Training data: %d\nValidation data: %d" %
              (len(self.X_train), len(self.X_valid)))

    def random_augment(self, image):
        if np.random.rand() < 0.5:
            image = self.pan(image)
        if np.random.rand() < 0.5:
            image = self.zoom(image)
        if np.random.rand() < 0.5:
            image = self.blur(image)
        if np.random.rand() < 0.5:
            image = self.adjust_brightness(image)
        return image

    def my_imread(self, image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image

    def zoom(self, image):
        zoom = img_aug.Affine(scale=(1, 1.3))
        image = zoom.augment_image(image)
        return image

    def pan(self, image):
        pan = img_aug.Affine(translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)})
        image = pan.augment_image(image)
        return image

    def adjust_brightness(self, image):
        brightness = img_aug.Multiply((0.7, 1.3))
        image = brightness.augment_image(image)
        return image

    def blur(self, image):
        kernel_size = random.randint(1, 5)
        image = cv2.blur(image, (kernel_size, kernel_size))
        return image

    def img_preprocess(self, image):
        # 패션 이미지를 적합한 크기로 전처리
        image = cv2.resize(image, (128, 128))
        image = image / 255.0
        return image

    def fashion_model(self):
        # 패션 이미지 분류용 간단한 CNN 모델
        model = Sequential(name='Fashion_Model')
        model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)))
        model.add(Conv2D(64, (3, 3), activation='relu'))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(128, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(10, activation='softmax'))  # 패션 카테고리가 10개라고 가정
        optimizer = Adam(lr=1e-3)
        model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        return model

    def image_data_generator(self, image_paths, labels, batch_size, is_training):
        while True:
            batch_images = []
            batch_labels = []
            for i in range(batch_size):
                random_index = random.randint(0, len(image_paths) - 1)
                image_path = image_paths[random_index]
                image = self.my_imread(image_path)
                label = labels[random_index]
                if is_training:
                    image = self.random_augment(image)
                image = self.img_preprocess(image)
                batch_images.append(image)
                batch_labels.append(label)
            yield np.asarray(batch_images), np.asarray(batch_labels)

    def train(self):
        model = self.fashion_model()
        print(model.summary())

        # 체크포인트 설정
        checkpoint_callback = ModelCheckpoint(filepath=os.path.join(self.model_output_dir, 'model-{epoch:03d}.h5'),
                                              verbose=1, save_best_only=True)

        # 모델 학습
        history = model.fit(
            self.image_data_generator(self.X_train, self.y_train, batch_size=64, is_training=True),
            steps_per_epoch=200,
            epochs=50,  # 에포크 수 조정
            validation_data=self.image_data_generator(self.X_valid, self.y_valid, batch_size=32, is_training=False),
            validation_steps=50,
            verbose=1,
            shuffle=True,
            callbacks=[checkpoint_callback]
        )

        # 모델 저장
        model.save(os.path.join(self.model_output_dir, 'final_fashion_model.h5'))

        # 학습 기록 저장
        history_path = os.path.join(self.model_output_dir, 'history.pickle')
        with open(history_path, 'wb') as f:
            pickle.dump(history.history, f, pickle.HIGHEST_PROTOCOL)


if __name__ == '__main__':
    fashion_lab = FashionModel()
    fashion_lab.train()
    print("Fashion model training finished!")
