## 0. Config

In [1]:
# 테스트 데이터셋 폴더 경로를 지정해주세요.
TRAIN_DIR = "/Users/yeongseon/Downloads/train/train/images"
MODEL_DIR = "./model"

In [2]:
params = { 
    "seed": 42,
    "epochs": 3,
    "dataset": "MaskBaseDataset",
    "augmentation": "BaseAugmentation",
    "resize": [128, 96],
    "train_batch_size": 64,
    "valid_batch_size": 1000,
    "model" : "BaseModel",
    "optimizer": "SGD",
    "lr": 1e-3,
    "val_ratio": 0.2,
    "criterion": "cross_entropy",
    "lr_decay_step": 20,
    "log_interval": 20,
    "name": "exp",
    "data_dir": TRAIN_DIR,
    "model_dir": MODEL_DIR
}

In [3]:
import torch
import numpy as np
import random

seed = params["seed"]

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# Initialize random seed
seed_everything(seed)

# device optimization
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

## 1. Test Dataset 정의

In [4]:
import numpy as np
from enum import Enum
from typing import Tuple
from torch.utils.data import Dataset, Subset, random_split

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2


class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(
                f"Gender value should be either 'male' or 'female', {value}"
            )


class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD


class MaskBaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL,
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(
        self,
        data_dir,
        mean=(0.548, 0.504, 0.479),
        std=(0.237, 0.247, 0.246),
        val_ratio=0.2,
    ):
        self.data_dir = data_dir
        self.mean = mean
        self.std = std
        self.val_ratio = val_ratio

        self.transform = None
        self.setup()
        self.calc_statistics()

    def setup(self):
        profiles = os.listdir(self.data_dir)
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.data_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if (
                    _file_name not in self._file_names
                ):  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(
                    self.data_dir, profile, file_name
                )  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)

    def calc_statistics(self):
        has_statistics = self.mean is not None and self.std is not None
        if not has_statistics:
            print(
                "[Warning] Calculating statistics... It can take a long time depending on your CPU machine"
            )
            sums = []
            squared = []
            for image_path in self.image_paths[:3000]:
                image = np.array(Image.open(image_path)).astype(np.int32)
                sums.append(image.mean(axis=(0, 1)))
                squared.append((image ** 2).mean(axis=(0, 1)))

            self.mean = np.mean(sums, axis=0) / 255
            self.std = (np.mean(squared, axis=0) - self.mean ** 2) ** 0.5 / 255

    def set_transform(self, transform):
        self.transform = transform

    def __getitem__(self, index):
        assert self.transform is not None, ".set_tranform 메소드를 이용하여 transform 을 주입해주세요"

        image = self.read_image(index)
        mask_label = self.get_mask_label(index)
        gender_label = self.get_gender_label(index)
        age_label = self.get_age_label(index)
        multi_class_label = self.encode_multi_class(mask_label, gender_label, age_label)

        image_transform = self.transform(image)
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

    def get_mask_label(self, index) -> MaskLabels:
        return self.mask_labels[index]

    def get_gender_label(self, index) -> GenderLabels:
        return self.gender_labels[index]

    def get_age_label(self, index) -> AgeLabels:
        return self.age_labels[index]

    def read_image(self, index):
        image_path = self.image_paths[index]
        return Image.open(image_path)

    @staticmethod
    def encode_multi_class(mask_label, gender_label, age_label) -> int:
        return mask_label * 6 + gender_label * 3 + age_label

    @staticmethod
    def decode_multi_class(
        multi_class_label,
    ) -> Tuple[MaskLabels, GenderLabels, AgeLabels]:
        mask_label = (multi_class_label // 6) % 3
        gender_label = (multi_class_label // 3) % 2
        age_label = multi_class_label % 3
        return mask_label, gender_label, age_label

    @staticmethod
    def denormalize_image(image, mean, std):
        img_cp = image.copy()
        img_cp *= std
        img_cp += mean
        img_cp *= 255.0
        img_cp = np.clip(img_cp, 0, 255).astype(np.uint8)
        return img_cp

    def split_dataset(self) -> Tuple[Subset, Subset]:
        """
        데이터셋을 train 과 val 로 나눕니다,
        pytorch 내부의 torch.utils.data.random_split 함수를 사용하여
        torch.utils.data.Subset 클래스 둘로 나눕니다.
        구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^
        """
        n_val = int(len(self) * self.val_ratio)
        n_train = len(self) - n_val
        train_set, val_set = random_split(self, [n_train, n_val])
        return train_set, val_set

SyntaxError: invalid syntax (4293227174.py, line 49)

In [None]:
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize


class BaseAugmentation:
    def __init__(self, resize, mean, std, **args):
        self.transform = transforms.Compose(
            [
                Resize(resize, Image.BILINEAR),
                ToTensor(),
                Normalize(mean=mean, std=std),
            ]
        )

    def __call__(self, image):
        return self.transform(image)

In [None]:
# number of subprocess to use for data loading
num_workers = 0

# how manhy sample per batch to load
train_batch_size = params["train_batch_size"]
valid_batch_size = params["valid_batch_size"]

# 
val_ratio = params["val_ratio"]

resize = params["resize"]

dataset = MaskBaseDataset(data_dir=TRAIN_DIR, val_ratio=val_ratio)

# Argumentation
transform = BaseAugmentation(
    resize=resize,
    mean=dataset.mean,
    std=dataset.std,
)
dataset.set_transform(transform)

# data loader
train_set, val_set = dataset.split_dataset()

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=train_batch_size,
    num_workers=num_workers,
    shuffle=True,
    pin_memory=use_cuda,
    drop_last=True,
)

val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=valid_batch_size,
    num_workers=num_workers,
    shuffle=False,
    pin_memory=use_cuda,
    drop_last=True,
)

