In [5]:
!ls

sample_data  나이분류.zip


In [1]:
import zipfile
import os

# 압축 파일 경로
zip_path = './나이분류.zip'

# 압축을 풀 폴더 경로
extract_path = './'

# 압축 풀기
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# 압축이 풀린 파일 목록 확인
os.listdir(extract_path)


['.config', '나이분류', '나이분류.zip', 'sample_data']

In [15]:
!ls -al

total 341556
drwxr-xr-x 5 root root      4096 Sep 23 00:16 나이분류
drwxr-xr-x 1 root root      4096 Sep 23 00:16 .
drwxr-xr-x 1 root root      4096 Sep 23 00:00 ..
drwxr-xr-x 4 root root      4096 Sep 19 13:25 .config
drwxr-xr-x 1 root root      4096 Sep 19 13:25 sample_data
-rw-r--r-- 1 root root 349728946 Sep 23 00:15 나이분류.zip


In [2]:
cd 나이분류

/content/나이분류


In [3]:
pwd

'/content/나이분류'

In [6]:
import os
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # TQDM import

# Custom dataset
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.data = []

        for label in range(len(self.classes)):
            class_folder = os.path.join(root_dir, self.classes[label])
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')  # 이미지 RGB로 변환
        if self.transform:
            image = self.transform(image)
        return image, label

# 경로 및 배치 크기 설정
data_dir = "."
batch_size = 32

# 데이터 증강 포함한 이미지 전처리
transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),  # 랜덤 가로 뒤집기
    T.RandomRotation(10),  # 랜덤 회전
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # 색상 변형
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 학습 및 검증 데이터셋 생성
train_dataset = CustomDataset(os.path.join(data_dir, 'train'), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = CustomDataset(os.path.join(data_dir, 'valid'), transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 사전 학습된 resnext50_32x4d 모델 사용
model = models.resnext50_32x4d(pretrained=True)

# 전이 학습을 위해 일부 가중치 고정 (freeze)
for param in model.parameters():
    param.requires_grad = False

# 출력 레이어를 분류하려는 클래스 수에 맞게 수정
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.5),  # Dropout 추가
    torch.nn.Linear(model.fc.in_features, len(train_dataset.classes))
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 손실 함수 및 옵티마이저 설정
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# 학습률 스케줄러 추가 (학습률 점진적 감소)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 학습 과정 (TQDM으로 진행 상황 시각화)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # TQDM으로 학습 진행 표시
    train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, labels in train_loader_iter:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_loader_iter.set_postfix(loss=running_loss / len(train_loader))

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

    # 학습률 스케줄러 업데이트
    scheduler.step()

# 모델 저장
torch.save(model.state_dict(), 'resnext_model.pth')

# 검증 평가
model.eval()
correct = 0
total = 0
with torch.no_grad():
    valid_loader_iter = tqdm(valid_loader, desc="Validating")  # 검증도 TQDM으로 진행 상황 시각화
    for images, labels in valid_loader_iter:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Validation Accuracy: {accuracy}%')


Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth
100%|██████████| 95.8M/95.8M [00:00<00:00, 155MB/s]
Epoch 1/10: 100%|██████████| 431/431 [03:14<00:00,  2.22it/s, loss=1.02]


Epoch 1/10, Loss: 1.0211076253804696


Epoch 2/10: 100%|██████████| 431/431 [03:15<00:00,  2.20it/s, loss=0.976]


Epoch 2/10, Loss: 0.9755381224050322


Epoch 3/10: 100%|██████████| 431/431 [03:14<00:00,  2.21it/s, loss=0.964]


Epoch 3/10, Loss: 0.9640533044553965


Epoch 4/10: 100%|██████████| 431/431 [03:14<00:00,  2.21it/s, loss=0.965]


Epoch 4/10, Loss: 0.9648061083530328


Epoch 5/10: 100%|██████████| 431/431 [03:14<00:00,  2.21it/s, loss=0.954]


Epoch 5/10, Loss: 0.9536708980591281


Epoch 6/10: 100%|██████████| 431/431 [03:16<00:00,  2.19it/s, loss=0.966]


Epoch 6/10, Loss: 0.9662154117759032


Epoch 7/10: 100%|██████████| 431/431 [03:14<00:00,  2.21it/s, loss=0.965]


Epoch 7/10, Loss: 0.964905275961086


Epoch 8/10: 100%|██████████| 431/431 [03:14<00:00,  2.21it/s, loss=0.939]


Epoch 8/10, Loss: 0.9388825762022952


Epoch 9/10: 100%|██████████| 431/431 [03:15<00:00,  2.20it/s, loss=0.941]


Epoch 9/10, Loss: 0.9413373774552843


Epoch 10/10: 100%|██████████| 431/431 [03:15<00:00,  2.20it/s, loss=0.93]


Epoch 10/10, Loss: 0.9297795649747561


Validating: 100%|██████████| 54/54 [00:23<00:00,  2.33it/s]

Validation Accuracy: 55.8584686774942%





In [7]:
import os
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm  # TQDM import

# Custom dataset
class CustomDataset(Dataset): 
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.data = []

        for label in range(len(self.classes)):
            class_folder = os.path.join(root_dir, self.classes[label])
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')  # 이미지 RGB로 변환
        if self.transform:
            image = self.transform(image)
        return image, label

# 경로 및 배치 크기 설정
data_dir = "."
batch_size = 32

# 데이터 증강 포함한 이미지 전처리
transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),  # 랜덤 가로 뒤집기
    T.RandomRotation(10),  # 랜덤 회전
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # 색상 변형
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 학습 및 검증 데이터셋 생성
train_dataset = CustomDataset(os.path.join(data_dir, 'train'), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = CustomDataset(os.path.join(data_dir, 'valid'), transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 사전 학습된 resnext50_32x4d 모델 사용
model = models.resnext50_32x4d(pretrained=True)

# 모든 레이어 학습 가능하도록 설정 (고정 부분 제거)
# 이제 모든 파라미터가 학습에 참여함
for param in model.parameters():
    param.requires_grad = True

# 출력 레이어를 분류하려는 클래스 수에 맞게 수정
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.5),  # Dropout 추가
    torch.nn.Linear(model.fc.in_features, len(train_dataset.classes))
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 손실 함수 및 옵티마이저 설정
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.00005, momentum=0.9)

