In [17]:
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

In [18]:
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST(root='./data', train=False, download=False, transform=ToTensor())

In [28]:
LR=1e-3
BATCH_SIZE = [64,128,256,512,1024]
epochs = 40

In [29]:
model = nn.Sequential(
    nn.Linear(28*28, 64),
    nn.Sigmoid(),
    nn.Linear(64, 10)
)
device = "cuda" if not torch.cuda.is_available() else "cpu"
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

In [30]:
for batch_size in BATCH_SIZE:
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        for data,target in train_dataloader:
            data, target = data.to(device), target.to(device)
            predict = model(data.reshape(data.shape[0], -1))
            loss = loss_fn(predict, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    print(f'batch_size:{batch_size} Epoch:{epoch} Loss: {loss.item()}')

batch_size:64 Epoch:39 Loss: 0.7040919065475464
batch_size:128 Epoch:39 Loss: 0.5000953078269958
batch_size:256 Epoch:39 Loss: 0.586323082447052
batch_size:512 Epoch:39 Loss: 0.6334527134895325
batch_size:1024 Epoch:39 Loss: 0.4597504734992981


In [31]:
# 测试
for batch_size in BATCH_SIZE:
    correct = 0
    total = 0
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    with torch.no_grad():
        for data,target in test_dataloader:
            output = model(data.reshape(data.shape[0], -1))
            _,predict = torch.max(output, 1)
            total += target.size(0)
            correct += (predict == target).sum().item()
    print(f'batch_size : {batch_size} Accuracy: {100 * correct / total:.2f}%')

batch_size : 64 Accuracy: 88.21%
batch_size : 128 Accuracy: 88.21%
batch_size : 256 Accuracy: 88.21%
batch_size : 512 Accuracy: 88.21%
batch_size : 1024 Accuracy: 88.21%
