In [5]:
import medmnist
from medmnist import INFO, Evaluator
from medmnist.dataset import TissueMNIST
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt

def load_data():
    data_transform = Compose([
        ToTensor(),
        Normalize(mean=[0.5], std=[0.5])
    ])

    train_dataset = TissueMNIST(
        split='train',
        transform=data_transform,
        download=True
    )
    val_dataset = TissueMNIST(
        split='val',
        transform=data_transform,
        download=True
    )

    train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)

    n_classes = len(train_dataset.info['label'])

    return train_loader, val_loader, n_classes

class EnhancedCNN(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(EnhancedCNN, self).__init__()
        self.conv1 = nn.Conv2d(n_channels, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, n_classes)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


def train_model(model, train_loader, val_loader, n_classes, device, save_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 20
    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.squeeze(1).long()
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1} - Loss: {running_loss/len(train_loader):.4f}, '
              f'Accuracy: {accuracy:.2f}%')

        val_accuracy = evaluate(model, val_loader, device)
        print(f'Validation Accuracy: {val_accuracy:.2f}%')

        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), f"{save_path}/best_model.pth")
            print("Saved improved model")

def evaluate(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.squeeze(1).long()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def main():
    train_loader, val_loader, n_classes = load_data()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EnhancedCNN(n_channels=1, n_classes=n_classes).to(device)
    
    save_path = "."
    train_model(model, train_loader, val_loader, n_classes, device, save_path)

if __name__ == '__main__':
    main()


Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz
Epoch 1 - Loss: 1.3321, Accuracy: 50.16%
Validation Accuracy: 56.57%
Saved improved model
Epoch 2 - Loss: 1.1578, Accuracy: 57.17%
Validation Accuracy: 58.65%
Saved improved model
Epoch 3 - Loss: 1.1001, Accuracy: 59.55%
Validation Accuracy: 60.77%
Saved improved model
Epoch 4 - Loss: 1.0637, Accuracy: 60.88%
Validation Accuracy: 62.34%
Saved improved model
Epoch 5 - Loss: 1.0346, Accuracy: 62.06%
Validation Accuracy: 62.81%
Saved improved model
Epoch 6 - Loss: 1.0155, Accuracy: 62.67%
Validation Accuracy: 62.38%
Epoch 7 - Loss: 0.9970, Accuracy: 63.40%
Validation Accuracy: 62.12%
Epoch 8 - Loss: 0.9828, Accuracy: 63.88%
Validation Accuracy: 63.37%
Saved improved model
Epoch 9 - Loss: 0.9687, Accuracy: 64.38%
Validation Accuracy: 63.06%
Epoch 10 - Loss: 0.9591, Accuracy: 64.85%
Validation Accuracy: 62.23%
Epoch 11 - Loss: 0.9440, Accu

In [6]:
def load_test_data():
    """加载测试数据集"""
    data_transform = Compose([
        ToTensor(),
        Normalize(mean=[0.5], std=[0.5])
    ])

    # 假设使用验证集作为测试集
    test_dataset = TissueMNIST(
        split='test',
        transform=data_transform,
        download=True
    )

    test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
    return test_loader

def test_model(model, test_loader, device):
    """测试模型的性能"""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            labels = labels.squeeze(1).long()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_classes = 8  # 假设有8个类别，根据你的实际情况修改
    model = EnhancedCNN(n_channels=1, n_classes=n_classes).to(device)

    # 加载模型权重
    model_path = './best_model.pth'  # 确保路径正确
    model.load_state_dict(torch.load(model_path))

    # 加载测试数据
    test_loader = load_test_data()

    # 测试模型
    test_model(model, test_loader, device)

if __name__ == '__main__':
    main()

Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz
Test Accuracy: 63.48%