# 학습률 스케줄러 추가 (학습률 점진적 감소)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 학습 과정 (TQDM으로 진행 상황 시각화)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # TQDM으로 학습 진행 표시
    train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, labels in train_loader_iter:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_loader_iter.set_postfix(loss=running_loss / len(train_loader))

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

    # 학습률 스케줄러 업데이트
    scheduler.step()

# 모델 저장
torch.save(model.state_dict(), 'resnext_model.pth')

# 검증 평가
model.eval()
correct = 0
total = 0
with torch.no_grad():
    valid_loader_iter = tqdm(valid_loader, desc="Validating")  # 검증도 TQDM으로 진행 상황 시각화
    for images, labels in valid_loader_iter:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Validation Accuracy: {accuracy}%')


Epoch 1/10: 100%|██████████| 431/431 [05:27<00:00,  1.32it/s, loss=1.08]


Epoch 1/10, Loss: 1.0825531928002696


Epoch 2/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.962]


Epoch 2/10, Loss: 0.96152822006053


Epoch 3/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.863]


Epoch 3/10, Loss: 0.8625074017905304


Epoch 4/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.795]


Epoch 4/10, Loss: 0.7950573925905604


Epoch 5/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.744]


Epoch 5/10, Loss: 0.744026968692129


Epoch 6/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.717]


Epoch 6/10, Loss: 0.7173273258447094


Epoch 7/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.698]


Epoch 7/10, Loss: 0.6979419654040768


Epoch 8/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.675]


Epoch 8/10, Loss: 0.6747959918599671


Epoch 9/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.662]


