In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader 
from torchvision import datasets, transforms
from torch.optim import AdamW

In [None]:
class LeNet(nn.Module):
    """conv net lecun"""
    def __init__(self):
        super().__init__()
        self.conv_block = nn.Sequential(nn.Conv2d(1,32,5), 
                       nn.MaxPool2d(3), 
                       nn.ReLU(),
                       nn.Conv2d(32,64,5),
                       nn.MaxPool2d(2), 
                       nn.ReLU()
                      )
        self.mlp = nn.Sequential(nn.Linear(256,200), nn.ReLU(), nn.Linear(200,10))

    def forward(self,x, y=None):
        x = self.conv_block(x)
        x = x.view(-1,256)
        logits = self.mlp(x)
        if y is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(-1,10), y)
        return logits, loss

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
mnist = datasets.MNIST('./data', train = True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist,batch_size=128, shuffle=True)

In [None]:
model = LeNet().to(device)
opt = AdamW(model.parameters())

In [None]:
mnist_test = datasets.MNIST('./data', train = False, download=True, transform=transforms.ToTensor())
test_dataloader = DataLoader(mnist_test,batch_size=128)
len(mnist_test)

10000

In [None]:
@torch.inference_mode()
def eval():
    counts = 0
    for X,Y in test_dataloader:
        X=X.to(device)
        Y=Y.to(device)
        logits, _ = model(X)
        preds = torch.argmax(logits,dim=-1)
        counts += torch.eq(preds,Y).sum().item()
    acc = counts/len(mnist_test)
    return acc

In [None]:
model.train()
for i, batch in enumerate(dataloader):
    X,Y= batch[0].to(device), batch[1].to(device)
    logits, loss = model(X,Y)
    loss.backward()
    opt.step()
    opt.zero_grad()
    if (i+1)%100 == 0:
        print(f'{i+1}th batch {loss.item()=:04f}')
        print(eval())

100th batch loss.item()=0.293359
0.9276
200th batch loss.item()=0.162600
0.962
300th batch loss.item()=0.099568
0.9716
400th batch loss.item()=0.088467
0.9704
