In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


# 31 classes 분류

데이터셋 경로 설정(본인 드라이브에 맞게 수정)

In [2]:
import os
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
import torch
from tqdm import tqdm  # tqdm 불러오기

# Train 데이터셋 경로 (processed_images와 원본 이미지파일)
train_dir_1 = '/content/drive/MyDrive/mission_data/class_processed_images'
train_dir_2 = '/content/drive/MyDrive/mission_data/class_training_images'

# Val 데이터 경로 설정
val_dir = '/content/drive/MyDrive/mission_data/class_validation_images'

데이터셋 정규화 및 증강(정규화 솔직히 차이 잘 모르겠음... 기존 평균, 표준편차로 돌려도 별 차이 없을 듯?)

이 코드 빼고 사용해도 됨

In [None]:
# 한번 돌리면 안돌려도 됨(결과 같음)

import os
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
import torch
from tqdm import tqdm  # tqdm 불러오기

# 이미지 로드 및 기본 전처리 (이미지 크기만 일단 조정)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# train 원본 데이터만 확인
train_dataset_2 = datasets.ImageFolder(root=train_dir_2, transform=transform)

# DataLoader 설정
batch_size = 32
origin_train_loader = DataLoader(train_dataset_2, batch_size=batch_size, shuffle=True, num_workers=4)

# 평균과 표준편차 계산 함수에 tqdm 추가
def calculate_mean_std(loader):
    mean = torch.zeros(3)
    std = torch.zeros(3)
    total_images_count = 0

    # tqdm을 사용하여 진행 상황을 보여줌
    for images, _ in tqdm(loader, desc="Calculating mean and std"):
        batch_samples = images.size(0)  # 배치 내 이미지 개수
        images = images.view(batch_samples, images.size(1), -1)  # [batch_size, channels, height*width]
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images_count += batch_samples

    mean /= total_images_count
    std /= total_images_count

    return mean, std

# 평균과 표준편차 계산
print("Calculating for train dataset...")
mean, std = calculate_mean_std(origin_train_loader)
print(f"Train Mean: {mean}")
print(f"Train Standard Deviation: {std}")

Calculating for train dataset...


Calculating mean and std: 100%|██████████████████████████████████████████████████████| 128/128 [15:43<00:00,  7.37s/it]

Train Mean: tensor([0.5498, 0.5226, 0.5052])
Train Standard Deviation: tensor([0.2600, 0.2582, 0.2620])





데이터셋 전처리 및 로드

In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
import torch

# 이미지 로드 및 기본 전처리 + 정규화
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5498, 0.5226, 0.5052], std=[0.2600, 0.2582, 0.2620])  # 계산한 평균, 표준편차를 활용한 정규화
])

# 증강을 추가한 transform
augment_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=1),  # 좌우 반전
    transforms.ColorJitter(saturation=0.3, brightness=0.3),  # 채도와 밝기 변환
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5498, 0.5226, 0.5052], std=[0.2600, 0.2582, 0.2620])
])

