In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
import torch.optim as optim
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

In [2]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

# using CPU
DEVICE = torch.device("cpu")

In [3]:
class SimpleNet(nn.Module):
# TODO:define model
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=(6, 6))  # input:(1,28,28) output:(10,24,24)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(2, 2))  # input:(10,12,12) output:(20,10,10)
        self.fc1 = nn.Linear(20*10*10, 500)
        self.fc2 = nn.Linear(500, 10)  # AX+B

    def forward(self, x):
        in_size = x.size(0)
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, (2, 2))
        out = self.conv2(out)
        out = F.relu(out)
        out = out.view(in_size, -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out
    
    
model = SimpleNet().to(DEVICE)   # use CPU

# TODO:define loss function and optimiter
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [7]:
# train and evaluate
for epoch in range(NUM_EPOCHS):
    model.train()
    correct = 0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
        correct += pred.eq(labels.view_as(pred)).sum().item()
        
print('\nTrain Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100.*correct / len(train_loader.dataset)))

for epoch in range(NUM_EPOCHS):
    #model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            output = model(images)
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
            correct += pred.eq(labels.view_as(pred)).sum().item()  
            
print('\nTest Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100.*correct / len(test_loader.dataset)))       
        
        
        
    # evaluate
    # TODO:calculate the accuracy using traning and testing dataset

100%|██████████| 468/468 [00:15<00:00, 30.45it/s]
100%|██████████| 468/468 [00:17<00:00, 26.50it/s]
100%|██████████| 468/468 [00:18<00:00, 25.73it/s]
100%|██████████| 468/468 [00:16<00:00, 28.58it/s]
100%|██████████| 468/468 [00:18<00:00, 25.33it/s]
100%|██████████| 468/468 [00:18<00:00, 24.89it/s]
100%|██████████| 468/468 [00:18<00:00, 25.15it/s]
100%|██████████| 468/468 [00:18<00:00, 25.38it/s]
100%|██████████| 468/468 [00:18<00:00, 24.89it/s]
100%|██████████| 468/468 [00:18<00:00, 24.97it/s]
  5%|▌         | 4/78 [00:00<00:02, 33.21it/s]


Train Accuracy: 59421/60000 (99%)



100%|██████████| 78/78 [00:01<00:00, 48.74it/s]
100%|██████████| 78/78 [00:01<00:00, 51.91it/s]
100%|██████████| 78/78 [00:01<00:00, 51.80it/s]
100%|██████████| 78/78 [00:01<00:00, 51.03it/s]
100%|██████████| 78/78 [00:01<00:00, 51.47it/s]
100%|██████████| 78/78 [00:01<00:00, 51.51it/s]
100%|██████████| 78/78 [00:01<00:00, 51.55it/s]
100%|██████████| 78/78 [00:01<00:00, 51.46it/s]
100%|██████████| 78/78 [00:01<00:00, 51.33it/s]
100%|██████████| 78/78 [00:01<00:00, 51.88it/s]


Test Accuracy: 9848/10000 (98%)




