In [1]:
from IPython.core.display import display, HTML

display(HTML("<style>.container { width:140% !important; }</style>"))

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://github.com/philhoonoh/blog_git/blob/main/comp_dataset_2.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View Source</a>
  </td>
</table>

# PyTorch Dataset 심화 part2)

> PyTorch Dataset 심화 part1) 에 이은 part2) 입니다.   
> 해당 코드는 [https://boostcamp.connect.or.kr/] 에서 참조했음을 알려드립니다.   
> MaskBaseDataset Class 를 활용하여 MaskStratifiedDataset, ThreeWayStratifiedDataset 등 라는 새로운 Dataset을 정의 합니다.  

In [3]:
import os 
import random
from collections import defaultdict
from typing import Tuple, List

from enum import Enum

from PIL import Image
from torch.utils.data import Dataset, Subset, random_split
from torchvision import transforms
from torchvision.transforms import *
import numpy as np

In [6]:
class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2
    

class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")


class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD

In [7]:
class MaskBaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, data_dir, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        self.data_dir = data_dir
        self.mean = mean
        self.std = std
        self.val_ratio = val_ratio

        self.transform = None
        self.setup()
        self.calc_statistics()

    def setup(self):
        profiles = os.listdir(self.data_dir)
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.data_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.data_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)

    def calc_statistics(self):
        has_statistics = self.mean is not None and self.std is not None
        if not has_statistics:
            print("[Warning] Calculating statistics... It can take a long time depending on your CPU machine")
            sums = []
            squared = []
            for image_path in self.image_paths[:3000]:
                image = np.array(Image.open(image_path)).astype(np.int32)
                sums.append(image.mean(axis=(0, 1)))
                squared.append((image ** 2).mean(axis=(0, 1)))

            self.mean = np.mean(sums, axis=0) / 255
            self.std = (np.mean(squared, axis=0) - self.mean ** 2) ** 0.5 / 255

    def set_transform(self, transform):
        self.transform = transform

    def __getitem__(self, index):
        assert self.transform is not None, ".set_tranform 메소드를 이용하여 transform 을 주입해주세요"

        image = self.read_image(index)
        mask_label = self.get_mask_label(index)
        gender_label = self.get_gender_label(index)
        age_label = self.get_age_label(index)
        multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)

        image_transform = self.transform(image)
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

    def get_mask_label(self, index) -> MaskLabels:
        return self.mask_labels[index]

    def get_gender_label(self, index) -> GenderLabels:
        return self.gender_labels[index]

    def get_age_label(self, index) -> AgeLabels:
        return self.age_labels[index]

    def read_image(self, index):
        image_path = self.image_paths[index]
        return Image.open(image_path)

    @staticmethod
    def encode_multi_class(mask_label, gender_label, age_label) -> int:
        return mask_label * 6 + gender_label * 3 + age_label

    @staticmethod
    def decode_multi_class(multi_class_label) -> Tuple[MaskLabels, GenderLabels, AgeLabels]:
        mask_label = (multi_class_label // 6) % 3
        gender_label = (multi_class_label // 3) % 2
        age_label = multi_class_label % 3
        return mask_label, gender_label, age_label

    @staticmethod
    def denormalize_image(image, mean, std):
        img_cp = image.copy()
        img_cp *= std
        img_cp += mean
        img_cp *= 255.0
        img_cp = np.clip(img_cp, 0, 255).astype(np.uint8)
        return img_cp

    def split_dataset(self) -> Tuple[Subset, Subset]:
        """
        데이터셋을 train 과 val 로 나눕니다,
        pytorch 내부의 torch.utils.data.random_split 함수를 사용하여
        torch.utils.data.Subset 클래스 둘로 나눕니다.
        구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^
        """
        n_val = int(len(self) * self.val_ratio)
        n_train = len(self) - n_val
        train_set, val_set = random_split(self, [n_train, n_val])
        return train_set, val_set

### MaskStratifiedDataset

> MaskStratifiedDataset 같은 경우 target의 클래스 distrubtion에 따란 train, validation 을 나눕니다. 

> MaskBaseDataset 를 inherit 상속 받아 쓰시면 됩니다.   
> 바꾸어야 될 부분은  set 과 split_dataset 함수입니다.  
> set 함수에서 dictionary 을 활용하여 클래스별 indice 를 저장합니다.  
> split_dataset 에서 dictionary 별로 random sample 을 하여 train, validation set을 나누면 됩니다. 

In [17]:
class MaskStratifiedDataset(MaskBaseDataset):
    """
        target class에 proportional 하게 train/val set을 나눔
    """
    def __init__(self, data_dir, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        self.multi_label_dict = defaultdict(list)
        super().__init__(data_dir, mean, std, val_ratio)

    def setup(self):
        profiles = os.listdir(self.data_dir)
        indice = 0
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.data_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.data_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                
                multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)
                self.multi_label_dict[multi_class_label].append(indice)
                indice += 1

    def split_dataset(self) -> Tuple[Subset, Subset]:
        train_indices = []
        val_indices = []

        for key, value in self.multi_label_dict.items():
            n_val = int(len(value) * self.val_ratio)
            random.shuffle(value)
            key_val_indices = value[:n_val]
            key_train_indices = value[n_val:]

            val_indices += key_val_indices
            train_indices += key_train_indices
            
            random.shuffle(val_indices)
            random.shuffle(train_indices)
            
        return Subset(self, train_indices), Subset(self, val_indices)

### ThreeWayStratifiedDataset

> ThreeWayStratifiedDataset 같은 경우, age, gender, mask 를 각각 label 해서 데이터를 생성합니다.    

> MaskBaseDataset 같은 경우, 하나의 label를 return 하며 18 class 가 있습니다.   
> ThreeWayStratifiedDataset 같은 경우,3개(age, gender, mask)의 labels를 return 하며 각각의 3,2,3 class 가 있습니다.  

> MaskBaseDataset  를 inherit 상속 받아 쓰시면 됩니다.   
> 바꾸어야 될 부분은 \_\_getitem\_\_, setup, split_dataset 함수입니다.  
> \_\_getitem\_\_ 함수 3가지 label 를 각각 return 하게 바꾸어 줍니다. -> Model Architecture 및 훈련방법 또한 달라질 것입니다.  
> set 함수에서 dictionary 을 활용하여 클래스별 indice 를 저장합니다.  
> split_dataset 에서 dictionary 별로 random sample 을 하여 train, validation set을 나누면 됩니다.  

- 더욱 간단하게 구현하려면 MaskStratifiedDataset 를 상속받아 구현하실 수 있습니다. 

In [19]:
class ThreeWayStratifiedDataset(MaskBaseDataset):
    def __init__(self, data_dir, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        self.multi_label_dict = defaultdict(list)
        super().__init__(data_dir, mean, std, val_ratio)
    
    def __getitem__(self, index):
        assert self.transform is not None, ".set_tranform 메소드를 이용하여 transform 을 주입해주세요"

        image = self.read_image(index)
        mask_label = self.get_mask_label(index)
        gender_label = self.get_gender_label(index)
        age_label = self.get_age_label(index)
        multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)

        image_transform = self.transform(image)
        return image_transform, [age_label, mask_label, gender_label]
    
    def setup(self):
        profiles = os.listdir(self.data_dir)
        indice = 0
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.data_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.data_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                
                multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)
                self.multi_label_dict[multi_class_label].append(indice)
                indice += 1

    def split_dataset(self) -> Tuple[Subset, Subset]:
        """
        데이터셋을 target 의 비율로 나눕니다.
        """
        train_indices = []
        val_indices = []

        for key, value in self.multi_label_dict.items():
            n_val = int(len(value) * self.val_ratio)
            random.shuffle(value)
            key_val_indices = value[:n_val]
            key_train_indices = value[n_val:]

            val_indices += key_val_indices
            train_indices += key_train_indices
            
            random.shuffle(val_indices)
            random.shuffle(train_indices)
            
        return Subset(self, train_indices), Subset(self, val_indices)