# 커스텀 데이터셋 클래스 정의
class GenderStyleDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.image_paths = []
        self.gender_labels = []
        self.style_labels = []
        self.style_map = {}

        # 클래스 폴더 이름을 숫자로 매핑 (0 ~ 30)
        self.style_folders = sorted(os.listdir(root))
        for idx, folder_name in enumerate(self.style_folders):
            self.style_map[folder_name] = idx

        # 데이터셋 로드 및 성별/스타일 라벨 생성
        for folder_name in self.style_folders:
            folder_path = os.path.join(root, folder_name)
            if os.path.isdir(folder_path):
                # 성별은 폴더명에서 W (여성) 또는 M (남성)으로 구분
                gender = 0 if folder_name.startswith('W') else 1
                style_label = self.style_map[folder_name]

                for img_name in os.listdir(folder_path):
                    img_path = os.path.join(folder_path, img_name)
                    if img_name.endswith(('jpg', 'png')):
                        self.image_paths.append(img_path)
                        self.gender_labels.append(gender)
                        self.style_labels.append(style_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        gender_label = torch.tensor(self.gender_labels[idx], dtype=torch.float32)  # 성별 레이블
        style_label = torch.tensor(self.style_labels[idx], dtype=torch.long)        # 스타일 레이블

        return image, gender_label, style_label

    def get_style_name(self, style_label):
        """숫자 레이블을 스타일 이름으로 변환"""
        for folder_name, label in self.style_map.items():
            if label == style_label:
                return folder_name
        return None


# 각 폴더에서 데이터셋 로드
processed_datasets = GenderStyleDataset(root=train_dir_1, transform=transform)  # 전처리 데이터셋
train_datasets =  GenderStyleDataset(root=train_dir_2, transform=transform)  # 원본 데이터셋
augmented_datasets = GenderStyleDataset(root=train_dir_2, transform=augment_transform)  # 증강된 데이터셋

# 데이터셋 합치기
train_dataset = ConcatDataset([processed_datasets, train_datasets, augmented_datasets])

# validation 데이터셋 로드
val_dataset = GenderStyleDataset(root=val_dir, transform=transform)

# DataLoader 설정
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

print(f"Train size: {len(train_loader.dataset)}, Val size: {len(val_loader.dataset)}")


Train size: 12210, Val size: 951


In [4]:
# 데이터셋의 첫 번째 항목 확인
image, gender_label, style_label = val_dataset[530]
print(f"Gender label: {gender_label}")  # 0(여성), 1(남성)
print(f"Style label (numeric): {style_label}")  # 0 ~ 30 중 하나의 숫자 레이블

# 스타일 이름 확인
style_name = val_dataset.get_style_name(style_label.item())
print(f"Style label (actual name): {style_name}")

Gender label: 0.0
Style label (numeric): 8
Style label (actual name): W_athleisure


resnet-18 모델 정의 및 설정

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class MultiOutputResNet(nn.Module):
    def __init__(self, num_styles=31):  # 스타일 총 31개
        super(MultiOutputResNet, self).__init__()
        self.resnet = models.resnet18(weights=None)  # ResNet-18, pretrained=False
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()  # ResNet의 Fully Connected 레이어를 제거하고 특징 추출만 수행

        # 성별 분류 레이어 (이진 분류라서 출력 뉴런 수는 1)
        self.gender_fc = nn.Linear(num_features, 1)

        # 스타일 분류 레이어 (남성과 여성 스타일을 합친 31개)
        self.style_fc = nn.Linear(num_features, num_styles)

    def forward(self, x):
        # 공통 특징 추출 (ResNet-18 사용)
        features = self.resnet(x)

        # 성별 분류
        gender_output = torch.sigmoid(self.gender_fc(features))  # 시그모이드를 사용해 성별 확률을 계산

        # 성별에 따른 스타일 분류 (전체 스타일 레이블을 사용하지만 이후 성별에 따라 제한)
        style_output = self.style_fc(features)

        return gender_output, style_output


Using device: cuda


진행상황 저장 (이것도 본인 드라이브에 맞게 경로 수정)

In [6]:
# 체크포인트 저장 함수
def save_checkpoint(state, filename="/content/drive/MyDrive/GSW/checkpoint_multi.pth", weights_only=False):
    torch.save(state, filename)


# 체크포인트 로드 함수
def load_checkpoint(filename="/content/drive/MyDrive/GSW/checkpoint_multi.pth", weights_only=False):
    checkpoint = torch.load(filename)
    epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    best_loss = checkpoint['best_loss']
    return epoch, best_loss

resnet-18 학습모델

In [7]:
# 학습 및 검증 함수 정의
def train_and_validate(model, train_loader, val_loader, gender_criterion, style_criterion, optimizer, scheduler, num_epochs, start_epoch=0, best_loss=float('inf')):
    history = {
        'train_gender_loss': [], 'train_style_loss': [], 'train_total_loss': [],
        'val_gender_loss': [], 'val_style_loss': [], 'val_total_loss': [],
        'train_gender_acc': [], 'train_style_acc': [], 'train_total_acc': [],
        'val_gender_acc': [], 'val_style_acc': [], 'val_total_acc': []
    }

    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_gender_loss = 0.0
        running_style_loss = 0.0
        running_total_loss = 0.0

        correct_gender = 0
        correct_style = 0
        total_samples = 0

        for images, gender_labels, style_labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            gender_labels = gender_labels.to(device)
            style_labels = style_labels.to(device)

            # Forward pass
            gender_output, style_output = model(images)

            # 성별 예측 손실 계산
            loss_gender = gender_criterion(gender_output.view(-1), gender_labels.float())  # 성별 분류 손실

            # 성별에 따른 스타일 손실 계산
            loss_style = 0
            for i in range(len(gender_labels)):
                if gender_labels[i] == 0:  # 여성
                    loss_style += style_criterion(style_output[i, 8:], style_labels[i] - 8)  # 8~30 스타일만 사용
                else:  # 남성
                    loss_style += style_criterion(style_output[i, :8], style_labels[i])  # 0~7 스타일만 사용

            # 총 손실
            total_loss = loss_gender + loss_style

            # Backward pass 및 최적화
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            running_gender_loss += loss_gender.item()
            running_style_loss += loss_style.item()
            running_total_loss += total_loss.item()

            # 정확도 계산
            predicted_gender = (gender_output > 0.5).float().view(-1)
            correct_gender += (predicted_gender == gender_labels).sum().item()

            # 성별이 맞았을 때만 스타일 정확도 평가
            for i in range(len(gender_labels)):
                if predicted_gender[i] == gender_labels[i]:  # 성별이 맞았을 때만 스타일 평가
                    if gender_labels[i] == 0:  # 여성
                        predicted_style = torch.argmax(style_output[i, 8:]) + 8  # 8~30 스타일 중 예측
                    else:  # 남성
                        predicted_style = torch.argmax(style_output[i, :8])  # 0~7 스타일 중 예측

                    correct_style += (predicted_style == style_labels[i]).sum().item()

            total_samples += gender_labels.size(0)

        # 평균 손실 및 정확도 계산
        avg_gender_loss = running_gender_loss / len(train_loader)
        avg_style_loss = running_style_loss / len(train_loader)
        avg_total_loss = running_total_loss / len(train_loader)

        gender_acc = correct_gender / total_samples
        style_acc = correct_style / total_samples
        total_acc = (correct_gender + correct_style) / (2 * total_samples)

        # 기록
        history['train_gender_loss'].append(avg_gender_loss)
        history['train_style_loss'].append(avg_style_loss)
        history['train_total_loss'].append(avg_total_loss)
        history['train_gender_acc'].append(gender_acc)
        history['train_style_acc'].append(style_acc)
        history['train_total_acc'].append(total_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Gender Loss: {avg_gender_loss:.4f}, Train Style Loss: {avg_style_loss:.4f}, Train Total Loss: {avg_total_loss:.4f}")
        print(f"Train Gender Acc: {gender_acc:.4f}, Train Style Acc: {style_acc:.4f}, Train Total Acc: {total_acc:.4f}")

        # Validation phase
        model.eval()
        val_running_gender_loss = 0.0
        val_running_style_loss = 0.0
        val_running_total_loss = 0.0

        val_correct_gender = 0
        val_correct_style = 0
        val_total_samples = 0

        with torch.no_grad():
            for images, gender_labels, style_labels in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"):
                images = images.to(device)
                gender_labels = gender_labels.to(device)
                style_labels = style_labels.to(device)

                # Forward pass
                gender_output, style_output = model(images)

                # 성별 손실 계산
                val_loss_gender = gender_criterion(gender_output.view(-1), gender_labels.float())

                # 성별에 따른 스타일 손실 계산 (성별이 맞았을 때만)
                val_loss_style = 0
                for i in range(len(gender_labels)):
                    if gender_labels[i] == 0:  # 여성
                        val_loss_style += style_criterion(style_output[i, 8:], style_labels[i] - 8)  # 8~30 스타일만 사용
                    else:  # 남성
                        val_loss_style += style_criterion(style_output[i, :8], style_labels[i])  # 0~7 스타일만 사용

                val_total_loss = val_loss_gender + val_loss_style
                val_running_gender_loss += val_loss_gender.item()
                val_running_style_loss += val_loss_style.item()
                val_running_total_loss += val_total_loss.item()

                # 성별이 맞았을 때만 스타일 정확도 평가
                val_predicted_gender = (gender_output > 0.5).float().view(-1)
                val_correct_gender += (val_predicted_gender == gender_labels).sum().item()

                for i in range(len(gender_labels)):
                    if val_predicted_gender[i] == gender_labels[i]:  # 성별이 맞았을 때만 스타일 평가
                        if gender_labels[i] == 0:  # 여성
                            val_predicted_style = torch.argmax(style_output[i, 8:]) + 8  # 8~30 스타일 중 예측
                        else:  # 남성
                            val_predicted_style = torch.argmax(style_output[i, :8])  # 0~7 스타일 중 예측

                        val_correct_style += (val_predicted_style == style_labels[i]).sum().item()

                val_total_samples += gender_labels.size(0)

        # Validation 손실 및 정확도 평균
        avg_val_gender_loss = val_running_gender_loss / len(val_loader)
        avg_val_style_loss = val_running_style_loss / len(val_loader)
        avg_val_total_loss = val_running_total_loss / len(val_loader)

        val_gender_acc = val_correct_gender / val_total_samples
        val_style_acc = val_correct_style / val_total_samples
        val_total_acc = (val_correct_gender + val_correct_style) / (2 * val_total_samples)

        # 기록
        history['val_gender_loss'].append(avg_val_gender_loss)
        history['val_style_loss'].append(avg_val_style_loss)
        history['val_total_loss'].append(avg_val_total_loss)
        history['val_gender_acc'].append(val_gender_acc)
        history['val_style_acc'].append(val_style_acc)
        history['val_total_acc'].append(val_total_acc)

        print(f"Validation Gender Loss: {avg_val_gender_loss:.4f}, Validation Style Loss: {avg_val_style_loss:.4f}, Validation Total Loss: {avg_val_total_loss:.4f}")
        print(f"Validation Gender Acc: {val_gender_acc:.4f}, Validation Style Acc: {val_style_acc:.4f}, Validation Total Acc: {val_total_acc:.4f}")

        # ReduceLROnPlateau 스케줄러를 사용하여 학습률 조정 (총 손실 기준)
        scheduler.step(avg_val_total_loss)
        current_lr = scheduler.optimizer.param_groups[0]['lr']  # 현재 학습률 확인
        print(f"Current Learning Rate: {current_lr}")

        # 가장 낮은 총 손실을 가진 모델 저장
        if avg_val_total_loss < best_loss:
            best_loss = avg_val_total_loss
            torch.save(model.state_dict(), 'best_model.pth')

            # 체크포인트 저장
            save_checkpoint({
                'epoch': epoch + 1,  # 현재 에폭 저장
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_loss': best_loss
            })

    return history


하이퍼파라미터 설정 및 학습

In [8]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 하이퍼파라미터 설정
learning_rate = 0.01
num_epochs = 100

# 모델 정의 (MultiOutputResNet 모델)
model = MultiOutputResNet(num_styles=31).to(device)

# Loss 및 Optimizer 설정
gender_criterion = nn.BCEWithLogitsLoss()  # 성별 이진 분류를 위한 손실 함수
style_criterion = nn.CrossEntropyLoss()  # 스타일 분류 손실은 다중 클래스 분류

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)  # AdamW 옵티마이저 사용
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6)



