In [None]:
import os
import random
import fnmatch
import datetime
import pickle
import numpy as np
import cv2
from imgaug import augmenters as img_aug
import matplotlib.pyplot as plt
from PIL import Image
from google.colab import files
uploaded = files.upload()

# 업로드된 파일들을 data 디렉토리로 이동
import shutil
os.makedirs('data', exist_ok=True)
for filename in uploaded.keys():
    shutil.move(filename, 'data/' + filename)

class FashionStyleClassifier:
    def __init__(self):
        data_dir = 'data'
        file_list = os.listdir(data_dir)
        image_paths = []
        labels = []
        pattern = "*.jpg"
        self.model_output_dir = 'output'
        os.makedirs(self.model_output_dir, exist_ok=True)
        for filename in file_list:
            if fnmatch.fnmatch(filename, pattern):
                image_paths.append(os.path.join(data_dir, filename))
                label = filename.split('_')[2]  # 스타일 라벨 추출
                labels.append(label)

        self.X_train, self.X_valid, self.y_train, self.y_valid = train_test_split(image_paths, labels, test_size=0.2)
        print("Training data: %d\nValidation data: %d" % (len(self.X_train), len(self.X_valid)))

    def random_augment(self, image):
        if np.random.rand() < 0.5:
            image = self.pan(image)
        if np.random.rand() < 0.5:
            image = self.zoom(image)
        if np.random.rand() < 0.5:
            image = self.blur(image)
        if np.random.rand() < 0.5:
            image = self.adjust_brightness(image)
        return image

    def my_imread(self, image_path):
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image

    def zoom(self, image):
        zoom = img_aug.Affine(scale=(1, 1.3))
        image = zoom.augment_image(image)
        return image

    def pan(self, image):
        pan = img_aug.Affine(translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)})
        image = pan.augment_image(image)
        return image

    def adjust_brightness(self, image):
        brightness = img_aug.Multiply((0.7, 1.3))
        image = brightness.augment_image(image)
        return image

    def blur(self, image):
        kernel_size = random.randint(1, 5)
        image = cv2.blur(image, (kernel_size, kernel_size))
        return image

    def img_preprocess(self, image):
        height, _, _ = image.shape
        image = image[int(height/2):, :, :]
        image = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
        image = cv2.GaussianBlur(image, (3, 3), 0)
        image = cv2.resize(image, (200, 66))
        image = image / 255
        return image

    def resnet18_model(self, num_classes):
        # ResNet-18 모델 정의
        model = Sequential(name='ResNet18_Fashion_Classifier')
        model.add(Conv2D(64, (7, 7), strides=(2, 2), input_shape=(66, 200, 3), activation='relu'))
        model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
        model.add(Conv2D(128, (3, 3), activation='relu'))
        model.add(Flatten())
        model.add(Dense(512, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(num_classes, activation='softmax'))
        optimizer = Adam(lr=1e-4)
        model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        return model

    def image_data_generator(self, image_paths, labels, batch_size, is_training):
        while True:
            batch_images = []
            batch_labels = []
            for i in range(batch_size):
                random_index = random.randint(0, len(image_paths) - 1)
                image_path = image_paths[random_index]
                image = self.my_imread(image_paths[random_index])
                label = labels[random_index]
                if is_training:
                    image = self.random_augment(image)
                image = self.img_preprocess(image)
                batch_images.append(image)
                batch_labels.append(label)
            yield np.asarray(batch_images), np.asarray(batch_labels)

    def train(self, num_classes):
        model = self.resnet18_model(num_classes)
        print(model.summary())
        checkpoint_callback = ModelCheckpoint(filepath=os.path.join(self.model_output_dir, 'best_model.h5'), verbose=1, save_best_only=True)

        history = model.fit(
            self.image_data_generator(self.X_train, self.y_train, batch_size=32, is_training=True),
            steps_per_epoch=200,
            epochs=50,
            validation_data=self.image_data_generator(self.X_valid, self.y_valid, batch_size=32, is_training=False),
            validation_steps=100,
            verbose=1,
            callbacks=[checkpoint_callback]
        )

        # 모델 저장
        model.save(os.path.join(self.model_output_dir, 'final_model.h5'))

        # 학습 기록 저장
        history_path = os.path.join(self.model_output_dir, 'history.pickle')
        with open(history_path, 'wb') as f:
            pickle.dump(history.history, f, pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':
    classifier = FashionStyleClassifier()
    classifier.train(num_classes)
    print("Training finished!")
