In [1]:
import torch, torchvision
from torch import nn
device = (
    torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cuda") if torch.cuda.is_available()
    else torch.device("cpur")
)
print("Device: ", device)

Device:  mps


In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x10c9a2090>

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081))
])

DATA_DIR = 'data'
train_ds = datasets.MNIST(root=DATA_DIR, train = True, download = True, transform=transforms)
test_ds = datasets.MNIST(root=DATA_DIR, train = False, download=True, transform=transforms)

In [4]:
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, 
                          num_workers=2, pin_memory=(device.type=='cuda'))
test_loader = DataLoader(test_ds, batch_size=1024, shuffle=False,
                         num_workers=2, pin_memory=(device.type=='cuda'))

In [5]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding = 1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.drop = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.drop(x)
        x = self.fc2(x)
        return x

In [6]:
model = MNIST_CNN().to(device)

In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()