In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from collections import Counter

In [2]:
dataset = datasets.ImageFolder(
    'dataset_v2',
    transforms.Compose([
        transforms.Grayscale(),
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
)

In [3]:
train_set_size = int(len(dataset) * 0.8)
valid_set_size = len(dataset) - train_set_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size])

In [4]:
print('Train data set:', len(train_set))
print('Valid data set:', len(valid_set))

Train data set: 536
Valid data set: 135


In [5]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [6]:
class_indexes = train_loader.dataset.dataset.class_to_idx
idx_to_class = {v:k for k,v in class_indexes.items()}
idx_to_class

{0: 'forward',
 1: 'full-left',
 2: 'full-right',
 3: 'half-left',
 4: 'half-right',
 5: 'reverse'}

In [7]:
# Counting the number of items in validation set
valid_set_image_count = Counter([label for _, label in valid_set])
valid_set_image_count = {idx_to_class[k]: v for k,v in valid_set_image_count.items()}

In [8]:
# dataiter = iter(train_loader)
# images, labels = dataiter.next()
# images[0].shape

In [9]:
classes = 6

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, 6)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 64 * 56 * 56)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.softmax(self.fc2(x), dim=1)
        return x

    
model = Net()

device = torch.device('cpu')
model = model.to(device)

In [10]:
NUM_EPOCHS = 10
BEST_MODEL_PATH = 'best_model_p4.pth'
best_accuracy = 0.0
rate_learning = 0.01

#optimizer = optim.Adam(model.parameters(), lr=rate_learning)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    test_error_count = 0.0
    wrong_items_count = {'forward':0, 'full-left':0, 'full-right':0, 'half-left':0, 'half-right':0, 'reverse':0}
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        
        predicted = outputs.argmax(1)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask).tolist()
        wrong_labels = torch.masked_select(labels,mask).tolist()
        
        # Decoding the wrong target variables
        wrong_labels = [idx_to_class[label] for label in wrong_labels]
        
        for wrong_label in wrong_labels:
            wrong_items_count[wrong_label] = wrong_items_count[wrong_label] + 1
        
        test_error_count += float(torch.count_nonzero(torch.abs(labels - outputs.argmax(1))))
       
    print("Wrong Count By Class: ", wrong_items_count)
    print("Wrong Count By Class(%)", {class_: round(float(missed*100)/valid_set_image_count[class_], 1) for class_, missed in wrong_items_count.items()})
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(valid_set))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

Wrong Count By Class:  {'forward': 0, 'full-left': 12, 'full-right': 16, 'half-left': 12, 'half-right': 12, 'reverse': 49}
Wrong Count By Class(%) {'forward': 0.0, 'full-left': 100.0, 'full-right': 100.0, 'half-left': 100.0, 'half-right': 100.0, 'reverse': 100.0}
0: 0.251852
Wrong Count By Class:  {'forward': 0, 'full-left': 12, 'full-right': 16, 'half-left': 12, 'half-right': 12, 'reverse': 49}
Wrong Count By Class(%) {'forward': 0.0, 'full-left': 100.0, 'full-right': 100.0, 'half-left': 100.0, 'half-right': 100.0, 'reverse': 100.0}
1: 0.251852
Wrong Count By Class:  {'forward': 22, 'full-left': 12, 'full-right': 16, 'half-left': 12, 'half-right': 12, 'reverse': 1}
Wrong Count By Class(%) {'forward': 64.7, 'full-left': 100.0, 'full-right': 100.0, 'half-left': 100.0, 'half-right': 100.0, 'reverse': 2.0}
2: 0.444444
Wrong Count By Class:  {'forward': 23, 'full-left': 12, 'full-right': 16, 'half-left': 12, 'half-right': 12, 'reverse': 0}
Wrong Count By Class(%) {'forward': 67.6, 'full-le

In [None]:
valid_set_image_count

In [None]:
best_accuracy