# 학습 재개를 위한 체크포인트 로드
try:
    start_epoch, best_loss = load_checkpoint()
    print(f"Checkpoint loaded. Resuming from epoch {start_epoch} with lowest loss {best_loss:.4f}.")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")
    start_epoch, best_loss = 0, float('inf')

# 학습 및 검증 시작
history = train_and_validate(model, train_loader, val_loader, gender_criterion, style_criterion, optimizer, scheduler, num_epochs, start_epoch, best_loss)

  checkpoint = torch.load(filename)


No checkpoint found, starting from scratch.


Training Epoch 1/100: 100%|██████████| 382/382 [23:21<00:00,  3.67s/it]


Epoch [1/100], Train Gender Loss: 0.6884, Train Style Loss: 82.6726, Train Total Loss: 83.3610
Train Gender Acc: 0.4482, Train Style Acc: 0.0347, Train Total Acc: 0.2414


Validation Epoch 1/100: 100%|██████████| 30/30 [02:10<00:00,  4.35s/it]


Validation Gender Loss: 0.6906, Validation Style Loss: 79.9396, Validation Total Loss: 80.6302
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0463, Validation Total Acc: 0.2471
Current Learning Rate: 0.01


Training Epoch 2/100: 100%|██████████| 382/382 [14:18<00:00,  2.25s/it]


