### 多分类问题

In [105]:
import torch

def default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')   
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

device = default_device()

In [106]:
import torchvision
import torchvision.transforms as transforms
transformation = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transformation, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transformation, download=True)


In [107]:
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    plt.figure(figsize=(16,8))
    for i in range(8):
        plt.subplot(1, 8, i+1)
        plt.imshow(images[i].numpy().squeeze(), cmap='gray')
        plt.title(labels[i].item())
    break
plt.show()


In [109]:
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, input_size=28*28, num_classes = 10):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(input_size, num_classes)
        # self.layer2 = nn.Linear(128, 10)
        # self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.layer1(x)
        # x = self.relu(x)
        # x = self.layer2(x)
        return x

In [110]:
input_size = 28*28
num_classes = 10
model = Model(input_size, num_classes).to(device)

In [111]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [112]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(-1, input_size).to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.view(-1, input_size).to(device)
        labels = labels.long().to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    accuracy = evaluate_model(model, test_loader)
    print('Accuracy: {:.2f}%'.format(accuracy*100))