Epoch 9/10, Loss: 0.6621346672145507


Epoch 10/10: 100%|██████████| 431/431 [05:26<00:00,  1.32it/s, loss=0.669]


Epoch 10/10, Loss: 0.66901260303234


Validating: 100%|██████████| 54/54 [00:23<00:00,  2.34it/s]

Validation Accuracy: 69.3155452436195%





In [2]:
import os
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from tqdm import tqdm  # TQDM import
import torch.nn as nn
from collections import Counter

# Custom dataset
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.data = []
        
        for label in range(len(self.classes)):
            class_folder = os.path.join(root_dir, self.classes[label])
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')  # 이미지 RGB로 변환
        if self.transform:
            image = self.transform(image)
        return image, label

# 경로 및 배치 크기 설정
data_dir = "."
batch_size = 32

# 데이터 증강 포함한 이미지 전처리
transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(15),  # 더 큰 회전 각도
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 이미지 이동
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),  # 색상 변형
    T.GaussianBlur(kernel_size=3),  # Gaussian Blur 추가
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 학습 및 검증 데이터셋 생성
train_dataset = CustomDataset(os.path.join(data_dir, 'train'), transform=transform)
valid_dataset = CustomDataset(os.path.join(data_dir, 'valid'), transform=transform)

# 각 클래스의 데이터 개수 계산 (Counter 사용)
class_counts = Counter([label for _, label in train_dataset])

# 클래스별로 가중치를 부여 (데이터 개수의 역수 사용)
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}

# 각 샘플의 가중치를 리스트로 변환
sample_weights = [class_weights[label] for _, label in train_dataset]

# WeightedRandomSampler 생성
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))

# WeightedRandomSampler를 적용한 DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 사전 학습된 resnext50_32x4d 모델 사용
from torchvision.models import ResNeXt50_32X4D_Weights

weights = ResNeXt50_32X4D_Weights.DEFAULT
model = models.resnext50_32x4d(weights=weights)

# 모든 레이어 학습 가능하도록 설정 (전이 학습 제외)
for param in model.parameters():
    param.requires_grad = True

# 출력 레이어를 분류하려는 클래스 수에 맞게 수정 (드롭아웃 비율 0.3으로 설정)
model.fc = nn.Sequential(
    nn.Dropout(0.3),  # Dropout 비율을 0.3으로 조정
    nn.Linear(model.fc.in_features, len(train_dataset.classes))
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Early Stopping 설정
best_accuracy = 0
patience = 5  # 개선되지 않는 에포크를 허용하는 최대 수
counter = 0

# 손실 함수 및 AdamW 옵티마이저 설정 (학습률 0.0001로 감소)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

# 학습률 스케줄러 추가 (학습률 점진적 감소, gamma=0.5, step_size=10)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# 학습 과정 (TQDM으로 진행 상황 시각화, 에포크 30)
num_epochs = 30
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # TQDM으로 학습 진행 표시
    train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for images, labels in train_loader_iter:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        train_loader_iter.set_postfix(loss=running_loss / len(train_loader))
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
    
    # 학습률 스케줄러 업데이트
    scheduler.step()

    # 검증 평가
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy}%')
    
    # Early Stopping 기준
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        counter = 0  # 성능 개선 시 카운터 초기화
        torch.save(model.state_dict(), 'best_model.pth')  # 최적의 모델 저장
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping 적용")
            break

Epoch 1/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:27<00:00,  4.16it/s, loss=0.761]


Epoch 1/30, Loss: 0.7608831753460032
Validation Accuracy: 69.77958236658932%


Epoch 2/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:22<00:00,  4.26it/s, loss=0.626]


Epoch 2/30, Loss: 0.6255019191449652
Validation Accuracy: 70.70765661252901%


Epoch 3/30: 100%|█████████████████████████████████████████████████████████| 863/863 [03:21<00:00,  4.27it/s, loss=0.57]


Epoch 3/30, Loss: 0.5699071202984
Validation Accuracy: 73.20185614849188%