Epoch [2/100], Train Gender Loss: 0.6882, Train Style Loss: 79.9257, Train Total Loss: 80.6139
Train Gender Acc: 0.4477, Train Style Acc: 0.0376, Train Total Acc: 0.2426


Validation Epoch 2/100: 100%|██████████| 30/30 [00:59<00:00,  1.98s/it]


Validation Gender Loss: 0.6887, Validation Style Loss: 79.9504, Validation Total Loss: 80.6391
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0484, Validation Total Acc: 0.2482
Current Learning Rate: 0.01


Training Epoch 3/100: 100%|██████████| 382/382 [14:25<00:00,  2.27s/it]


Epoch [3/100], Train Gender Loss: 0.6875, Train Style Loss: 79.8882, Train Total Loss: 80.5757
Train Gender Acc: 0.4477, Train Style Acc: 0.0366, Train Total Acc: 0.2421


Validation Epoch 3/100: 100%|██████████| 30/30 [01:00<00:00,  2.01s/it]


Validation Gender Loss: 0.6881, Validation Style Loss: 79.3782, Validation Total Loss: 80.0664
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0557, Validation Total Acc: 0.2518
Current Learning Rate: 0.01


Training Epoch 4/100: 100%|██████████| 382/382 [15:17<00:00,  2.40s/it]


Epoch [4/100], Train Gender Loss: 0.6877, Train Style Loss: 79.8525, Train Total Loss: 80.5402
Train Gender Acc: 0.4477, Train Style Acc: 0.0414, Train Total Acc: 0.2446


Validation Epoch 4/100: 100%|██████████| 30/30 [01:04<00:00,  2.13s/it]


Validation Gender Loss: 0.6885, Validation Style Loss: 79.7062, Validation Total Loss: 80.3947
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0305, Validation Total Acc: 0.2392
Current Learning Rate: 0.01


Training Epoch 5/100: 100%|██████████| 382/382 [16:16<00:00,  2.56s/it]


Epoch [5/100], Train Gender Loss: 0.6874, Train Style Loss: 79.8098, Train Total Loss: 80.4972
Train Gender Acc: 0.4475, Train Style Acc: 0.0399, Train Total Acc: 0.2437


Validation Epoch 5/100: 100%|██████████| 30/30 [01:14<00:00,  2.48s/it]


Validation Gender Loss: 0.6877, Validation Style Loss: 79.5857, Validation Total Loss: 80.2734
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0484, Validation Total Acc: 0.2482
Current Learning Rate: 0.01


Training Epoch 6/100: 100%|██████████| 382/382 [16:13<00:00,  2.55s/it]


Epoch [6/100], Train Gender Loss: 0.6875, Train Style Loss: 79.7313, Train Total Loss: 80.4188
Train Gender Acc: 0.4483, Train Style Acc: 0.0392, Train Total Acc: 0.2438


Validation Epoch 6/100: 100%|██████████| 30/30 [01:04<00:00,  2.13s/it]


Validation Gender Loss: 0.6860, Validation Style Loss: 79.3721, Validation Total Loss: 80.0581
Validation Gender Acc: 0.4501, Validation Style Acc: 0.0326, Validation Total Acc: 0.2413
Current Learning Rate: 0.01


Training Epoch 7/100: 100%|██████████| 382/382 [14:55<00:00,  2.34s/it]


Epoch [7/100], Train Gender Loss: 0.6872, Train Style Loss: 79.6765, Train Total Loss: 80.3637
Train Gender Acc: 0.4483, Train Style Acc: 0.0405, Train Total Acc: 0.2444


Validation Epoch 7/100: 100%|██████████| 30/30 [01:03<00:00,  2.13s/it]


Validation Gender Loss: 0.6864, Validation Style Loss: 79.6127, Validation Total Loss: 80.2991
Validation Gender Acc: 0.4774, Validation Style Acc: 0.0641, Validation Total Acc: 0.2708
Current Learning Rate: 0.01


Training Epoch 8/100: 100%|██████████| 382/382 [14:55<00:00,  2.34s/it]


