In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:

tr = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', download=True, train=True, transform=tr)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=50, shuffle=True)
print(train_data)

test_data = datasets.MNIST('./data', train=False, transform=tr)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=50, shuffle=True)


total_batch = len(train_loader)
print(f"{total_batch=}")

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )
total_batch=1200


In [23]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 4, padding=1)
        
        self.norm1 = nn.BatchNorm2d(16)
        self.norm2 = nn.BatchNorm2d(32)
        self.norm3 = nn.BatchNorm2d(64)
        
        self.norm10 = nn.BatchNorm1d(128)
        self.norm20 = nn.BatchNorm1d(64)
        
        self.fc1 = nn.Linear(36 * 64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        
        self.pool = nn.MaxPool2d(2,2)
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.pool(self.norm2(F.relu(self.conv2(x))))
        x = self.dropout(self.pool(self.norm3(F.relu(self.conv3(x)))))
        x = x.view(-1, self.num_flat_features(x))
#         print(x.size())
        x = F.relu(self.norm10(self.fc1(x)))
        x = F.relu(self.dropout(self.norm20(self.fc2(x))))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        return np.prod(size)

In [24]:

net = Net().to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.01, patience=1, mode='min')   

print(net)

Net(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (norm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm20): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=2304, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.3, inplace=Fa

In [25]:

net.train()
for epoch in range(20):
    total_loss = 0
    for x, target in train_loader:
        x = x.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = net(x)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss
    scheduler.step(loss) 
        
    print(f"{epoch=} train_loss={100 * total_loss / total_batch}%")
    
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for (x, label) in test_loader:
            x = x.to(device)
            label = label.to(device)
            output = net(x)
            _, predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    print(f"Accuracy: {100 * correct / total}%")

torch.save(net.state_dict(), "./mnist.pth")
print("train finished")

epoch=0 train_loss=16.215747833251953%
Accuracy: 98.95%
epoch=1 train_loss=3.9626197814941406%
Accuracy: 98.52%
epoch=2 train_loss=2.7247776985168457%
Accuracy: 98.94%
epoch=3 train_loss=1.9666361808776855%
Accuracy: 99.17%
epoch=4 train_loss=1.5548884868621826%
Accuracy: 98.67%
epoch=5 train_loss=1.3974908590316772%
Accuracy: 98.96%
epoch=6 train_loss=1.3099182844161987%
Accuracy: 99.36%
epoch=7 train_loss=0.9125514030456543%
Accuracy: 99.16%
epoch=8 train_loss=1.0905722379684448%
Accuracy: 99.25%
epoch=9 train_loss=0.7924409508705139%
Accuracy: 99.08%
epoch=10 train_loss=0.7038989663124084%
Accuracy: 99.37%
epoch=11 train_loss=0.17871232330799103%
Accuracy: 99.5%
epoch=12 train_loss=0.03924359381198883%
Accuracy: 99.56%
epoch=13 train_loss=0.016178198158740997%
Accuracy: 99.53%
epoch=14 train_loss=0.00799106527119875%
Accuracy: 99.55%
epoch=15 train_loss=0.004029366187751293%
Accuracy: 99.54%
epoch=16 train_loss=0.002108413726091385%
Accuracy: 99.58%
epoch=17 train_loss=0.00103383813

In [26]:
correct = 0
total = 0
net = Net().to(device)
net.load_state_dict(torch.load("./mnist.pth"))
with torch.no_grad():
    net.eval()
    for (x, label) in test_loader:
        x = x.to(device)
        label = label.to(device)
        output = net(x)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
print(f"Accuracy: {100 * correct / total}%")

Accuracy: 99.57%
