In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [3]:
train_set_size = int(len(dataset) * 0.8)
valid_set_size = len(dataset) - train_set_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size])

#train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])

In [4]:
# After
print('Train data set:', len(train_set))
print('Valid data set:', len(valid_set))

Train data set: 327
Valid data set: 82


In [5]:
#train_set.dataset.samples

In [6]:
#dir(train_set.dataset)

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [8]:
classes = 6
model = models.alexnet(pretrained=True)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, classes)

device = torch.device('cpu')
model = model.to(device)

In [9]:
len(valid_set)

82

In [11]:
BEST_MODEL_PATH

'best_model.pth'

In [10]:
NUM_EPOCHS = 60
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.count_nonzero(torch.abs(labels - outputs.argmax(1))))
        
    test_accuracy = 1.0 - float(test_error_count) / float(len(valid_set))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

0: 0.292683
1: 0.317073
2: 0.378049
3: 0.500000
4: 0.463415
5: 0.463415
6: 0.560976
7: 0.658537
8: 0.609756
9: 0.682927
10: 0.658537
11: 0.560976
12: 0.731707
13: 0.609756
14: 0.670732
15: 0.695122
16: 0.719512
17: 0.695122
18: 0.731707
19: 0.621951
20: 0.743902
21: 0.756098
22: 0.743902
23: 0.621951
24: 0.634146
25: 0.695122
26: 0.792683
27: 0.731707
28: 0.719512
29: 0.719512
30: 0.695122
31: 0.707317
32: 0.634146
33: 0.743902
34: 0.743902
35: 0.768293
36: 0.731707
37: 0.719512
38: 0.731707
39: 0.731707
40: 0.707317
41: 0.743902
42: 0.743902
43: 0.731707
44: 0.792683
45: 0.756098
46: 0.658537
47: 0.707317
48: 0.743902
49: 0.707317
50: 0.743902
51: 0.756098
52: 0.792683
53: 0.695122
54: 0.731707
55: 0.670732
56: 0.731707
57: 0.756098
58: 0.731707
59: 0.719512