Epoch [8/100], Train Gender Loss: 0.6868, Train Style Loss: 79.5761, Train Total Loss: 80.2629
Train Gender Acc: 0.4480, Train Style Acc: 0.0408, Train Total Acc: 0.2444


Validation Epoch 8/100: 100%|██████████| 30/30 [01:03<00:00,  2.11s/it]


Validation Gender Loss: 0.6859, Validation Style Loss: 78.6667, Validation Total Loss: 79.3526
Validation Gender Acc: 0.4669, Validation Style Acc: 0.0620, Validation Total Acc: 0.2645
Current Learning Rate: 0.01


Training Epoch 9/100: 100%|██████████| 382/382 [14:08<00:00,  2.22s/it]


Epoch [9/100], Train Gender Loss: 0.6866, Train Style Loss: 79.5606, Train Total Loss: 80.2472
Train Gender Acc: 0.4507, Train Style Acc: 0.0416, Train Total Acc: 0.2462


Validation Epoch 9/100: 100%|██████████| 30/30 [01:00<00:00,  2.02s/it]


Validation Gender Loss: 0.6869, Validation Style Loss: 78.7565, Validation Total Loss: 79.4434
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0515, Validation Total Acc: 0.2497
Current Learning Rate: 0.01


Training Epoch 10/100: 100%|██████████| 382/382 [14:04<00:00,  2.21s/it]


Epoch [10/100], Train Gender Loss: 0.6868, Train Style Loss: 79.4796, Train Total Loss: 80.1664
Train Gender Acc: 0.4492, Train Style Acc: 0.0405, Train Total Acc: 0.2449


Validation Epoch 10/100: 100%|██████████| 30/30 [01:00<00:00,  2.03s/it]


Validation Gender Loss: 0.6863, Validation Style Loss: 78.6605, Validation Total Loss: 79.3468
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0589, Validation Total Acc: 0.2534
Current Learning Rate: 0.01


Training Epoch 11/100: 100%|██████████| 382/382 [14:26<00:00,  2.27s/it]


Epoch [11/100], Train Gender Loss: 0.6866, Train Style Loss: 79.3893, Train Total Loss: 80.0760
Train Gender Acc: 0.4490, Train Style Acc: 0.0429, Train Total Acc: 0.2459


Validation Epoch 11/100: 100%|██████████| 30/30 [01:01<00:00,  2.04s/it]


Validation Gender Loss: 0.6847, Validation Style Loss: 78.7111, Validation Total Loss: 79.3958
Validation Gender Acc: 0.4501, Validation Style Acc: 0.0599, Validation Total Acc: 0.2550
Current Learning Rate: 0.01


Training Epoch 12/100: 100%|██████████| 382/382 [14:05<00:00,  2.21s/it]


Epoch [12/100], Train Gender Loss: 0.6863, Train Style Loss: 79.3087, Train Total Loss: 79.9950
Train Gender Acc: 0.4491, Train Style Acc: 0.0426, Train Total Acc: 0.2459


Validation Epoch 12/100: 100%|██████████| 30/30 [01:00<00:00,  2.02s/it]


Validation Gender Loss: 0.6856, Validation Style Loss: 78.8224, Validation Total Loss: 79.5080
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0578, Validation Total Acc: 0.2529
Current Learning Rate: 0.005


Training Epoch 13/100: 100%|██████████| 382/382 [14:02<00:00,  2.21s/it]


Epoch [13/100], Train Gender Loss: 0.6854, Train Style Loss: 78.9329, Train Total Loss: 79.6183
Train Gender Acc: 0.4496, Train Style Acc: 0.0470, Train Total Acc: 0.2483


Validation Epoch 13/100: 100%|██████████| 30/30 [01:00<00:00,  2.03s/it]


Validation Gender Loss: 0.6851, Validation Style Loss: 78.1897, Validation Total Loss: 78.8747
Validation Gender Acc: 0.4479, Validation Style Acc: 0.0620, Validation Total Acc: 0.2550
Current Learning Rate: 0.005


Training Epoch 14/100: 100%|██████████| 382/382 [13:48<00:00,  2.17s/it]


Epoch [14/100], Train Gender Loss: 0.6840, Train Style Loss: 78.6320, Train Total Loss: 79.3160
Train Gender Acc: 0.4511, Train Style Acc: 0.0472, Train Total Acc: 0.2491


Validation Epoch 14/100: 100%|██████████| 30/30 [00:59<00:00,  1.99s/it]


Validation Gender Loss: 0.6825, Validation Style Loss: 78.1306, Validation Total Loss: 78.8131
Validation Gender Acc: 0.4585, Validation Style Acc: 0.0631, Validation Total Acc: 0.2608
Current Learning Rate: 0.005


Training Epoch 15/100: 100%|██████████| 382/382 [14:56<00:00,  2.35s/it]


Epoch [15/100], Train Gender Loss: 0.6830, Train Style Loss: 78.1833, Train Total Loss: 78.8662
Train Gender Acc: 0.4616, Train Style Acc: 0.0544, Train Total Acc: 0.2580


Validation Epoch 15/100: 100%|██████████| 30/30 [01:06<00:00,  2.23s/it]


