In [38]:
import torch
from utilities.dataset_tools import load_dice_dataset

dataset, dataloader = load_dice_dataset('train')
class_from_idx = {v: k for k, v in dataset.class_to_idx.items()}
n_classes = len(class_from_idx)

torch_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def show_batch(images, targets):
    for image, t in zip(images, targets.tolist()):
        imshow(image)
        print(class_from_idx[t])


In [37]:
from architectures.die_net_v2 import DieNet

die_net = DieNet()

Train it!

In [4]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(die_net.parameters(), lr=0.001)

In [5]:
from datetime import datetime

for epoch in range(30):
    start_time = datetime.now()
    
    running_loss = 0.0
    for batch_idx, (images, targets) in enumerate(dataloader, 0):
        # show_batch(images, targets)
        optimizer.zero_grad()
        
        outputs = die_net(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if batch_idx % 100 == 99:
            print(f'[{epoch + 1}, {batch_idx + 1}] loss: {running_loss / 100}')
            running_loss = 0.0
    elapsed_time = datetime.now() - start_time
    print(f'Epoch {epoch + 1} complete after {elapsed_time.seconds} seconds')
    
    torch.save(die_net.state_dict(), f'../saved_weights/die_net_v2_epoch{epoch + 1}.pth')

print('Finished!')

[1, 100] loss: 1.7267863762378692
[1, 200] loss: 1.6351789569854736
[1, 300] loss: 1.5696976399421692
[1, 400] loss: 1.5100800091028213
[1, 500] loss: 1.4881608116626739
[1, 600] loss: 1.3805503273010253
[1, 700] loss: 1.3429132789373397
[1, 800] loss: 1.2732155811786652
Epoch 1 complete after 1803 seconds
[2, 100] loss: 1.2505623602867126
[2, 200] loss: 1.1303659921884537
[2, 300] loss: 1.118786066174507
[2, 400] loss: 1.0501513150334358
[2, 500] loss: 1.0741989076137544
[2, 600] loss: 1.0476237019896508
[2, 700] loss: 1.0580292689800261
[2, 800] loss: 0.9747830486297607
Epoch 2 complete after 1731 seconds
[3, 100] loss: 0.9206368863582611
[3, 200] loss: 0.9211164352297783
[3, 300] loss: 0.8703913089632987
[3, 400] loss: 0.8847077186405659
[3, 500] loss: 0.8679471503198147
[3, 600] loss: 0.9438100576400756
[3, 700] loss: 0.8423341657221317
[3, 800] loss: 0.8335633471608161
Epoch 3 complete after 1742 seconds
[4, 100] loss: 0.8251716339588165
[4, 200] loss: 0.7700160123407841
[4, 300] 

KeyboardInterrupt: 

In [6]:
val_dataset, val_dataloader = load_dice_dataset('valid')

In [24]:
# First index will be guessed class, second will be true class
confusion_matrix = torch.zeros(n_classes, n_classes)

print_this_many_outputs_directly = 0
with torch.no_grad():
    for val_batch, (images, targets) in enumerate(val_dataloader, 0):
    #     show_batch(images, targets)
        outputs = die_net(images)
        
        if print_this_many_outputs_directly:
            print('Some raw network outputs:')
            print(outputs)
            print_this_many_outputs_directly -= 1
        
        _, predicted = torch.max(outputs, 1)
        for guess, target in zip(predicted.tolist(), targets.tolist()):
#             print(f'Guess: {class_from_idx[guess]}, Target: {class_from_idx[target]}')
            confusion_matrix[guess, target] += 1

print('Confusion matrix:')
print(confusion_matrix)

Confusion matrix:
tensor([[231.,   6.,   0.,   2.,   1.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   1., 463.,   5.,   1.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.],
        [  1.,   6.,   1.,   9., 572.,   1.],
        [ 14., 228.,  14., 262.,   4., 217.]])


In [8]:
print(outputs)

tensor([[ 0.0000,  0.0000, 17.4824,  0.0000, 12.7442,  0.4030],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.7410,  0.0000, 59.1530,  2.4573],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 16.3997,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.6910, 31.3042],
        [ 8.5705,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


In [36]:
from utilities.confusion_tools import print_overall_accuracy, print_class_accuracies


print_overall_accuracy(confusion_matrix)
print_class_accuracies(confusion_matrix, class_from_idx=class_from_idx)

Overall accuracy: 72.73%
Accuracy for d10: 96.25%
Accuracy for d12: nan%
Accuracy for d20: 98.51%
Accuracy for d4 : nan%
Accuracy for d6 : 96.95%
Accuracy for d8 : 29.36%


In [None]:
# torch.save(die_net.state_dict(), '../saved_weights/die_net_v2_weights.pth')

In [27]:
confusion_matrix[:, 0]

tensor([231.,   0.,   0.,   0.,   1.,  14.])