In [1]:
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms, models


######################################
# 0. ÏÑ§Ï†ï
######################################

data_dir = r"C:\dataset\mealworm_crops"  # abnormal / normal Ìè¥ÎçîÍ∞Ä ÏûàÎäî ÏÉÅÏúÑ Ìè¥Îçî
num_classes = 2
batch_size = 32
num_epochs = 10
learning_rate = 1e-4
val_ratio = 0.2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ / Î≤†Ïä§Ìä∏ Î™®Îç∏ Ï†ÄÏû• Í≤ΩÎ°ú
best_model_path = "best_resnet50_mealworm.pth"
checkpoint_path = "checkpoint_resnet50_mealworm_last.pth"


######################################
# 1. Îç∞Ïù¥ÌÑ∞ÏÖã / Ï†ÑÏ≤òÎ¶¨
######################################

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ÏÖã Î°úÎìú
full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)

class_names = full_dataset.classes
print("Classes:", class_names)

# train / val Î∂ÑÌï†
num_total = len(full_dataset)
num_val = int(num_total * val_ratio)
num_train = num_total - num_val

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(
    full_dataset,
    [num_train, num_val],
    generator=generator
)

# Í≤ÄÏ¶ù Îç∞Ïù¥ÌÑ∞ÏÖã Î≥ÄÌôò Ï†ÅÏö©
val_dataset = Subset(
    datasets.ImageFolder(root=data_dir, transform=val_transform),
    val_dataset.indices
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


######################################
# 2. Î™®Îç∏ Ï†ïÏùò (ResNet-50 Fine-tuning)
######################################

resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

# Fully connected layer ÏàòÏ†ï
in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, num_classes)
resnet50 = resnet50.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50.parameters(), lr=learning_rate)


######################################
# 3. Ïù¥Ïñ¥ÏÑú ÌïôÏäµÏùÑ ÏúÑÌïú Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ Î°úÎìú
######################################

best_val_acc = 0.0
start_epoch = 0   # Ïù¥ epochÎ∂ÄÌÑ∞ ÏãúÏûë

if os.path.exists(checkpoint_path):
    print(f"Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ Î∞úÍ≤¨: {checkpoint_path}, ÌïôÏäµÏùÑ Ïû¨Í∞úÌï©ÎãàÎã§.")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    resnet50.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    best_val_acc = checkpoint.get("best_val_acc", 0.0)
    start_epoch = checkpoint.get("epoch", 0) + 1  # ÎßàÏßÄÎßâÏúºÎ°ú ÎÅùÎÇú epoch Îã§ÏùåÎ∂ÄÌÑ∞ ÏãúÏûë

    print(f"Ïù¥Ï†Ñ ÌïôÏäµ Í∏∞Î°ù: epoch {checkpoint['epoch']+1}/{num_epochs}, "
          f"best_val_acc={best_val_acc:.4f}")
else:
    print("Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ ÏóÜÏùå. Ï≤òÏùåÎ∂ÄÌÑ∞ ÌïôÏäµÏùÑ ÏãúÏûëÌï©ÎãàÎã§.")


######################################
# 4. ÌïôÏäµ / Í≤ÄÏ¶ù Ìï®Ïàò
######################################

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


######################################
# 5. ÌïôÏäµ Î£®ÌîÑ (Ïù¥Ïñ¥ÌïòÍ∏∞ ÏßÄÏõê)
######################################

if start_epoch >= num_epochs:
    print(f"start_epoch({start_epoch}) >= num_epochs({num_epochs}) Ïù¥ÎØÄÎ°ú "
          f"Ï∂îÍ∞Ä ÌïôÏäµ ÏóÜÏù¥ Ï¢ÖÎ£åÌï©ÎãàÎã§.")
else:
    print(f"ÌïôÏäµ ÏãúÏûë: epoch {start_epoch+1} ~ {num_epochs}")
    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()

        train_loss, train_acc = train_one_epoch(
            resnet50, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc = validate(
            resnet50, val_loader, criterion, device
        )

        elapsed = time.time() - start_time

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | "
              f"Time: {elapsed:.1f}s")

        # Î≤†Ïä§Ìä∏ Î™®Îç∏ Í∞±Ïã† Ïãú Ï†ÄÏû•
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(resnet50.state_dict(), best_model_path)
            print(f">> Best model updated! (val_acc={best_val_acc:.4f}) "
                  f"saved to {best_model_path}")

        # Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ Ï†ÄÏû•
        checkpoint = {
            "epoch": epoch,
            "model_state": resnet50.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_val_acc": best_val_acc,
        }
        torch.save(checkpoint, checkpoint_path)
        # ÌïÑÏöîÌïòÎ©¥ Ïó¨Í∏∞ÏÑúÎèÑ print
        # print(f"   [checkpoint] saved to {checkpoint_path}")

    print("Training finished.")
    print(f"Best Val Acc: {best_val_acc:.4f}")


Using device: cuda
Classes: ['abnormal', 'normal']
üÜï Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏Í∞Ä ÏóÜÏäµÎãàÎã§. Ï≤òÏùåÎ∂ÄÌÑ∞ ÌïôÏäµÏùÑ ÏãúÏûëÌï©ÎãàÎã§.
üöÄ ÌïôÏäµ ÏãúÏûë: epoch 1 ~ 10
Epoch [1/10] Train Loss: 0.0300 Acc: 0.9904 | Val Loss: 0.0105 Acc: 0.9976 | Time: 340.3s
>> Best model updated! (val_acc=0.9976) saved to best_resnet50_mealworm.pth
Epoch [2/10] Train Loss: 0.0091 Acc: 0.9972 | Val Loss: 0.0039 Acc: 0.9990 | Time: 354.3s
>> Best model updated! (val_acc=0.9990) saved to best_resnet50_mealworm.pth
Epoch [3/10] Train Loss: 0.0069 Acc: 0.9978 | Val Loss: 0.0073 Acc: 0.9978 | Time: 382.2s
Epoch [4/10] Train Loss: 0.0064 Acc: 0.9979 | Val Loss: 0.0084 Acc: 0.9970 | Time: 380.4s
Epoch [5/10] Train Loss: 0.0049 Acc: 0.9984 | Val Loss: 0.0045 Acc: 0.9986 | Time: 381.3s
Epoch [6/10] Train Loss: 0.0038 Acc: 0.9987 | Val Loss: 0.0055 Acc: 0.9984 | Time: 380.4s
Epoch [7/10] Train Loss: 0.0037 Acc: 0.9987 | Val Loss: 0.0020 Acc: 0.9996 | Time: 360.4s
>> Best model updated! (val_acc=0.9996) saved to best_re