Validation Gender Loss: 0.6783, Validation Style Loss: 76.9614, Validation Total Loss: 77.6397
Validation Gender Acc: 0.4711, Validation Style Acc: 0.0810, Validation Total Acc: 0.2760
Current Learning Rate: 0.005


Training Epoch 16/100: 100%|██████████| 382/382 [14:59<00:00,  2.36s/it]


Epoch [16/100], Train Gender Loss: 0.6791, Train Style Loss: 77.6816, Train Total Loss: 78.3608
Train Gender Acc: 0.4941, Train Style Acc: 0.0663, Train Total Acc: 0.2802


Validation Epoch 16/100: 100%|██████████| 30/30 [01:03<00:00,  2.10s/it]


Validation Gender Loss: 0.6758, Validation Style Loss: 77.1843, Validation Total Loss: 77.8601
Validation Gender Acc: 0.5174, Validation Style Acc: 0.0810, Validation Total Acc: 0.2992
Current Learning Rate: 0.005


Training Epoch 17/100: 100%|██████████| 382/382 [15:00<00:00,  2.36s/it]


Epoch [17/100], Train Gender Loss: 0.6770, Train Style Loss: 76.8957, Train Total Loss: 77.5727
Train Gender Acc: 0.5142, Train Style Acc: 0.0722, Train Total Acc: 0.2932


Validation Epoch 17/100: 100%|██████████| 30/30 [01:04<00:00,  2.16s/it]


Validation Gender Loss: 0.6727, Validation Style Loss: 75.9180, Validation Total Loss: 76.5907
Validation Gender Acc: 0.5195, Validation Style Acc: 0.0852, Validation Total Acc: 0.3023
Current Learning Rate: 0.005


Training Epoch 18/100: 100%|██████████| 382/382 [15:24<00:00,  2.42s/it]


Epoch [18/100], Train Gender Loss: 0.6750, Train Style Loss: 75.8160, Train Total Loss: 76.4911
Train Gender Acc: 0.5219, Train Style Acc: 0.0859, Train Total Acc: 0.3039


Validation Epoch 18/100: 100%|██████████| 30/30 [01:13<00:00,  2.46s/it]


Validation Gender Loss: 0.6688, Validation Style Loss: 76.6706, Validation Total Loss: 77.3394
Validation Gender Acc: 0.5121, Validation Style Acc: 0.0873, Validation Total Acc: 0.2997
Current Learning Rate: 0.005


Training Epoch 19/100: 100%|██████████| 382/382 [17:01<00:00,  2.67s/it]


Epoch [19/100], Train Gender Loss: 0.6749, Train Style Loss: 73.9410, Train Total Loss: 74.6158
Train Gender Acc: 0.5230, Train Style Acc: 0.0964, Train Total Acc: 0.3097


Validation Epoch 19/100: 100%|██████████| 30/30 [01:14<00:00,  2.47s/it]


Validation Gender Loss: 0.6698, Validation Style Loss: 76.0202, Validation Total Loss: 76.6900
Validation Gender Acc: 0.5647, Validation Style Acc: 0.1073, Validation Total Acc: 0.3360
Current Learning Rate: 0.005


Training Epoch 20/100: 100%|██████████| 382/382 [16:25<00:00,  2.58s/it]


Epoch [20/100], Train Gender Loss: 0.6725, Train Style Loss: 72.0544, Train Total Loss: 72.7270
Train Gender Acc: 0.5347, Train Style Acc: 0.1061, Train Total Acc: 0.3204


Validation Epoch 20/100: 100%|██████████| 30/30 [01:03<00:00,  2.11s/it]


Validation Gender Loss: 0.6671, Validation Style Loss: 74.1203, Validation Total Loss: 74.7874
Validation Gender Acc: 0.5426, Validation Style Acc: 0.1167, Validation Total Acc: 0.3297
Current Learning Rate: 0.005


Training Epoch 21/100: 100%|██████████| 382/382 [14:33<00:00,  2.29s/it]


Epoch [21/100], Train Gender Loss: 0.6712, Train Style Loss: 69.5564, Train Total Loss: 70.2276
Train Gender Acc: 0.5420, Train Style Acc: 0.1262, Train Total Acc: 0.3341


Validation Epoch 21/100: 100%|██████████| 30/30 [01:02<00:00,  2.09s/it]


Validation Gender Loss: 0.6737, Validation Style Loss: 73.5008, Validation Total Loss: 74.1745
Validation Gender Acc: 0.5058, Validation Style Acc: 0.1094, Validation Total Acc: 0.3076
Current Learning Rate: 0.005


Training Epoch 22/100: 100%|██████████| 382/382 [14:37<00:00,  2.30s/it]


Epoch [22/100], Train Gender Loss: 0.6710, Train Style Loss: 66.2388, Train Total Loss: 66.9098
Train Gender Acc: 0.5395, Train Style Acc: 0.1450, Train Total Acc: 0.3422


Validation Epoch 22/100: 100%|██████████| 30/30 [01:03<00:00,  2.10s/it]


Validation Gender Loss: 0.6637, Validation Style Loss: 70.9861, Validation Total Loss: 71.6498
Validation Gender Acc: 0.6015, Validation Style Acc: 0.1661, Validation Total Acc: 0.3838
Current Learning Rate: 0.005


