In [1]:
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import Adam

In [2]:
class AgePredictionModel:
    def __init__(self, input_shape=(64, 64, 3)):
        self.model = self._build_model(input_shape)
    # 모델을 구축합니다.(CNN + MLP)
    def _build_model(self, input_shape):
        model = models.Sequential([
            layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(128, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(256, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Flatten(),
            layers.Dense(256, activation='relu'),
            layers.Dropout(0.5),
            layers.Dense(128, activation='relu'),
            layers.Dropout(0.5),
            layers.Dense(1, activation='relu')
        ])
        model.compile(optimizer=Adam(), loss='mean_squared_error', metrics=['mae'])
        return model
    # 모델을 확인합니다.
    def check_model(self):
        self.model.summary()
    # 모델을 학습시킵니다.
    def train(self, x_train, y_train, batch_size=128, epochs=30, validation_split=0.2):
        self.history = self.model.fit(
            x_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=validation_split
        )
        return self.history
    # 주어진 이미지에 대한 나이를 출력합니다.    
    def predict(self, image):
        pred = self.model.predict(image)
        pred = round(pred[0][0])
        print(f"Predicted Age: {pred}")
    # Epoch에 따른 Mean Absolute Error를 출력합니다.
    def plot_training_history(self):
        plt.plot(self.history.history['mae'], label='Train MAE')
        plt.plot(self.history.history['val_mae'], label='Validation MAE')
        plt.xlabel('Epochs')
        plt.ylabel('Mean Absolute Error')
        plt.legend()
        plt.show()