In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from collections import Counter

In [2]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [3]:
train_set_size = int(len(dataset) * 0.8)
valid_set_size = len(dataset) - train_set_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size])

#train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])

In [4]:
# After
print('Train data set:', len(train_set))
print('Valid data set:', len(valid_set))

Train data set: 327
Valid data set: 82


In [5]:
#train_set.dataset.samples

In [6]:
#dir(train_set.dataset)

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [8]:
class_indexes = train_loader.dataset.dataset.class_to_idx
idx_to_class = {v:k for k,v in class_indexes.items()}
idx_to_class

{0: 'forward',
 1: 'full-left',
 2: 'full-right',
 3: 'half-left',
 4: 'half-right',
 5: 'reverse'}

In [9]:
# Counting the number of items in validation set

valid_set_image_count = Counter([label for _, label in valid_set])
valid_set_image_count = {idx_to_class[k]: v for k,v in valid_set_image_count.items()}

In [10]:
classes = 6
model = models.alexnet(pretrained=True)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, classes)

device = torch.device('cpu')
model = model.to(device)

In [None]:
NUM_EPOCHS = 100
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    test_error_count = 0.0
    wrong_items_count = {'forward':0, 'full-left':0, 'full-right':0, 'half-left':0, 'half-right':0, 'reverse':0}
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        
        predicted = outputs.argmax(1)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask).tolist()
        wrong_labels = torch.masked_select(labels,mask).tolist()
        
        # Decoding the wrong target variables
        wrong_labels = [idx_to_class[label] for label in wrong_labels]
        
        for wrong_label in wrong_labels:
            wrong_items_count[wrong_label] = wrong_items_count[wrong_label] + 1
        
        test_error_count += float(torch.count_nonzero(torch.abs(labels - outputs.argmax(1))))
       
    print("Wrong Count By Class: ", wrong_items_count)
    print("Wrong Count By Class(%)", {class_: round(float(missed*100)/valid_set_image_count[class_], 1) for class_, missed in wrong_items_count.items()})
    
    
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(valid_set))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

Wrong Count By Class:  {'forward': 12, 'full-left': 7, 'full-right': 2, 'half-left': 9, 'half-right': 8, 'reverse': 11}
Wrong Count By Class(%) {'forward': 38.7, 'full-left': 100.0, 'full-right': 100.0, 'half-left': 90.0, 'half-right': 88.9, 'reverse': 47.8}
0: 0.402439