Epoch 4/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:24<00:00,  4.23it/s, loss=0.527]


Epoch 4/30, Loss: 0.5266329214961062
Validation Accuracy: 72.62180974477958%


Epoch 5/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:20<00:00,  4.31it/s, loss=0.495]


Epoch 5/30, Loss: 0.495164231899025
Validation Accuracy: 72.73781902552204%


Epoch 6/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:20<00:00,  4.31it/s, loss=0.456]


Epoch 6/30, Loss: 0.45606091718684755
Validation Accuracy: 72.2737819025522%


Epoch 7/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:20<00:00,  4.30it/s, loss=0.432]


Epoch 7/30, Loss: 0.432258178131398
Validation Accuracy: 74.07192575406033%


Epoch 8/30: 100%|██████████████████████████████████████████████████████████| 863/863 [03:24<00:00,  4.22it/s, loss=0.4]


Epoch 8/30, Loss: 0.3998504717373489
Validation Accuracy: 75.92807424593967%


Epoch 9/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:21<00:00,  4.28it/s, loss=0.365]


Epoch 9/30, Loss: 0.3652687069559567
Validation Accuracy: 72.62180974477958%


Epoch 10/30: 100%|███████████████████████████████████████████████████████| 863/863 [03:19<00:00,  4.32it/s, loss=0.344]


Epoch 10/30, Loss: 0.3436410233447477
Validation Accuracy: 76.21809744779583%


Epoch 11/30: 100%|███████████████████████████████████████████████████████| 863/863 [03:18<00:00,  4.34it/s, loss=0.264]


Epoch 11/30, Loss: 0.2644685557886941
Validation Accuracy: 75.0%


Epoch 12/30: 100%|███████████████████████████████████████████████████████| 863/863 [03:18<00:00,  4.36it/s, loss=0.224]


Epoch 12/30, Loss: 0.22382964441813918
Validation Accuracy: 74.65197215777262%


Epoch 13/30: 100%|███████████████████████████████████████████████████████| 863/863 [03:19<00:00,  4.33it/s, loss=0.207]


Epoch 13/30, Loss: 0.2073700617208887
Validation Accuracy: 73.78190255220417%


Epoch 14/30: 100%|███████████████████████████████████████████████████████| 863/863 [03:17<00:00,  4.36it/s, loss=0.187]


Epoch 14/30, Loss: 0.18740442308013436
Validation Accuracy: 74.94199535962878%


Epoch 15/30: 100%|████████████████████████████████████████████████████████| 863/863 [03:17<00:00,  4.37it/s, loss=0.17]


Epoch 15/30, Loss: 0.1696912146303997
Validation Accuracy: 73.72389791183295%
Early stopping 적용


In [2]:
import os
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from tqdm import tqdm  # TQDM import
import torch.nn as nn
from collections import Counter

# Custom dataset
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.data = []
        
        for label in range(len(self.classes)):
            class_folder = os.path.join(root_dir, self.classes[label])
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')  # 이미지 RGB로 변환
        if self.transform:
            image = self.transform(image)
        return image, label

# 경로 및 배치 크기 설정
data_dir = "."
batch_size = 64

# 데이터 증강 포함한 이미지 전처리
transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(15),  # 더 큰 회전 각도
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 이미지 이동
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),  # 색상 변형
    T.GaussianBlur(kernel_size=3),  # Gaussian Blur 추가
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 학습 및 검증 데이터셋 생성
train_dataset = CustomDataset(os.path.join(data_dir, 'train'), transform=transform)
valid_dataset = CustomDataset(os.path.join(data_dir, 'valid'), transform=transform)

# 각 클래스의 데이터 개수 계산 (Counter 사용)
class_counts = Counter([label for _, label in train_dataset])

# 클래스별로 가중치를 부여 (데이터 개수의 역수 사용)
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}

# 각 샘플의 가중치를 리스트로 변환
sample_weights = [class_weights[label] for _, label in train_dataset]

# WeightedRandomSampler 생성
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights))

