In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, datasets, models
from sklearn.model_selection import KFold
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 数据预处理
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((256,256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 数据集路径
data_dir = '/content/drive/MyDrive/TOOTH_DAMAGE_CALSSIFICATION'
image_datasets = datasets.ImageFolder(data_dir, data_transforms['train'])

# 定义模型
def get_model(model_name):
    if model_name == 'resnet':
        model = models.resnet18(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, len(image_datasets.classes))
    elif model_name == 'efficientnet':
        model = models.efficientnet_b3(pretrained=True)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, len(image_datasets.classes))
    elif model_name == 'densenet':
        model = models.densenet161(pretrained=True)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, len(image_datasets.classes))
    else:
        raise ValueError('Invalid model name')
    return model

# 训练和验证
def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=25):
    best_model_wts = model.state_dict()
    best_acc = 0.0

    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        epoch_train_loss = 0.0
        epoch_val_loss = 0.0
        epoch_train_corrects = 0
        epoch_val_corrects = 0

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc.item())
                epoch_train_loss = epoch_loss
                epoch_train_corrects = running_corrects
            else:
                val_losses.append(epoch_loss)
                val_accuracies.append(epoch_acc.item())
                epoch_val_loss = epoch_loss
                epoch_val_corrects = running_corrects

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {epoch_train_loss:.4f} Acc: {epoch_train_corrects.double()/len(dataloaders["train"].dataset):.4f}')
        print(f'Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_corrects.double()/len(dataloaders["val"].dataset):.4f}')

    model.load_state_dict(best_model_wts)

    return model, train_losses, val_losses, train_accuracies, val_accuracies

# 交叉验证和评估
def cross_validate(model_name, num_epochs=25, batch_size=32):
    kf = KFold(n_splits=5, shuffle=True, random_state= 42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    all_labels = np.array(image_datasets.targets)
    all_preds = np.zeros_like(all_labels)
    error_samples = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(image_datasets)):
        print(f'Fold {fold + 1}')

        train_subset = Subset(image_datasets, train_idx)
        val_subset = Subset(image_datasets, val_idx)

        dataloaders = {
            'train': DataLoader(train_subset, batch_size=batch_size, shuffle=True),
            'val': DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        }

        model = get_model(model_name)
        model = model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

        model, train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, criterion, optimizer, dataloaders, device, num_epochs)

        val_preds = []
        val_labels = []

        # 预测
        model.eval()
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)

                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

                # 收集错误预测的样本
                error_samples.extend([(inputs[j], preds[j], labels[j]) for j in range(len(preds)) if preds[j] != labels[j]])

        all_preds[val_idx] = val_preds

        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title(f'{model_name} - Fold {fold+1} Loss')

        plt.subplot(1, 2, 2)
        plt.plot(train_accuracies, label='Train Accuracy')
        plt.plot(val_accuracies, label='Val Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.title(f'{model_name} - Fold {fold+1} Accuracy')

        plt.show()

    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')

    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=image_datasets.classes, yticklabels=image_datasets.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.show()


    # 显示错误预测的图片
    fig, axes = plt.subplots(len(error_samples), 1, figsize=(12, len(error_samples) * 4))
    fig.suptitle('Sample Misclassified Images')
    for i, (img, pred, label) in enumerate(error_samples[:len(error_samples)]):
        img = img.cpu().numpy().transpose((1, 2, 0))
        img = np.clip(img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406], 0, 1)
        axes[i].imshow(img)
        axes[i].set_title(f'Pred: {image_datasets.classes[pred]} / True: {image_datasets.classes[label]}')
        axes[i].axis('off')
    plt.show()

# 执行
for model_name in ['resnet']:
    print(f'\nTraining {model_name} model:')
    cross_validate(model_name, num_epochs=35, batch_size=8)


Output hidden; open in https://colab.research.google.com to view.