In [1]:
from torch import nn, cuda, optim, save
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from kmnist_torch import train, test
from kmnist_conv import CNN


if __name__ == "__main__":
    device = "cuda" if cuda.is_available() else "cpu"

    training_data = MNIST(
        root="D:\ProgramData\data", train=True, download=True, transform=ToTensor()
    )

    test_data = MNIST(
        root="D:\ProgramData\data", train=False, download=True, transform=ToTensor()
    )

    batch_size = 128

    train_dataloader = DataLoader(
        training_data,
        batch_size=batch_size,
        shuffle=True)

    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    mymodel = CNN().to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(mymodel.parameters(), lr=1e-3)

    epochs = 10
    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train(train_dataloader, mymodel, loss_fn, optimizer, device)
    test_loss, correct = test(test_dataloader, mymodel, loss_fn, device)
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"  # noqa: E501
    )
    save(mymodel, "models/mnist_conv.pth")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to D:\ProgramData\data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [06:54<00:00, 23887.08it/s]


Extracting D:\ProgramData\data\MNIST\raw\train-images-idx3-ubyte.gz to D:\ProgramData\data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to D:\ProgramData\data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 168027.68it/s]


Extracting D:\ProgramData\data\MNIST\raw\train-labels-idx1-ubyte.gz to D:\ProgramData\data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to D:\ProgramData\data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:40<00:00, 40733.50it/s]


Extracting D:\ProgramData\data\MNIST\raw\t10k-images-idx3-ubyte.gz to D:\ProgramData\data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to D:\ProgramData\data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 524938.10it/s]


Extracting D:\ProgramData\data\MNIST\raw\t10k-labels-idx1-ubyte.gz to D:\ProgramData\data\MNIST\raw

Epoch 0
-------------------------------


batch:  468 loss:0.153268: 100%|██████████| 469/469 [00:20<00:00, 22.69it/s]


Epoch 1
-------------------------------


batch:  468 loss:0.130494: 100%|██████████| 469/469 [00:14<00:00, 32.85it/s]


Epoch 2
-------------------------------


batch:  468 loss:0.100625: 100%|██████████| 469/469 [00:13<00:00, 33.53it/s]


Epoch 3
-------------------------------


batch:  468 loss:0.080012: 100%|██████████| 469/469 [00:13<00:00, 33.55it/s]


Epoch 4
-------------------------------


batch:  468 loss:0.026354: 100%|██████████| 469/469 [00:14<00:00, 33.50it/s]


Epoch 5
-------------------------------


batch:  468 loss:0.060833: 100%|██████████| 469/469 [00:14<00:00, 33.01it/s]


Epoch 6
-------------------------------


batch:  468 loss:0.036658: 100%|██████████| 469/469 [00:15<00:00, 30.63it/s]


Epoch 7
-------------------------------


batch:  468 loss:0.031388: 100%|██████████| 469/469 [00:13<00:00, 34.16it/s]


Epoch 8
-------------------------------


batch:  468 loss:0.002927: 100%|██████████| 469/469 [00:14<00:00, 32.66it/s]


Epoch 9
-------------------------------


batch:  468 loss:0.023206: 100%|██████████| 469/469 [00:14<00:00, 33.39it/s]


Test Error: 
 Accuracy: 98.8%, Avg loss: 0.039674 

