In [165]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,random_split
from torchvision import datasets,transforms
from torch.optim import SGD
from tqdm.notebook import tqdm

In [166]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
             ])

In [167]:
data_train = datasets.MNIST('../data/MNIST/train',train=True,transform=transform)
data_test = datasets.MNIST('../data/MNIST/test',train=False,transform=transform)

In [168]:
data_train,data_val = random_split(data_train,[50000,10000],generator=torch.Generator().manual_seed(13))

In [169]:
train_loader = DataLoader(data_train,shuffle=True,batch_size=256)
val_loader = DataLoader(data_val,batch_size=256,shuffle=True)
test_loader = DataLoader(data_test,shuffle=False,batch_size=256)

In [None]:
class LeNet5(nn.Module):
    def __init__(self, num_class=10):
        super(LeNet5,self).__init__()
        self.num_class = num_class
        self.cnn_net = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=6,out_channels=16,stride=1,kernel_size=5),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=16,out_channels=120,kernel_size=5)
        )
        self.fcn_net = nn.Sequential(
            nn.Linear(in_features=120,out_features=84),
            nn.ReLU(),
            nn.Linear(in_features=84,out_features=self.num_class)
        )
    def forward(self,x):
        x = self.cnn_net(x)
        x = torch.flatten(x,1)#从第一维后全部拉直，[batch,120,1,1]拉成[batch,120*1*1]
        x = self.fcn_net(x)
        return x

In [194]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LeNet5().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = SGD(model.parameters(),lr=0.01,momentum=0.9)

In [195]:
num_epochs = 10

In [None]:
for epoch in range(num_epochs):
    train_correct = 0
    train_total = 0
    val_correct = 0
    val_total = 0
    train_pbar = tqdm(train_loader)
    for X,y in train_pbar:
        optimizer.zero_grad()
        X = X.to(device)
        y = y.to(device)
        train_pbar.set_description(f'Traing:[{epoch+1}]/[{num_epochs}]')
        y_pred = model.forward(X)
        train_correct += torch.sum(y == torch.max(y_pred,1)[1]).detach().item()
        train_total += y.size(0)
        loss = criterion(y_pred,y)
        loss.backward()
        optimizer.step()
        train_pbar.set_postfix_str(f'Training loss:{loss:.4f} acc:{train_correct/train_total:.4f}')

    with torch.no_grad():
        val_pbar = tqdm(val_loader)
        val_pbar.set_description(desc=f'Val:[{epoch+1}]/[{num_epochs}]')

        for X,y in val_pbar:
            X = X.to(device)
            y = y.to(device)
            y_pred = model.forward(X)
            val_correct += torch.sum(y == torch.max(y_pred,1)[1]).detach().item()
            val_total += y.size(0)
            loss = criterion(y_pred,y)
            val_pbar.set_postfix_str(f'Val loss:{loss:.4f} acc:{val_correct/val_total:.4f}')

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

In [197]:
test_total = 0
test_correct = 0

In [198]:
model.eval()
with torch.no_grad():
    for X,y in test_loader:
        X = X.to(device)
        y = y.to(device)
        output = model.forward(X)
        test_correct += torch.sum(y == torch.max(output,1)[1]).item()
        test_total += y.size(0)
    print(f'Test acc:{test_correct/test_total}')

Test acc:0.9847
