In [1]:
from models.cnn import CNN
import os
import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda")
root = "./datasets"
download = os.path.exists("./datasets/MNIST")

model = CNN()
model.to(device)

trainset = datasets.MNIST(root="./datasets", download=download, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root="./datasets", download=download, train=False, transform=transforms.ToTensor())
trainloader = DataLoader(trainset, shuffle=True, batch_size=1024)
testloader = DataLoader(testset, shuffle=False, batch_size=128)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(),lr=1e-3)

epochs = 100
best_val_acc = 0
for e in range(epochs):
    with tqdm(trainloader,desc=f"{e+1}/{epochs} epochs") as t:
        running_correct = 0
        running_loss = 0
        running_total = 0
        model.train()
        for i, (x,y) in enumerate(t):
            out = model(x.to(device))
            loss = loss_fn(out,y.to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
            running_loss += loss.item()*x.size(0)
            running_total += x.size(0)
            if i < len(trainloader)-1:
                t.set_postfix({
                    "train_acc": running_correct/running_total,
                    "train_loss": running_loss/running_total
                })
            else:
                model.eval()
                val_correct = 0
                val_loss = 0
                val_total = 0
                with torch.no_grad():
                    for _, (x,y) in zip(range(4),testloader):
                        out = model(x.to(device))
                        loss = loss_fn(out,y.to(device))
                        val_correct += (out.max(dim=1)[1]==y.to(device)).sum().item()
                        val_loss += loss.item()*x.size(0)
                        val_total += x.size(0)
                        t.set_postfix({
                            "train_acc": running_correct/running_total,
                            "train_loss": running_loss/running_total,
                            "val_acc": val_correct/val_total,
                            "val_loss": val_loss/val_total
                        })
                    if val_correct/val_total >= best_val_acc:
                        best_val_acc = val_correct/val_total
                        torch.save(model.state_dict(), "./model_weights/mnist_cnn_clean.pt")

1/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.58it/s, train_acc=0.793, train_loss=0.658, val_acc=0.951, val_loss=0.154] 
2/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.38it/s, train_acc=0.938, train_loss=0.21, val_acc=0.975, val_loss=0.0944]
3/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.29it/s, train_acc=0.957, train_loss=0.147, val_acc=0.979, val_loss=0.0642]
4/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.39it/s, train_acc=0.966, train_loss=0.116, val_acc=0.982, val_loss=0.0554]
5/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.30it/s, train_acc=0.97, train_loss=0.103, val_acc=0.98, val_loss=0.0494] 
6/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.69it/s, train_acc=0.974, train_loss=0.0875, val_acc=0.986, val_loss=0.0395]
7/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.53it/s, train_acc=0.976, train_loss=0.0793, val_acc=0.988, val_loss=0.0366]
8/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.50it/s, train_acc=0.978, train_loss=0.0716,

63/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.72it/s, train_acc=0.995, train_loss=0.0165, val_acc=0.996, val_loss=0.0125] 
64/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.14it/s, train_acc=0.994, train_loss=0.0166, val_acc=0.996, val_loss=0.0108] 
65/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.17it/s, train_acc=0.994, train_loss=0.0173, val_acc=0.994, val_loss=0.0145] 
66/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.76it/s, train_acc=0.995, train_loss=0.0159, val_acc=0.994, val_loss=0.00969]
67/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.75it/s, train_acc=0.995, train_loss=0.0176, val_acc=0.994, val_loss=0.0142] 
68/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.67it/s, train_acc=0.994, train_loss=0.0175, val_acc=0.996, val_loss=0.0111] 
69/100 epochs: 100%|██████████| 59/59 [00:05<00:00, 11.77it/s, train_acc=0.995, train_loss=0.0163, val_acc=0.992, val_loss=0.0126] 
70/100 epochs: 100%|██████████| 59/59 [00:04<00:00, 12.70it/s, train_acc=0.9