Training Epoch 23/100: 100%|██████████| 382/382 [14:20<00:00,  2.25s/it]


Epoch [23/100], Train Gender Loss: 0.6714, Train Style Loss: 61.5674, Train Total Loss: 62.2387
Train Gender Acc: 0.5382, Train Style Acc: 0.1755, Train Total Acc: 0.3568


Validation Epoch 23/100: 100%|██████████| 30/30 [01:00<00:00,  2.02s/it]


Validation Gender Loss: 0.6714, Validation Style Loss: 67.3717, Validation Total Loss: 68.0431
Validation Gender Acc: 0.6120, Validation Style Acc: 0.2029, Validation Total Acc: 0.4075
Current Learning Rate: 0.005


Training Epoch 24/100: 100%|██████████| 382/382 [14:03<00:00,  2.21s/it]


Epoch [24/100], Train Gender Loss: 0.6715, Train Style Loss: 54.0372, Train Total Loss: 54.7087
Train Gender Acc: 0.5409, Train Style Acc: 0.2159, Train Total Acc: 0.3784


Validation Epoch 24/100: 100%|██████████| 30/30 [01:01<00:00,  2.04s/it]


Validation Gender Loss: 0.6626, Validation Style Loss: 63.5927, Validation Total Loss: 64.2553
Validation Gender Acc: 0.5857, Validation Style Acc: 0.2355, Validation Total Acc: 0.4106
Current Learning Rate: 0.005


Training Epoch 25/100: 100%|██████████| 382/382 [14:20<00:00,  2.25s/it]


Epoch [25/100], Train Gender Loss: 0.6709, Train Style Loss: 44.3085, Train Total Loss: 44.9794
Train Gender Acc: 0.5446, Train Style Acc: 0.2792, Train Total Acc: 0.4119


Validation Epoch 25/100: 100%|██████████| 30/30 [01:01<00:00,  2.03s/it]


Validation Gender Loss: 0.6683, Validation Style Loss: 62.2873, Validation Total Loss: 62.9556
Validation Gender Acc: 0.5100, Validation Style Acc: 0.2355, Validation Total Acc: 0.3728
Current Learning Rate: 0.005


Training Epoch 26/100: 100%|██████████| 382/382 [14:30<00:00,  2.28s/it]


Epoch [26/100], Train Gender Loss: 0.6704, Train Style Loss: 33.6901, Train Total Loss: 34.3606
Train Gender Acc: 0.5399, Train Style Acc: 0.3471, Train Total Acc: 0.4435


Validation Epoch 26/100: 100%|██████████| 30/30 [01:02<00:00,  2.07s/it]


Validation Gender Loss: 0.6708, Validation Style Loss: 62.5961, Validation Total Loss: 63.2670
Validation Gender Acc: 0.5321, Validation Style Acc: 0.2923, Validation Total Acc: 0.4122
Current Learning Rate: 0.005


Training Epoch 27/100: 100%|██████████| 382/382 [14:35<00:00,  2.29s/it]


Epoch [27/100], Train Gender Loss: 0.6715, Train Style Loss: 23.5235, Train Total Loss: 24.1951
Train Gender Acc: 0.5477, Train Style Acc: 0.4117, Train Total Acc: 0.4797


Validation Epoch 27/100: 100%|██████████| 30/30 [01:02<00:00,  2.07s/it]


Validation Gender Loss: 0.6691, Validation Style Loss: 72.7610, Validation Total Loss: 73.4301
Validation Gender Acc: 0.5405, Validation Style Acc: 0.3039, Validation Total Acc: 0.4222
Current Learning Rate: 0.005


Training Epoch 28/100: 100%|██████████| 382/382 [14:35<00:00,  2.29s/it]


Epoch [28/100], Train Gender Loss: 0.6692, Train Style Loss: 16.3858, Train Total Loss: 17.0550
Train Gender Acc: 0.5505, Train Style Acc: 0.4604, Train Total Acc: 0.5055


Validation Epoch 28/100: 100%|██████████| 30/30 [01:03<00:00,  2.11s/it]


Validation Gender Loss: 0.6611, Validation Style Loss: 66.4254, Validation Total Loss: 67.0865
Validation Gender Acc: 0.5626, Validation Style Acc: 0.3470, Validation Total Acc: 0.4548
Current Learning Rate: 0.005


Training Epoch 29/100:  29%|██▉       | 112/382 [04:17<10:20,  2.30s/it]


KeyboardInterrupt: 

In [None]:
# 체크포인트 저장된 모델 사용할거면 아래 코드 경로 수정하여 사용
load_checkpoint(filename="C:/Users/SW/Desktop/데크캠/mission1-2/checkpoint_multi.pth", weights_only=False)


