In [None]:
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from timm import create_model
from optuna.pruners import MedianPruner

# 장치 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터셋 (CIFAR10 예시, 원하는 데이터셋으로 교체 가능)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

# 평가 함수
def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, preds = torch.max(outputs, 1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# 학습 함수
def train_one_epoch(model, optimizer, criterion, loader):
    model.train()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

# Optuna 목적 함수
def objective(trial):
    # 하이퍼파라미터 샘플링
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-3)
    weight_decay = trial.suggest_loguniform('weight_decay', 1e-5, 1e-2)
    drop_path_rate = trial.suggest_float('drop_path_rate', 0.0, 0.3)

    # coat_medium 모델 생성
    model = create_model(
        'coat_medium',
        pretrained=False,
        num_classes=10,
        drop_path_rate=drop_path_rate
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(5):  # 빠른 튜닝 목적
        train_one_epoch(model, optimizer, criterion, train_loader)
        val_acc = evaluate(model, val_loader)
        trial.report(val_acc, epoch)

        # Pruning 조건
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_acc

# Optuna 실험 실행
study = optuna.create_study(direction='maximize', pruner=MedianPruner(n_warmup_steps=2))
study.optimize(objective, n_trials=30, timeout=3600)

print("Best trial:")
print(study.best_trial.params)