# WeightedRandomSampler를 적용한 DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# 사전 학습된 resnext50_32x4d 모델 사용
from torchvision.models import ResNeXt50_32X4D_Weights

weights = ResNeXt50_32X4D_Weights.DEFAULT
model = models.resnext50_32x4d(weights=weights)

# 모든 레이어 학습 가능하도록 설정 (전이 학습 제외)
for param in model.parameters():
    param.requires_grad = True

# 출력 레이어를 분류하려는 클래스 수에 맞게 수정 (드롭아웃 비율 0.3으로 설정)
model.fc = nn.Sequential(
    nn.Dropout(0.3),  # Dropout 비율을 0.3으로 조정
    nn.Linear(model.fc.in_features, len(train_dataset.classes))
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Early Stopping 설정
best_accuracy = 0
patience = 30  # 개선되지 않는 에포크를 허용하는 최대 수
counter = 0

# 손실 함수 및 AdamW 옵티마이저 설정 (학습률 0.0001로 감소)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

# 학습률 스케줄러 추가 (학습률 점진적 감소, gamma=0.5, step_size=10)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# 학습 과정 (TQDM으로 진행 상황 시각화, 에포크 30)
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # TQDM으로 학습 진행 표시
    train_loader_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for images, labels in train_loader_iter:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        train_loader_iter.set_postfix(loss=running_loss / len(train_loader))
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")
    
    # 학습률 스케줄러 업데이트
    scheduler.step()

    # 검증 평가
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy}%')
    
    # Early Stopping 기준
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        counter = 0  # 성능 개선 시 카운터 초기화
        torch.save(model.state_dict(), 'best_model.pth')  # 최적의 모델 저장
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping 적용")
            break

# 베스트 모델의 혼동 행렬 시각화 및 성능 출력
if best_cm is not None:
    plt.figure(figsize=(8, 6))
    sns.heatmap(best_cm, annot=True, fmt='d', cmap='Blues', xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Best Model Confusion Matrix\nValidation Accuracy: {best_accuracy:.3f}%')
    plt.show()


Epoch 1/100: 100%|███████████████████████████████████████████████████████| 216/216 [07:55<00:00,  2.20s/it, loss=0.791]


Epoch 1/100, Loss: 0.7910682476229138
Validation Accuracy: 69.60556844547564%


Epoch 2/100: 100%|███████████████████████████████████████████████████████| 216/216 [08:07<00:00,  2.26s/it, loss=0.593]


Epoch 2/100, Loss: 0.5930374179173399
Validation Accuracy: 69.66357308584686%


Epoch 3/100: 100%|███████████████████████████████████████████████████████| 216/216 [07:33<00:00,  2.10s/it, loss=0.539]


Epoch 3/100, Loss: 0.5385223522230431
Validation Accuracy: 73.25986078886311%


Epoch 4/100: 100%|███████████████████████████████████████████████████████| 216/216 [05:21<00:00,  1.49s/it, loss=0.491]


Epoch 4/100, Loss: 0.491256192188572
Validation Accuracy: 73.89791183294663%


Epoch 5/100: 100%|███████████████████████████████████████████████████████| 216/216 [05:55<00:00,  1.65s/it, loss=0.451]


Epoch 5/100, Loss: 0.4514921543498834
Validation Accuracy: 73.60788863109049%


Epoch 6/100: 100%|███████████████████████████████████████████████████████| 216/216 [05:43<00:00,  1.59s/it, loss=0.415]


Epoch 6/100, Loss: 0.41532192744866564
Validation Accuracy: 74.76798143851508%


Epoch 7/100: 100%|███████████████████████████████████████████████████████| 216/216 [05:42<00:00,  1.58s/it, loss=0.384]


Epoch 7/100, Loss: 0.3838025319079558
Validation Accuracy: 73.78190255220417%


Epoch 8/100: 100%|███████████████████████████████████████████████████████| 216/216 [05:35<00:00,  1.55s/it, loss=0.352]


Epoch 8/100, Loss: 0.35245026151339215
Validation Accuracy: 72.8538283062645%


Epoch 9/100: 100%|███████████████████████████████████████████████████████| 216/216 [06:05<00:00,  1.69s/it, loss=0.326]


Epoch 9/100, Loss: 0.32595981367760235
Validation Accuracy: 75.92807424593967%


Epoch 10/100: 100%|██████████████████████████████████████████████████████| 216/216 [05:57<00:00,  1.65s/it, loss=0.283]


Epoch 10/100, Loss: 0.28328480157587266
Validation Accuracy: 75.34802784222738%


Epoch 11/100: 100%|██████████████████████████████████████████████████████| 216/216 [04:30<00:00,  1.25s/it, loss=0.244]


Epoch 11/100, Loss: 0.24359282692549405
Validation Accuracy: 76.16009280742459%


Epoch 12/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:33<00:00,  1.01it/s, loss=0.197]


