### Using VGGNet

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

from collections import Counter

In [2]:
dataset = datasets.ImageFolder(
    'dataset_v2',
    transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(244),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
]))

In [3]:
train_set_size = int(len(dataset) * 0.8)
valid_set_size = len(dataset) - train_set_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size])

print('Train data set:', len(train_set))
print('Valid data set:', len(valid_set))


train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

Train data set: 536
Valid data set: 135


In [4]:
classes = 6

# Load the pre-trained VGG network
vgg16 = models.vgg16(pretrained=True)

# Modify the input layer of the network to accept grayscale images
vgg16.features[0] = nn.Conv2d(1, 64, kernel_size=3, padding=1)

# Freeze the pre-trained layers of the network
for param in vgg16.features.parameters():
    param.requires_grad = False

# Add new fully connected layers for your specific classification task
vgg16.classifier[6] = nn.Linear(4096, classes)

In [5]:
learning_rate = 0.01
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg16.parameters(), lr=learning_rate)

In [6]:
class_indexes = train_loader.dataset.dataset.class_to_idx
idx_to_class = {v:k for k,v in class_indexes.items()}
idx_to_class

{0: 'forward',
 1: 'full-left',
 2: 'full-right',
 3: 'half-left',
 4: 'half-right',
 5: 'reverse'}

In [7]:
# Counting the number of items in validation set
valid_set_image_count = Counter([label for _, label in valid_set])
valid_set_image_count = {idx_to_class[k]: v for k,v in valid_set_image_count.items()}

In [8]:
NUM_EPOCHS = 10
device = 'cpu'
BEST_MODEL_PATH = 'best_model_vgg.pth'
best_accuracy = 0.0
rate_learning = 0.01

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = vgg16(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    test_error_count = 0.0
    wrong_items_count = {'forward':0, 'full-left':0, 'full-right':0, 'half-left':0, 'half-right':0, 'reverse':0}
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = vgg16(images)
        
        predicted = outputs.argmax(1)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask).tolist()
        wrong_labels = torch.masked_select(labels,mask).tolist()
        
        # Decoding the wrong target variables
        wrong_labels = [idx_to_class[label] for label in wrong_labels]
        
        for wrong_label in wrong_labels:
            wrong_items_count[wrong_label] = wrong_items_count[wrong_label] + 1
        
        test_error_count += float(torch.count_nonzero(torch.abs(labels - outputs.argmax(1))))
       
    print("Wrong Count By Class: ", wrong_items_count)
    print("Wrong Count By Class(%)", {class_: round(float(missed*100)/valid_set_image_count[class_], 1) for class_, missed in wrong_items_count.items()})
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(valid_set))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(vgg16.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

Wrong Count By Class:  {'forward': 24, 'full-left': 19, 'full-right': 6, 'half-left': 8, 'half-right': 12, 'reverse': 21}
Wrong Count By Class(%) {'forward': 68.6, 'full-left': 95.0, 'full-right': 50.0, 'half-left': 100.0, 'half-right': 92.3, 'reverse': 44.7}
0: 0.333333
Wrong Count By Class:  {'forward': 24, 'full-left': 20, 'full-right': 8, 'half-left': 8, 'half-right': 13, 'reverse': 6}
Wrong Count By Class(%) {'forward': 68.6, 'full-left': 100.0, 'full-right': 66.7, 'half-left': 100.0, 'half-right': 100.0, 'reverse': 12.8}
1: 0.414815
Wrong Count By Class:  {'forward': 26, 'full-left': 20, 'full-right': 7, 'half-left': 8, 'half-right': 13, 'reverse': 2}
Wrong Count By Class(%) {'forward': 74.3, 'full-left': 100.0, 'full-right': 58.3, 'half-left': 100.0, 'half-right': 100.0, 'reverse': 4.3}
2: 0.437037
Wrong Count By Class:  {'forward': 15, 'full-left': 20, 'full-right': 6, 'half-left': 8, 'half-right': 13, 'reverse': 30}
Wrong Count By Class(%) {'forward': 42.9, 'full-left': 100.0,

KeyboardInterrupt: 