In [1]:
import datetime
from IPython import display
import os

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms

from tqdm import tqdm

In [2]:
# set up device
use_cuda = torch.cuda.is_available()

if use_cuda:
    device = torch.device("cuda")
    print("Using GPU")
else:
    dtype = torch.FloatTensor
    device = torch.device("cpu")
    print("Not using GPU")
    
# load data
transform = {
    'train': transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05),
            transforms.ToTensor(),
        ]),
    'test': transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        ])
}
train_dataset = torchvision.datasets.ImageFolder('data/train', transform=transform['train'])
test_dataset  = torchvision.datasets.ImageFolder('data/test', transform=transform['test'])

class_names = train_dataset.classes
print(f'{len(train_dataset)} training images')
print(f'{len( test_dataset)} test images')
image, label = train_dataset[50]

BATCH_SIZE = 1
NUM_WORKERS = 10

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, 
                              shuffle=True, pin_memory=True)

test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, 
                            shuffle=False, pin_memory=True)

print(f"Train/test dataloaders have {len(train_dataloader)} and {len(test_dataloader)} batches")

Using GPU
5630 training images
1332 test images
Train/test dataloaders have 5630 and 1332 batches


In [3]:
class TonyNet(torch.nn.Module):
    def __init__(self, vgg):
        super(TonyNet, self).__init__()
        features = []
        for i, module in enumerate(vgg.features):
            if i <= 23:
                features.append(module)
        self.features = torch.nn.Sequential(*features)
        self.avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=(7, 7))
        self.classifier = torch.nn.Sequential(torch.nn.Linear(7 * 7 * 512, 2048),
                                              torch.nn.ReLU(inplace=True),
                                              torch.nn.Dropout(),
                                              torch.nn.Linear(2048, 1024),
                                              torch.nn.ReLU(inplace=True),
                                              torch.nn.Dropout(),
                                              torch.nn.Linear(1024, 3))
        
        
    def forward(self, x):
        feats = self.features(x)
        pooled = self.avg_pool(feats)
        fltnd = torch.nn.Flatten()(pooled)
        return self.classifier(fltnd)

In [4]:
def train(model, train_dataloader, val_dataloader, opt, criterion, n_epochs=100, chckpnt_path='./checkpoint.pth'):
    date = datetime.datetime.now().strftime("%b-%d-%Y-%H:%M:%S")
    writer_train = SummaryWriter(f'runs/{date}/train')
    writer_test = SummaryWriter(f'runs/{date}/test')
    scheduler = ReduceLROnPlateau(opt, factor=0.5)
    best_acc = 0

    for i in range(n_epochs):
        model.train()
        correct, total = 0, 0
        for j, (images, labels) in enumerate(tqdm(train_dataloader)):
            probs = model(images.to(device))
            with torch.no_grad():
                labels = labels.to(device)
                predictions = probs.max(1)[1]

                total += len(labels)
                correct += (predictions == labels).sum().item()

            loss = criterion(probs, labels)
            writer_train.add_scalar('Loss', loss, i * len(train_dataloader) + j)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        train_acc = correct / total
        writer_train.add_scalar('Accuracy', train_acc, i)

        model.eval()
        correct, total = 0, 0
        for j, (images, labels) in enumerate(tqdm(val_dataloader)):
            probs = model(images.to(device))
            labels = labels.to(device)
            predictions = probs.max(1)[1]
            total += len(labels)
            correct += (predictions == labels).sum().item()
            val_loss = criterion(probs, labels)
#                 scheduler.step(test_loss)
            writer_test.add_scalar('Loss', val_loss, 
                                   (i * len(val_dataloader) + j) * len(train_dataloader) / len(val_dataloader))
        val_acc = correct / total
        writer_test.add_scalar('Accuracy', val_acc, i)
        display.clear_output(True)
        print(f'Epoch number: {i}')
        print(f'Train accuracy: {train_acc}')
        print(f'Validation accuracy: {val_acc}')
        if val_acc > best_acc:
            torch.save(model.state_dict(), chckpnt_path)
            best_acc = val_acc
        
    return train_acc, val_acc

In [5]:
vgg16 = torchvision.models.vgg16()
tony_net = TonyNet(vgg16)
tony_net.to(device)
learning_rate = 1e-4
optimizer = torch.optim.Adam(tony_net.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
train_acc, test_acc = train(tony_net, train_dataloader, test_dataloader, optimizer, criterion, n_epochs=400)

  0%|          | 0/88 [00:00<?, ?it/s]

Epoch number: 53
Train accuracy: 0.9877442273534636
Validation accuracy: 0.9677177177177178


 70%|███████   | 62/88 [00:14<00:05,  4.73it/s]

KeyboardInterrupt: 

In [6]:
vgg16 = torchvision.models.vgg16()
criterion = torch.nn.CrossEntropyLoss()

device = 'cpu'

tony_net_loaded = TonyNet(vgg16)
tony_net_loaded.load_state_dict(torch.load('checkpoint.pth', map_location='cpu'))
tony_net_loaded.to(device)
tony_net_loaded.eval()
correct, total = 0, 0
for j, (images, labels) in enumerate(tqdm(test_dataloader)):
    with torch.no_grad():
        probs = tony_net_loaded(images.to(device))
        labels = labels.to(device)
        predictions = probs.max(1)[1]
        total += len(labels)
        correct += (predictions == labels).sum().item()
        val_loss = criterion(probs, labels)

100%|██████████| 1332/1332 [01:03<00:00, 20.89it/s]


In [7]:
val_accuracy = 100 * correct / total
print("Validation accuracy: %.2f%%" % val_accuracy)

Validation accuracy: 97.45%


In [5]:
vgg16 = torchvision.models.vgg16()

In [8]:
%%timeit 
tony_net_loaded = TonyNet(vgg16)
device = 'cpu'
tony_net_loaded.load_state_dict(torch.load('checkpoint.pth', map_location='cpu'))
tony_net_loaded.to(device)
tony_net_loaded.eval()

369 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