def evaluate_with_misclassifications(model, test_loader, criterion, num_classes, class_labels, test_dataset):
    model.eval()  # 모델을 평가 모드로 전환
    test_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # 클래스별 정확도를 계산하기 위한 변수
    correct_per_class = [0] * num_classes  # 각 클래스별로 맞은 개수를 저장
    total_per_class = [0] * num_classes    # 각 클래스별로 총 샘플 수를 저장

    # 잘못 분류된 샘플 기록
    misclassified = {i: [] for i in range(num_classes)}  # 각 클래스별로 잘못 분류된 샘플 기록

    # no_grad 사용
    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(test_loader, desc="Evaluating")):  # images와 labels만 사용
            images, labels = images.to(device), labels.to(device)

            # 모델에 입력 후 출력 계산
            outputs = model(images)

            # 손실 계산
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # 정확도 계산
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            # 각 클래스별 맞은 개수와 총 샘플 수 추적
            for j in range(len(labels)):
                label = labels[j].item()
                pred = predicted[j].item()

                # 잘못 분류된 경우 기록
                if label != pred:
                    img_path = test_dataset.samples[i * test_loader.batch_size + j][0]  # 파일 경로 추출
                    misclassified[label].append((img_path, pred))

                # 정확한 예측일 경우
                if label == pred:
                    correct_per_class[label] += 1
                total_per_class[label] += 1

    # 평균 손실 및 전체 정확도 계산
    avg_loss = test_loss / len(test_loader)
    overall_accuracy = correct_predictions / total_samples

    # 클래스별 정확도 계산
    class_accuracy = []
    for i in range(num_classes):
        if total_per_class[i] > 0:
            accuracy = correct_per_class[i] / total_per_class[i]
        else:
            accuracy = 0.0
        class_accuracy.append(accuracy)

    # 결과 출력 (클래스 라벨과 함께)
    print(f"Test Loss: {avg_loss:.4f}, Overall Test Accuracy: {overall_accuracy:.4f}")
    for i, acc in enumerate(class_accuracy):
        print(f"{class_labels[i]}: Accuracy = {acc:.4f}")

    return avg_loss, overall_accuracy, class_accuracy, misclassified


test_loader = val_loader
num_classes = 31
class_labels = train_dataset_2.classes
test_dataset = val_dataset

avg_loss, overall_accuracy, class_accuracy, misclassified = evaluate_with_misclassifications(model, test_loader, criterion, num_classes, class_labels, test_dataset)

Evaluating: 100%|██████████████████████████████████████████████████████████████████████| 30/30 [02:12<00:00,  4.41s/it]

Test Loss: 2.3989, Overall Test Accuracy: 0.5489
M_bold: Accuracy = 0.5614
M_hiphop: Accuracy = 0.5000
M_hippie: Accuracy = 0.6098
M_ivy: Accuracy = 0.6329
M_metrosexual: Accuracy = 0.4828
M_mods: Accuracy = 0.5875
M_normcore: Accuracy = 0.0980
M_sportivecasual: Accuracy = 0.4423
W_athleisure: Accuracy = 0.6429
W_bodyconscious: Accuracy = 0.6522
W_cityglam: Accuracy = 0.5000
W_classic: Accuracy = 0.6818
W_disco: Accuracy = 0.4000
W_ecology: Accuracy = 0.5294
W_feminine: Accuracy = 0.7273
W_genderless: Accuracy = 0.7500
W_grunge: Accuracy = 0.6000
W_hiphop: Accuracy = 0.5000
W_hippie: Accuracy = 0.5714
W_kitsch: Accuracy = 0.5909
W_lingerie: Accuracy = 0.6000
W_lounge: Accuracy = 0.1250
W_military: Accuracy = 0.4444
W_minimal: Accuracy = 0.5429
W_normcore: Accuracy = 0.2500
W_oriental: Accuracy = 0.6667
W_popart: Accuracy = 0.7500
W_powersuit: Accuracy = 0.7353
W_punk: Accuracy = 0.3333
W_space: Accuracy = 0.6000
W_sportivecasual: Accuracy = 0.6875





In [None]:
import os

# 지정된 클래스에서 잘못 분류된 샘플 출력
target_class = 'M_normcore'
target_class_index = class_labels.index(target_class)

if len(misclassified[target_class_index]) > 0:
    print(f"\nMisclassifications for class '{target_class}':")
    for img_path, pred_class in misclassified[target_class_index]:
        file_name = os.path.basename(img_path)  # 파일명만 추출
        print(f"File: {file_name} misclassified as {class_labels[pred_class]}")
else:
    print(f"\nNo misclassifications for class '{target_class}'.")


Misclassifications for class 'M_normcore':
File: W_00117_19_normcore_M.jpg misclassified as M_metrosexual
File: W_00551_19_normcore_M.jpg misclassified as M_ivy
File: W_00831_19_normcore_M.jpg misclassified as M_sportivecasual
File: W_01410_19_normcore_M.jpg misclassified as M_sportivecasual
File: W_01552_19_normcore_M.jpg misclassified as M_hippie
File: W_02705_19_normcore_M.jpg misclassified as W_hippie
File: W_03003_19_normcore_M.jpg misclassified as M_hiphop
File: W_06609_19_normcore_M.jpg misclassified as M_metrosexual
File: W_06860_19_normcore_M.jpg misclassified as W_ecology
File: W_06917_19_normcore_M.jpg misclassified as M_metrosexual
File: W_06966_19_normcore_M.jpg misclassified as M_sportivecasual
File: W_07058_19_normcore_M.jpg misclassified as M_metrosexual
File: W_07077_19_normcore_M.jpg misclassified as W_hiphop
File: W_07120_19_normcore_M.jpg misclassified as M_bold
File: W_10103_19_normcore_M.jpg misclassified as M_mods
File: W_12412_19_normcore_M.jpg misclassified as