Epoch 12/100, Loss: 0.19674221840169695
Validation Accuracy: 75.63805104408353%


Epoch 13/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:31<00:00,  1.02it/s, loss=0.182]


Epoch 13/100, Loss: 0.18153227134435265
Validation Accuracy: 75.52204176334106%


Epoch 14/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:32<00:00,  1.01it/s, loss=0.161]


Epoch 14/100, Loss: 0.16116474863762656
Validation Accuracy: 75.05800464037122%


Epoch 15/100: 100%|██████████████████████████████████████████████████████| 216/216 [04:05<00:00,  1.14s/it, loss=0.146]


Epoch 15/100, Loss: 0.1458194467963444
Validation Accuracy: 76.62412993039443%


Epoch 16/100: 100%|███████████████████████████████████████████████████████| 216/216 [03:44<00:00,  1.04s/it, loss=0.13]


Epoch 16/100, Loss: 0.12993475903446475
Validation Accuracy: 74.0139211136891%


Epoch 17/100: 100%|███████████████████████████████████████████████████████| 216/216 [03:36<00:00,  1.00s/it, loss=0.12]


Epoch 17/100, Loss: 0.12000551600768058
Validation Accuracy: 74.53596287703016%


Epoch 18/100: 100%|███████████████████████████████████████████████████████| 216/216 [03:33<00:00,  1.01it/s, loss=0.11]


Epoch 18/100, Loss: 0.1100142245111918
Validation Accuracy: 74.88399071925754%


Epoch 19/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:35<00:00,  1.00it/s, loss=0.095]


Epoch 19/100, Loss: 0.09500225612255572
Validation Accuracy: 75.75406032482599%


Epoch 20/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:23<00:00,  1.06it/s, loss=0.097]


Epoch 20/100, Loss: 0.09698562211081109
Validation Accuracy: 75.63805104408353%


Epoch 21/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:34<00:00,  1.01it/s, loss=0.0735]


Epoch 21/100, Loss: 0.07349310338893836
Validation Accuracy: 75.52204176334106%


Epoch 22/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:30<00:00,  1.03it/s, loss=0.064]


Epoch 22/100, Loss: 0.06399960341406297
Validation Accuracy: 75.34802784222738%


Epoch 23/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:37<00:00,  1.01s/it, loss=0.0611]


Epoch 23/100, Loss: 0.06113607194964533
Validation Accuracy: 76.62412993039443%


Epoch 24/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:30<00:00,  1.03it/s, loss=0.0532]


Epoch 24/100, Loss: 0.05321656477516862
Validation Accuracy: 75.52204176334106%


Epoch 25/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:23<00:00,  1.06it/s, loss=0.0546]


Epoch 25/100, Loss: 0.0546055160467168
Validation Accuracy: 75.5800464037123%


Epoch 26/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:27<00:00,  1.04it/s, loss=0.0464]


Epoch 26/100, Loss: 0.046384206754198576
Validation Accuracy: 76.56612529002321%


Epoch 27/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:29<00:00,  1.03it/s, loss=0.0428]


