In [1]:
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

In [2]:
IMG_PATH = "./mnist_images/"

In [3]:
transform = transforms.Compose([
	transforms.ToTensor()
])

In [4]:
mnist_train_dataset = torchvision.datasets.MNIST(
	root=IMG_PATH, train=True, transform=transform, download=True
)
mnist_test_dataset = torchvision.datasets.MNIST(
	root=IMG_PATH, train=False, transform=transform, download=True
)

In [5]:
batch_size = 32
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, batch_size=batch_size, shuffle=True)

In [6]:
hidden_units = [32,16]
img_size = mnist_test_dataset[0][0].shape
input_size = img_size[0]*img_size[1]* img_size[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
	layer = nn.Linear(input_size, hidden_unit)
	all_layers.append(layer)
	all_layers.append(nn.ReLU())
	input_size = hidden_unit

all_layers.append(nn.Linear(hidden_units[-1], 10))

model = nn.Sequential(*all_layers)
model


Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [7]:
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [8]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)
torch.manual_seed(1)
num_epochs = 20
for epoch in range(num_epochs):
	acc_hist_train = 0
	for x_batch, y_batch in train_dl:
		pred = model(x_batch)
		loss = loss_fn(pred, y_batch)
		loss.backward()
		optimizer.step()
		optimizer.zero_grad()
		is_correct = (torch.argmax(pred,dim=1) == y_batch).float()
		acc_hist_train += is_correct.sum()
	acc_hist_train/=len(train_dl.dataset)
	print(f"Epoch: {epoch} Accuracy: {acc_hist_train: .4f}")


Epoch: 0 Accuracy:  0.8762
Epoch: 1 Accuracy:  0.9379
Epoch: 2 Accuracy:  0.9493
Epoch: 3 Accuracy:  0.9564
Epoch: 4 Accuracy:  0.9603
Epoch: 5 Accuracy:  0.9640
Epoch: 6 Accuracy:  0.9673
Epoch: 7 Accuracy:  0.9699
Epoch: 8 Accuracy:  0.9717
Epoch: 9 Accuracy:  0.9736
Epoch: 10 Accuracy:  0.9752
Epoch: 11 Accuracy:  0.9764
Epoch: 12 Accuracy:  0.9777
Epoch: 13 Accuracy:  0.9779
Epoch: 14 Accuracy:  0.9797
Epoch: 15 Accuracy:  0.9811
Epoch: 16 Accuracy:  0.9814
Epoch: 17 Accuracy:  0.9826
Epoch: 18 Accuracy:  0.9825
Epoch: 19 Accuracy:  0.9841


In [9]:
import torch

print(torch.cuda.is_available())

True


In [10]:
pred = model(mnist_test_dataset.data/255.0)
is_correct = (torch.argmax(pred, dim = 1) == mnist_test_dataset.targets).float()
print(f"Test Accuracy: {is_correct.mean():.4f}")

Test Accuracy: 0.9651
