In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import warnings
warnings.filterwarnings('ignore')

import torch
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from timm import create_model
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import pandas as pd
from datetime import datetime
from sklearn.metrics import accuracy_score

from dataloaders.loaders import get_dataloaders
from models.train_utils import train_one_epoch, evaluate, EarlyStopping, multiclass_log_loss
from analysis.result_plotter import analyze_model_output

import matplotlib.pyplot as plt
import random
import numpy as np

# 시드 고정
def seed_everything(seed=28):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(28)

# 한글 폰트 설정
plt.rcParams['font.family'] = 'Malgun Gothic'
plt.rcParams['axes.unicode_minus'] = False

# ✅ 설정
data_root = '../data/train2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 50
patience = 5
batch_size = 32
base_log_dir = "../runs"
os.makedirs(base_log_dir, exist_ok=True)

# ✅ 데이터 로더 구성
dataloaders = get_dataloaders(data_root, batch_size=batch_size)

# 첫 번째 모델의 데이터 수 확인
first_model_name = list(dataloaders.keys())[0]
train_loader = dataloaders[first_model_name]['train']
val_loader = dataloaders[first_model_name]['val']
print(f"학습 데이터 개수: {len(train_loader.dataset)}")
print(f"검증 데이터 개수: {len(val_loader.dataset)}")
print(f"총 데이터 개수: {len(train_loader.dataset) + len(val_loader.dataset)}")

results = []
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

for model_name, loaders in dataloaders.items():
    print(f"\n🚀 학습 시작: {model_name}")
    
    log_dir = os.path.join(base_log_dir, f"{model_name}_{timestamp}")
    writer = SummaryWriter(log_dir=log_dir)

    model = create_model(model_name, pretrained=True, num_classes=len(loaders['classes'])).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    # scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience)

    for epoch in range(num_epochs):
        # print(f"\n📘 Epoch {epoch + 1}")
        train_loss = train_one_epoch(model, loaders['train'], criterion, optimizer, device)
        y_pred, y_prob, y_true, y_id, _ = evaluate(model, loaders['val'], device)

        acc = accuracy_score(y_true, y_pred)
        class_list = [str(i) for i in range(y_prob.shape[1])]
        prob_df = pd.DataFrame(y_prob, columns=class_list)
        prob_df.insert(0, 'ID', y_id)
        label_df = pd.DataFrame({'ID': y_id, 'label': [str(l) for l in y_true]})

        logloss = multiclass_log_loss(label_df, prob_df)

        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Validation", logloss, epoch)
        writer.add_scalar("Accuracy/Validation", acc, epoch)
        writer.add_scalar("LearningRate", scheduler.get_last_lr()[0], epoch)

        print(f"📘 Epoch {epoch + 1} ✅ Acc: {acc:.4f} | LogLoss: {logloss:.4f}")
        scheduler.step()
        early_stopping(logloss, model)
        if early_stopping.early_stop:
            print("⛔ Early stopping triggered.")
            break

    model.load_state_dict(early_stopping.best_model_state)
    writer.close()

    # ✅ 분석 결과 자동 저장
    val_indices = loaders['val'].dataset.indices if hasattr(loaders['val'].dataset, 'indices') else list(range(len(loaders['val'].dataset)))
    base_dataset = loaders['val'].dataset.dataset
    image_paths = [base_dataset.samples[i][0] for i in val_indices]

    analyze_model_output(
        model_name=model_name,
        timestamp=timestamp,
        image_paths=image_paths,
        y_pred=y_pred,
        y_prob=y_prob,
        y_true=y_true,
        class_names=loaders['classes']
    )

    results.append({'model': model_name, 'accuracy': acc, 'log_loss': logloss, 'timestamp': timestamp})

# ✅ 최종 비교 결과
df_result = pd.DataFrame(results).sort_values(by='log_loss')
display(df_result)


계층적 분할 확인
학습 데이터 개수: 25113
검증 데이터 개수: 6279
총 데이터 개수: 31392

🚀 학습 시작: swin_tiny_patch4_window7_224

📘 Epoch 1


                                                        

✅ Acc: 0.5749 | LogLoss: 1.7029

📘 Epoch 2


                                                        

✅ Acc: 0.7791 | LogLoss: 0.7752

📘 Epoch 3


                                                        

✅ Acc: 0.8323 | LogLoss: 0.5452

📘 Epoch 4


                                                        

✅ Acc: 0.8720 | LogLoss: 0.3974

📘 Epoch 5


                                                        

✅ Acc: 0.8801 | LogLoss: 0.3560

📘 Epoch 6


                                                        

✅ Acc: 0.8971 | LogLoss: 0.3076

📘 Epoch 7


                                                        

✅ Acc: 0.9089 | LogLoss: 0.2673

📘 Epoch 8


                                                        

✅ Acc: 0.9108 | LogLoss: 0.2619

📘 Epoch 9


                                                        

✅ Acc: 0.9186 | LogLoss: 0.2404

📘 Epoch 10


                                                        

✅ Acc: 0.9115 | LogLoss: 0.2482
📉 EarlyStopping: 1/5

