In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
import torchvision
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = torchvision.datasets.MNIST(train=True,transform=transform,download=False, root='./data')
test_data = torchvision.datasets.MNIST(train=False,transform=transform,download=False, root='./data')

from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_data,batch_size=64,shuffle=True)
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=False)


In [None]:
x, label = train_data[0]
x.shape

In [None]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=8, kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.clf = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(in_features=16*7*7,out_features=256),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=256,out_features=10)
        )
    def forward(self,x):
        x = self.conv_block(x)
        x = self.clf(x)
        return x

In [None]:
torch.manual_seed(42)

model = MNIST_CNN()
optim = torch.optim.Adam(model.parameters(),lr=0.005)
loss_fn = nn.CrossEntropyLoss()

In [None]:
n_epochs = 10
model.to(device)
for epoch in range(n_epochs):
    train_loss = 0
    for batch, (X, y) in enumerate(train_loader):
        model.train()
        X, y = X.to(device), y.to(device)
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss
        optim.zero_grad()
        loss.backward()
        optim.step()
    train_loss = train_loss/len(train_loader)

    test_loss, test_acc = 0, 0
    model.eval()
    correct, total = 0, 0
    with torch.inference_mode():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            test_pred = model(X)


            test_loss += loss_fn(test_pred, y)
            y_pred_labels = test_pred.argmax(dim=1)

            correct += (y_pred_labels == y).sum().item()
            total += y.size(0)


        test_loss /= len(test_loader)

        test_acc = correct / total

    print(f"Epoch {epoch+1}:")
    print(f"Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc*100:.2f}%\n")
