In [433]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

In [462]:
# 하이퍼파라미터 설정
batch_size = 32
epochs = 150
learning_rate = 0.0001
num_classes = 7  # 얼굴 클래스 개수
val_split = 0.2  # 검증 데이터 비율
early_stop_patience = 3  # 조기 종료 기준 (연속으로 개선되지 않은 epoch 수)

In [463]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees= 10),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.GaussianBlur(kernel_size=(5,5)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [464]:
# 데이터셋 로드 (훈련 데이터만 사용)
dataset = datasets.ImageFolder(root='훈련데이터 경로/train_face_bts', transform=transform)
train_size = int((1 - val_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [465]:
print(f"클래스 목록 : {dataset.classes}")

클래스 목록 : ['jhope', 'jimin', 'jin', 'jungkook', 'rm', 'suga', 'v']


In [466]:
# ResNet 모델 불러오기 및 수정
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 마지막 층 수정
model = model.cuda() if torch.cuda.is_available() else model

In [467]:
# 손실 함수 및 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [468]:
# 학습 및 검증 루프
def train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, epochs, early_stop_patience):
    best_val_acc = 0.0
    no_improve_epochs = 0
    
    for epoch in range(epochs):
        # 학습 단계
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} - Training')
        
        for images, labels in train_bar:
            images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            train_bar.set_postfix(loss=running_loss / (train_bar.n + 1), acc=100 * correct / total)
        
        avg_train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        print(f'Epoch {epoch+1}, Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.2f}%')
        
        # 검증 단계
        model.eval()
        correct, total = 0, 0
        val_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} - Validation')
        
        with torch.no_grad():
            for images, labels in val_bar:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_bar.set_postfix(acc=100 * correct / total)
        
        val_acc = 100 * correct / total
        print(f'Validation Accuracy: {val_acc:.2f}%')
        
        # 조기 종료 조건 확인
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve_epochs = 0
            torch.save(model.state_dict(), 'resnet_best.pth')  # 최적 모델 저장
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= early_stop_patience:
                print(f'Early stopping at epoch {epoch+1}. No improvement for {early_stop_patience} epochs.')
                break

In [469]:
# 실행
train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, epochs, early_stop_patience)

Epoch 1/150 - Training: 100%|█████████████████████████████████████| 51/51 [00:12<00:00,  3.93it/s, acc=24.4, loss=1.87]


Epoch 1, Loss: 1.8678, Train Accuracy: 24.42%


Epoch 1/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.69it/s, acc=31.9]


Validation Accuracy: 31.87%


Epoch 2/150 - Training: 100%|█████████████████████████████████████| 51/51 [00:12<00:00,  4.08it/s, acc=43.2, loss=1.52]


Epoch 2, Loss: 1.5250, Train Accuracy: 43.19%


Epoch 2/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.73it/s, acc=38.5]


Validation Accuracy: 38.46%


Epoch 3/150 - Training: 100%|█████████████████████████████████████| 51/51 [00:12<00:00,  4.07it/s, acc=59.8, loss=1.13]


Epoch 3, Loss: 1.1296, Train Accuracy: 59.75%


Epoch 3/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.84it/s, acc=41.8]


Validation Accuracy: 41.76%


Epoch 4/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.10it/s, acc=68.1, loss=0.947]


Epoch 4, Loss: 0.9475, Train Accuracy: 68.10%


Epoch 4/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.93it/s, acc=51.6]


Validation Accuracy: 51.65%


Epoch 5/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.08it/s, acc=75.5, loss=0.716]


Epoch 5, Loss: 0.7160, Train Accuracy: 75.46%


Epoch 5/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.83it/s, acc=57.1]


Validation Accuracy: 57.14%


Epoch 6/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.09it/s, acc=81.6, loss=0.526]


Epoch 6, Loss: 0.5258, Train Accuracy: 81.60%


Epoch 6/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.66it/s, acc=54.9]


Validation Accuracy: 54.95%


Epoch 7/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.05it/s, acc=85.3, loss=0.417]


Epoch 7, Loss: 0.4171, Train Accuracy: 85.28%


Epoch 7/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.61it/s, acc=57.1]


Validation Accuracy: 57.14%


Epoch 8/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.07it/s, acc=88.2, loss=0.369]


Epoch 8, Loss: 0.3685, Train Accuracy: 88.22%


Epoch 8/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.80it/s, acc=64.8]


Validation Accuracy: 64.84%


Epoch 9/150 - Training: 100%|████████████████████████████████████| 51/51 [00:12<00:00,  4.05it/s, acc=87.9, loss=0.361]


Epoch 9, Loss: 0.3609, Train Accuracy: 87.85%


Epoch 9/150 - Validation: 100%|████████████████████████████████████████████████| 6/6 [00:01<00:00,  5.90it/s, acc=58.2]


Validation Accuracy: 58.24%


Epoch 10/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.09it/s, acc=89.4, loss=0.283]


Epoch 10, Loss: 0.2832, Train Accuracy: 89.45%


Epoch 10/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.76it/s, acc=68.1]


Validation Accuracy: 68.13%


Epoch 11/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.05it/s, acc=91.9, loss=0.276]


Epoch 11, Loss: 0.2761, Train Accuracy: 91.90%


Epoch 11/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.78it/s, acc=68.1]


Validation Accuracy: 68.13%


Epoch 12/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.08it/s, acc=94.8, loss=0.174]


Epoch 12, Loss: 0.1739, Train Accuracy: 94.85%


Epoch 12/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.73it/s, acc=70.3]


Validation Accuracy: 70.33%


Epoch 13/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.08it/s, acc=94.4, loss=0.181]


Epoch 13, Loss: 0.1805, Train Accuracy: 94.36%


Epoch 13/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.76it/s, acc=65.9]


Validation Accuracy: 65.93%


Epoch 14/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.06it/s, acc=96.2, loss=0.134]


Epoch 14, Loss: 0.1338, Train Accuracy: 96.20%


Epoch 14/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.95it/s, acc=69.2]


Validation Accuracy: 69.23%


Epoch 15/150 - Training: 100%|███████████████████████████████████| 51/51 [00:12<00:00,  4.07it/s, acc=97.2, loss=0.113]


Epoch 15, Loss: 0.1127, Train Accuracy: 97.18%


Epoch 15/150 - Validation: 100%|███████████████████████████████████████████████| 6/6 [00:01<00:00,  5.88it/s, acc=64.8]

Validation Accuracy: 64.84%
Early stopping at epoch 15. No improvement for 3 epochs.





In [397]:
# 최종 모델 저장
torch.save(model.state_dict(), 'resnet_final7.pth')

In [470]:
# 최종 모델 저장
torch.save(model.state_dict(), 'resnet_bts1.pth')