## 2. Model 정의

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class BaseModel(nn.Module):
    def __init__(self, num_classes: int = 1000):
        super(BaseModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Initialize the Model
model = BaseModel(num_classes=dataset.num_classes)
model = torch.nn.DataParallel(model)

### Loss Function Optimizer 정의 하기

In [None]:
lr = params["lr"]

# Specifiy loss funtion
criterion = nn.CrossEntropyLoss() # categorical cross-entryopy

# Specify optimer
optimizer = torch.optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr,
    weight_decay=5e-4,
)

## 3. 학습하기

In [None]:
import random
import matplotlib.pyplot as plt

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

def grid_image(np_images, gts, preds, n=16, shuffle=False):
    batch_size = np_images.shape[0]
    assert n <= batch_size

    choices = random.choices(range(batch_size), k=n) if shuffle else list(range(n))
    figure = plt.figure(
        figsize=(12, 18 + 2)
    )  # cautions: hardcoded, 이미지 크기에 따라 figsize 를 조정해야 할 수 있습니다. T.T
    plt.subplots_adjust(
        top=0.8
    )  # cautions: hardcoded, 이미지 크기에 따라 top 를 조정해야 할 수 있습니다. T.T
    n_grid = np.ceil(n ** 0.5)
    tasks = ["mask", "gender", "age"]
    for idx, choice in enumerate(choices):
        gt = gts[choice].item()
        pred = preds[choice].item()
        image = np_images[choice]
        # title = f"gt: {gt}, pred: {pred}"
        gt_decoded_labels = MaskBaseDataset.decode_multi_class(gt)
        pred_decoded_labels = MaskBaseDataset.decode_multi_class(pred)
        title = "\n".join(
            [
                f"{task} - gt: {gt_label}, pred: {pred_label}"
                for gt_label, pred_label, task in zip(
                    gt_decoded_labels, pred_decoded_labels, tasks
                )
            ]
        )

        plt.subplot(n_grid, n_grid, idx + 1, title=title)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(image, cmap=plt.cm.binary)

    return figure

In [None]:
import os
import glob
import re
from pathlib import Path

def increment_path(path, exist_ok=False):
    """Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

    Args:
        path (str or pathlib.Path): f"{model_dir}/{args.name}".
        exist_ok (bool): whether increment path (increment if False).
    """
    path = Path(path)
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}*")
        matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]
        n = max(i) + 1 if i else 2
        return f"{path}{n}"

In [None]:
import json
from torch.utils.tensorboard import SummaryWriter


epochs = params["epochs"]
model_dir = params["model_dir"]
log_interval = params["log_interval"]


scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=20, gamma=0.5
)

save_dir = increment_path(os.path.join(model_dir, "exp"))

# writer will output to save_dir
writer = SummaryWriter(log_dir=save_dir)

with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as f:
    json.dump(params, f, ensure_ascii=False, indent=4)

best_val_acc = 0
best_val_loss = np.inf

# epoch training loop
for epoch in range(epochs):
    # train loop
    model.train()
    loss_value = 0
    matches = 0

    # batch training loop
    for idx, train_batch in enumerate(train_loader):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)

        loss.backward()
        optimizer.step()

        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        if (idx + 1) % log_interval == 0:
            train_loss = loss_value / log_interval
            train_acc = matches / train_batch_size / log_interval
            current_lr = get_lr(optimizer)
            print(
                f"Epoch[{epoch}/{epochs}]({idx + 1}/{len(train_loader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
            )
            writer.add_scalar(
                "Train/loss", train_loss, epoch * len(train_loader) + idx
            )
            writer.add_scalar(
                "Train/accuracy", train_acc, epoch * len(train_loader) + idx
            )

            loss_value = 0
            matches = 0

    scheduler.step()

    # val loop
    with torch.no_grad():
        print("Calculating validation results...")
        model.eval()
        val_loss_items = []
        val_acc_items = []
        figure = None

        for inputs, labels in val_loader:
            # Move the batch to the device
            inputs = inputs.to(device)
            labels = labels.to(device)

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)

            loss_item = criterion(outs, labels).item()
            acc_item = (labels == preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)

            if figure is None:
                inputs_np = (
                    torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                )
                inputs_np = MaskBaseDataset.denormalize_image(
                    inputs_np, dataset.mean, dataset.std
                )
                figure = grid_image(
                    inputs_np,
                    labels,
                    preds,
                    n=16,
                    shuffle=True,
                )

        val_loss = np.sum(val_loss_items) / len(val_loader)
        val_acc = np.sum(val_acc_items) / len(val_set)
        best_val_loss = min(best_val_loss, val_loss)
        if val_acc > best_val_acc:
            print(
                f"New best model for val accuracy : {val_acc:4.2%}! saving the best model.."
            )
            torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
            best_val_acc = val_acc

        torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
        print(
            f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
            f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
        )

        writer.add_scalar("Val/loss", val_loss, epoch)
        writer.add_scalar("Val/accuracy", val_acc, epoch)
        writer.add_figure("results", figure, epoch)
        print()

writer.close()