In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
class UTKFaceDataset:
    def __init__(self, dataset_dir, image_size=(64, 64)):
        self.dataset_dir = dataset_dir
        self.image_size = image_size
        self.data = []
    # 데이터셋을 로드합니다.
    def load_data(self):
        file_list = os.listdir(self.dataset_dir)
        for filename in file_list:
            parts = filename.split('_')
            try:
                age = int(parts[0])  # Extract age
            except (ValueError, IndexError):
                print(f"Skipping invalid file: {filename}")
                continue
            
            image_path = os.path.join(self.dataset_dir, filename)
            try:
                with Image.open(image_path) as img:
                    self.data.append({'image': img.copy(), 'age': age})
            except Exception as e:
                print(f"Error loading image {filename}: {e}")
    # 이미지를 전처리합니다.(4차원 배열 + Min-Max scaling)
    def preprocess_data(self):
        self.df = pd.DataFrame(self.data)
        images = np.array([np.array(img.resize(self.image_size)) for img in self.df['image']])
        images = images / 255.0  # Normalize
        labels = np.array(self.df['age'])
        return images.astype('float32'), labels
    # 나이에 따른 이미지 수를 시각화합니다. 
    def visualize_data(self):
        plt.figure(figsize=(8, 6))
        sns.histplot(self.df['age'], kde=True)
        plt.xlabel("age")
        plt.ylabel("Number of images")
        plt.show()