Epoch 27/100, Loss: 0.04280897003115603
Validation Accuracy: 75.29002320185614%


Epoch 28/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:30<00:00,  1.53s/it, loss=0.0436]


Epoch 28/100, Loss: 0.043597310981971935
Validation Accuracy: 75.63805104408353%


Epoch 29/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:25<00:00,  1.51s/it, loss=0.0428]


Epoch 29/100, Loss: 0.042794372920912725
Validation Accuracy: 75.9860788863109%


Epoch 30/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:32<00:00,  1.54s/it, loss=0.0432]


Epoch 30/100, Loss: 0.043197707899337356
Validation Accuracy: 75.23201856148492%


Epoch 31/100: 100%|█████████████████████████████████████████████████████| 216/216 [04:18<00:00,  1.20s/it, loss=0.0339]


Epoch 31/100, Loss: 0.03389811770412726
Validation Accuracy: 75.87006960556845%


Epoch 32/100: 100%|██████████████████████████████████████████████████████| 216/216 [03:37<00:00,  1.01s/it, loss=0.031]


Epoch 32/100, Loss: 0.0310487362321173
Validation Accuracy: 76.10208816705337%


Epoch 33/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:28<00:00,  1.04it/s, loss=0.0289]


Epoch 33/100, Loss: 0.028880091047203342
Validation Accuracy: 76.16009280742459%


Epoch 34/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:33<00:00,  1.01it/s, loss=0.0314]


Epoch 34/100, Loss: 0.03140771496044989
Validation Accuracy: 74.88399071925754%


Epoch 35/100: 100%|█████████████████████████████████████████████████████| 216/216 [04:11<00:00,  1.16s/it, loss=0.0261]


Epoch 35/100, Loss: 0.026074306739987892
Validation Accuracy: 75.34802784222738%


Epoch 36/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:48<00:00,  1.61s/it, loss=0.0241]


Epoch 36/100, Loss: 0.024114842201082932
Validation Accuracy: 75.87006960556845%


Epoch 37/100: 100%|█████████████████████████████████████████████████████| 216/216 [04:52<00:00,  1.35s/it, loss=0.0298]


Epoch 37/100, Loss: 0.02975374093943241
Validation Accuracy: 75.52204176334106%


Epoch 38/100: 100%|██████████████████████████████████████████████████████| 216/216 [05:14<00:00,  1.45s/it, loss=0.022]


Epoch 38/100, Loss: 0.02202541112880378
Validation Accuracy: 75.63805104408353%


Epoch 39/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:31<00:00,  1.53s/it, loss=0.0238]


Epoch 39/100, Loss: 0.023849894059283002
Validation Accuracy: 75.23201856148492%


Epoch 40/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:21<00:00,  1.49s/it, loss=0.0234]


Epoch 40/100, Loss: 0.023425545124660455
Validation Accuracy: 75.69605568445476%


Epoch 41/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:21<00:00,  1.49s/it, loss=0.0207]


Epoch 41/100, Loss: 0.020722217628240794
Validation Accuracy: 76.45011600928075%


Epoch 42/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:06<00:00,  1.42s/it, loss=0.0221]


Epoch 42/100, Loss: 0.022119192573612695
Validation Accuracy: 76.39211136890951%


Epoch 43/100: 100%|█████████████████████████████████████████████████████| 216/216 [04:51<00:00,  1.35s/it, loss=0.0206]


Epoch 43/100, Loss: 0.020644812814242432
Validation Accuracy: 75.0%


Epoch 44/100: 100%|█████████████████████████████████████████████████████| 216/216 [03:44<00:00,  1.04s/it, loss=0.0196]


Epoch 44/100, Loss: 0.019602143966444094
Validation Accuracy: 75.92807424593967%


Epoch 45/100: 100%|█████████████████████████████████████████████████████| 216/216 [05:27<00:00,  1.51s/it, loss=0.0164]


Epoch 45/100, Loss: 0.01635756100491832
Validation Accuracy: 74.82598607888632%
Early stopping 적용
