In [2]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}


def mnist_dataset():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    trainset = datasets.MNIST('./.pytorch/MNIST_data/', download=True, train=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testset = datasets.MNIST('./.pytorch/MNIST_data/', download=True, train=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

    return trainloader, testloader

In [48]:
class torchModel(nn.Module):
    def __init__ (self):
        super(torchModel, self).__init__()
        self.fc1 = nn.Linear(784, 98).to(torch.float32)
        self.fc2 = nn.Linear(98, 10).to(torch.float32)
    
    def forward(self, x):
        x = x.view(-1, 784).to(torch.float32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
model = torchModel()
optimizer = optim.Adam(model.parameters(), lr=0.003)

In [49]:
def compute_loss(logits, labels):
    return F.cross_entropy(logits, labels)

def compute_accuracy(logits, labels):
    return torch.mean((logits.argmax(dim=1) == labels).float())

def train_one_step(model, optimizer, x, y):
    model.train()
    logits = model(x)
    loss = compute_loss(logits, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

def test(model, x, y):
    model.eval()
    with torch.no_grad():
        logits = model(x)
        loss = compute_loss(logits, y)
        accuracy = compute_accuracy(logits, y)
    return loss, accuracy

In [50]:
# train this model
trainloader, testloader = mnist_dataset()
for epoch in range(10):
    for x, y in trainloader:
        loss = train_one_step(model, optimizer, x.to(torch.float32), y.to(torch.float32))
    print(f'epoch {epoch}, loss {loss}')
    
for x, y in testloader:
    loss, accuracy = test(model, x.to(torch.float32), y.to(torch.float32))
print(f'test, loss {loss}, accuracy {accuracy}')


epoch 0, loss 0.38105282187461853
epoch 1, loss 0.03815656155347824
epoch 2, loss 0.05254775658249855
epoch 3, loss 0.17444223165512085
epoch 4, loss 0.4422112703323364
epoch 5, loss 0.2748570740222931
epoch 6, loss 0.05122112110257149
epoch 7, loss 0.14479100704193115
epoch 8, loss 0.10320556908845901
epoch 9, loss 0.06474336981773376
test, loss 0.09843818843364716, accuracy 0.9375


In [None]:
# Save model
torch.save(model.state_dict(), 'model.pth')
# Load model
model = torchModel()
model.load_state_dict(torch.load('model.pth'))