In [8]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device('mps')
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print('MPS device not found.')

tensor([1.], device='mps:0')


In [9]:
# build dataset and load data
import torch, torchvision
from torchvision import transforms
from torchvision.datasets import MNIST

image_path = './Data/'
transform = transforms.Compose([transforms.ToTensor()])

mnist_dataset = MNIST(root=image_path, train=True, transform=transform, download=False)

from torch.utils.data import Subset
from torch.utils.data import DataLoader

def get_default_device():
    if torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader(DataLoader):
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    
    def __len__(self):
        return len(self.dl)

mnist_valid_dataset = Subset(mnist_dataset, torch.arange(10000))
mnist_train_dataset = Subset(mnist_dataset, torch.arange(10000, len(mnist_dataset)))
mnist_test_dataset = MNIST(root=image_path, train=False, transform=transform, download=False)

batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)
valid_dl = DataLoader(mnist_valid_dataset, batch_size, shuffle=False)

# Construct the model
import torch.nn as nn
model = nn.Sequential()
model.add_module(
    'conv1',
    nn.Conv2d(
        in_channels=1, out_channels=32,
        kernel_size=5, padding=2
    )
)
model.add_module('relu1', nn.ReLU())
model.add_module('pool1', nn.MaxPool2d(kernel_size=2))
model.add_module(
    'conv2',
    nn.Conv2d(
        in_channels=32, out_channels=64,
        kernel_size=5, padding=2
    )
)
model.add_module('relu2', nn.ReLU())
model.add_module('pool2', nn.MaxPool2d(kernel_size=2))
model.add_module('flatten', nn.Flatten())
model.add_module('fc1', nn.Linear(3136, 1024))
model.add_module('relu3', nn.ReLU())
model.add_module('dropout', nn.Dropout(p=0.5))
model.add_module('fc2', nn.Linear(1024, 10))

loss_fn = nn.CrossEntropyLoss()
optiminzer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, num_epochs, train_dl, valid_dl):
    loss_hist_train = [0] * num_epochs
    accuracy_hist_train = [0] * num_epochs
    loss_hist_valid = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs
    for epoch in range(num_epochs):
        model.train()
        for x_batch, y_batch in train_dl:
            pred = model(x_batch)
            loss = loss_fn(pred, y_batch)
            loss.backward()
            optiminzer.step()
            optiminzer.zero_grad()
            loss_hist_train[epoch] += loss.item() * y_batch.size(0)
            is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
            accuracy_hist_train[epoch] += is_correct.sum()
        loss_hist_train[epoch] /= len(train_dl.dl.dataset)
        accuracy_hist_train[epoch] /= len(train_dl.dl.dataset)
    
        model.eval()
        with torch.no_grad():
            for x_batch, y_batch in valid_dl:
                pred = model(x_batch)
                loss = loss_fn(pred, y_batch)
                loss_hist_valid[epoch] += loss.item() * y_batch.size(0)
                is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
                accuracy_hist_valid[epoch] += is_correct.sum()
            loss_hist_valid[epoch] /= len(valid_dl.dl.dataset)
            accuracy_hist_valid[epoch] /= len(valid_dl.dl.dataset)

            print(f'Epoch {epoch+1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy:'
                  f'{accuracy_hist_valid[epoch]:.4f}')
    return loss_hist_train, loss_hist_valid, accuracy_hist_train, accuracy_hist_valid

In [10]:
# Training the model
torch.manual_seed(1)
num_epochs = 20
device = get_default_device()
print(f'Device: {device}')

train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)

model.to(device)

hist = train(model, num_epochs, train_dl, valid_dl)

Device: mps
Epoch 1 accuracy: 0.9486 val_accuracy:0.9549
Epoch 2 accuracy: 0.9851 val_accuracy:0.9867
Epoch 3 accuracy: 0.9894 val_accuracy:0.9839
Epoch 4 accuracy: 0.9916 val_accuracy:0.9899
Epoch 5 accuracy: 0.9932 val_accuracy:0.9894
Epoch 6 accuracy: 0.9943 val_accuracy:0.9902
Epoch 7 accuracy: 0.9950 val_accuracy:0.9900
Epoch 8 accuracy: 0.9963 val_accuracy:0.9900
Epoch 9 accuracy: 0.9970 val_accuracy:0.9909
Epoch 10 accuracy: 0.9966 val_accuracy:0.9903
Epoch 11 accuracy: 0.9969 val_accuracy:0.9908
Epoch 12 accuracy: 0.9972 val_accuracy:0.9906
Epoch 13 accuracy: 0.9975 val_accuracy:0.9909
Epoch 14 accuracy: 0.9976 val_accuracy:0.9892
Epoch 15 accuracy: 0.9976 val_accuracy:0.9910
Epoch 16 accuracy: 0.9979 val_accuracy:0.9900
Epoch 17 accuracy: 0.9982 val_accuracy:0.9901
Epoch 18 accuracy: 0.9988 val_accuracy:0.9914
Epoch 19 accuracy: 0.9977 val_accuracy:0.9889
Epoch 20 accuracy: 0.9992 val_accuracy:0.9900