📘 Epoch 11


                                                        

✅ Acc: 0.9259 | LogLoss: 0.2133

📘 Epoch 12


                                                        

✅ Acc: 0.9202 | LogLoss: 0.2189
📉 EarlyStopping: 1/5

📘 Epoch 13


                                                        

✅ Acc: 0.9255 | LogLoss: 0.2148
📉 EarlyStopping: 2/5

📘 Epoch 14


                                                        

✅ Acc: 0.9232 | LogLoss: 0.2173
📉 EarlyStopping: 3/5

📘 Epoch 15


                                                        

✅ Acc: 0.9339 | LogLoss: 0.1860

📘 Epoch 16


                                                        

✅ Acc: 0.9374 | LogLoss: 0.1767

📘 Epoch 17


                                                        

✅ Acc: 0.9409 | LogLoss: 0.1714

📘 Epoch 18


                                                        

✅ Acc: 0.9423 | LogLoss: 0.1672

📘 Epoch 19


                                                        

✅ Acc: 0.9454 | LogLoss: 0.1659

📘 Epoch 20


                                                        

✅ Acc: 0.9446 | LogLoss: 0.1576

📘 Epoch 21


                                                        

✅ Acc: 0.9470 | LogLoss: 0.1561

📘 Epoch 22


                                                        

✅ Acc: 0.9463 | LogLoss: 0.1575
📉 EarlyStopping: 1/5

📘 Epoch 23


                                                        

✅ Acc: 0.9468 | LogLoss: 0.1476

📘 Epoch 24


                                                        

✅ Acc: 0.9489 | LogLoss: 0.1523
📉 EarlyStopping: 1/5

📘 Epoch 25


                                                        

✅ Acc: 0.9530 | LogLoss: 0.1427

📘 Epoch 26


                                                        

✅ Acc: 0.9530 | LogLoss: 0.1397

📘 Epoch 27


                                                        

✅ Acc: 0.9565 | LogLoss: 0.1336

📘 Epoch 28


                                                        

✅ Acc: 0.9522 | LogLoss: 0.1385
📉 EarlyStopping: 1/5

📘 Epoch 29


                                                        

✅ Acc: 0.9567 | LogLoss: 0.1283

📘 Epoch 30


                                                        

✅ Acc: 0.9564 | LogLoss: 0.1289
📉 EarlyStopping: 1/5

📘 Epoch 31


                                                        

✅ Acc: 0.9562 | LogLoss: 0.1289
📉 EarlyStopping: 2/5

📘 Epoch 32


                                                        

✅ Acc: 0.9595 | LogLoss: 0.1198

📘 Epoch 33


                                                        

✅ Acc: 0.9599 | LogLoss: 0.1216
📉 EarlyStopping: 1/5

📘 Epoch 34


                                                        

✅ Acc: 0.9608 | LogLoss: 0.1175

📘 Epoch 35


                                                        

✅ Acc: 0.9602 | LogLoss: 0.1195
📉 EarlyStopping: 1/5

📘 Epoch 36


                                                        

✅ Acc: 0.9632 | LogLoss: 0.1150

📘 Epoch 37


                                                        

✅ Acc: 0.9615 | LogLoss: 0.1116

📘 Epoch 38


                                                        

✅ Acc: 0.9650 | LogLoss: 0.1153
📉 EarlyStopping: 1/5

📘 Epoch 39


                                                        

✅ Acc: 0.9658 | LogLoss: 0.1084

📘 Epoch 40


                                                        

✅ Acc: 0.9631 | LogLoss: 0.1080

📘 Epoch 41


                                                        

✅ Acc: 0.9616 | LogLoss: 0.1082
📉 EarlyStopping: 1/5

📘 Epoch 42


                                                        

✅ Acc: 0.9646 | LogLoss: 0.1086
📉 EarlyStopping: 2/5

📘 Epoch 43


                                                        

✅ Acc: 0.9645 | LogLoss: 0.1078

📘 Epoch 44


                                                        

✅ Acc: 0.9654 | LogLoss: 0.1058

📘 Epoch 45


                                                        

✅ Acc: 0.9645 | LogLoss: 0.1071
📉 EarlyStopping: 1/5

📘 Epoch 46


                                                        

✅ Acc: 0.9650 | LogLoss: 0.1066
📉 EarlyStopping: 2/5

📘 Epoch 47


                                                        

✅ Acc: 0.9640 | LogLoss: 0.1065
📉 EarlyStopping: 3/5

📘 Epoch 48


                                                        

✅ Acc: 0.9648 | LogLoss: 0.1057

📘 Epoch 49


                                                        

✅ Acc: 0.9650 | LogLoss: 0.1055

📘 Epoch 50


                                                        

✅ Acc: 0.9650 | LogLoss: 0.1055


Unnamed: 0,model,accuracy,log_loss,timestamp
0,swin_tiny_patch4_window7_224,0.964963,0.105472,20250527_141339


In [2]:
import torch
print(torch.__version__)
print(torch.version.cuda)
  

2.7.0+cu118
11.8
