In [31]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [32]:
input_size = 28
num_classes = 10
num_epochs = 100
batch_size = 64
train_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)
device = torch.device("mps")

In [33]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [34]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
        nn.Conv2d(
            in_channels=1,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        )
        
        self.conv2 = nn.Sequential(
        nn.Conv2d(16,32,5,1,2),
        nn.ReLU(),
        nn.MaxPool2d(2),
        )
        
        self.out = nn.Linear(32*7*7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

In [35]:
def accuracy(predictions, labels):
    pred = torch.max(predictions.data,1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)

In [36]:
net = CNN()
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(num_epochs):
    train_rights = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        net.train()
        output = net(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(output, target)
        train_rights.append(right)
        
        if batch_idx%100 == 0:
            net.eval()
            val_rights = []
            
            for (data, target) in test_loader:
                data, target = data.to(device), target.to(device)
                output = net(data)
                right = accuracy(output, target)
                val_rights.append(right)
                
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
            print("current epoch:{} [{}/{} ({:.0f}%)]\tloss:{:.6f}\taccuracy in train:{:.2f}%\taccuracy in test:{:.2f}%".format(epoch, batch_idx*batch_size, 
                                                                        len(train_loader.dataset),
                                                                        100.*batch_idx/len(train_loader),
                                                                        loss.data,
                                                                        100.*train_r[0].cpu().numpy()/train_r[1],
                                                                        100.*val_r[0].cpu().numpy()/val_r[1]))



























In [None]:
for data, label in train_dataset:
    plt.imshow(data.numpy().reshape(28,28),cmap="Greys")

In [None]:
plt.show()