In [1]:
import torch
import torch.nn as nn

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.net = nn.Sequential(   # (b, 1, 28, 28)
            nn.Conv2d(1, 8, 3),    # (b, 8, 26, 26)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),     # (b, 8, 13, 13)
            nn.Conv2d(8, 16, 3),  # (b, 16, 11, 11)
            nn.ReLU(),
            nn.MaxPool2d(2, 2, padding=1),      # (b, 16, 6, 6)
            nn.Flatten()
        )
        feature_size = self.get_net_dims()
        self.classification = nn.Sequential(
            nn.Linear(feature_size, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, 10)            
        )
        
    def forward(self, x):
        features = self.net(x)
        return self.classification(features)
    
    def get_net_dims(self):
        x = torch.randn(1, 1, 28, 28)
        y = self.net(x)
        return y.numel()
    

In [3]:
model = CNN().to(device)
model.load_state_dict(torch.load('./models/mnist_cnn_model.pth'))
model.eval()

CNN(
  (net): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (classification): Sequential(
    (0): Linear(in_features=576, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=10, bias=True)
  )
)

In [24]:
batchSize = 4

mytf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root = './data', train=True, download=True, transform=mytf)
test_data = datasets.MNIST(root = './data', train=False, download=True, transform=mytf)

train_loader = DataLoader(train_data, batchSize, shuffle=True)
test_loader = DataLoader(test_data, batchSize, shuffle=True)

In [25]:
correct_count = 0
for input, target in test_loader:
    input, target = input.to(device), target.to(device)
    output = model(input)

    correct_count += sum(output.argmax(1) == target).item()

print(f"{correct_count} correct out of {len(test_data)}")
print(f'Accuracy = {correct_count/len(test_data) * 100}%')
    

1399 correct out of 10000
Accuracy = 13.99%


# Fine Tuning (if desired)

In [32]:
checkpoints = torch.load('./checkpoints/fashion_mnist_10.pth')
checkpoints.keys()

start_epoch = checkpoints['eopch']
model.load_state_dict(checkpoints['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(checkpoints['optimizer_state_dict'])

criteria = nn.CrossEntropyLoss()

In [None]:
loss_list = []
val_loss_list = []
accuracies = []
epochs = 5

def accuracy(outputs, labels):
    outputs = outputs.argmax(1)
    correct = sum(outputs == labels).item()
    return correct/len(outputs) * 100


for ep in range(start_epoch, epochs + start_epoch):
    model.train()
    
    running_loss = 0
    running_val_loss = 0
    running_accuracy = 0
    
    for input, label in train_loader:
        input, label = input.to(device), label.to(device)
        
        output = model(input)
        
        loss = criteria(output, label)
        running_loss += loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    avg_loss = running_loss / len(train_loader)
    loss_list.append(avg_loss.item())
    
    model.eval()
    for input, label in test_loader:
        input, label = input.to(device), label.to(device)
        with torch.no_grad():
            output = model(input)
        
        val_loss = criteria(output, label)
        running_val_loss += val_loss
        
        running_accuracy += accuracy(output, label)
    
    avg_val_loss = running_val_loss / len(test_loader)
    avg_accuracy = running_accuracy / len(test_loader)
    val_loss_list.append(avg_val_loss.item())
    accuracies.append(avg_accuracy)
        
        
    
    print(f"Ep {ep+1}/{epochs + start_epoch}: Loss = {avg_loss:.5f}; \tVal Loss = {avg_val_loss:.5f}; \tAcc = {avg_accuracy}")
    

Ep 11/5: Loss = 0.15200; 	Val Loss = 20.54652; 	Acc = 12.66
Ep 12/5: Loss = 0.14505; 	Val Loss = 18.17535; 	Acc = 14.23
Ep 13/5: Loss = 0.13860; 	Val Loss = 23.02817; 	Acc = 9.83
Ep 14/5: Loss = 0.13249; 	Val Loss = 19.22543; 	Acc = 10.87
Ep 15/5: Loss = 0.12606; 	Val Loss = 25.00200; 	Acc = 10.57
