# Image classification with transformers
A simple transformer based image classification workflow which classifies 
terrain from the 'Intel image classification dataset'.

The whole model is not trained rather a classification layer is added on top and finetuned.

In [1]:
import os
from pathlib import Path

import sys
from tqdm import tqdm
import time
import copy


import torch 
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import datasets, transforms

import torchvision
# from torchvision import datasets
from torchvision import models

In [2]:
# !pip install timm 
# LabelSmoothingCrossEntropy provides better results
import timm
from timm.loss import LabelSmoothingCrossEntropy 

In [3]:
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())

True
1
0


In [4]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [8]:
# set dataset paths
dataset_path = Path('../datasets/intel_image_classification')

# Setup train and testing paths
train_dir = dataset_path / "seg_train/seg_train"
test_dir = dataset_path / "seg_test/seg_test"

print('train test dir: \n', train_dir, '\n', test_dir)

os.path.exists(train_dir), os.path.exists(test_dir)

train test dir: 
 ../datasets/intel_image_classification/seg_train/seg_train 
 ../datasets/intel_image_classification/seg_test/seg_test


(True, True)

In [9]:
# Write the transformations for the purpose of data augmentations
# train data transform
train_data_transform = transforms.Compose([
  
  # image transformations
  transforms.RandomHorizontalFlip(),
  transforms.RandomVerticalFlip(),
  transforms.RandomApply(torch.nn.ModuleList([transforms.ColorJitter()]), p=0.25),
  transforms.Resize((224, 224)),

  # Turn the image into a torch.Tensor
  transforms.ToTensor(),

  # normalise
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
  transforms.RandomErasing(p=0.2, value='random'),
   
])

# test data transform
test_data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # normalise with imagenet pretrained values
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [10]:
# Use ImageFolder to create standard pytorch dataset(s)
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=train_data_transform, # a transform for the data
                                  target_transform=None) # a transform for the label/target 

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=test_data_transform)

train_data, test_data

(Dataset ImageFolder
     Number of datapoints: 14034
     Root location: ../datasets/intel_image_classification/seg_train/seg_train
     StandardTransform
 Transform: Compose(
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                RandomApply(
                p=0.25
                ColorJitter(brightness=None, contrast=None, saturation=None, hue=None)
            )
                Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
                ToTensor()
                Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=random, inplace=False)
            ),
 Dataset ImageFolder
     Number of datapoints: 3000
     Root location: ../datasets/intel_image_classification/seg_test/seg_test
     StandardTransform
 Transform: Compose(
                Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias

In [12]:
# check classes
assert train_data.classes == test_data.classes
n_classes = len(train_data.classes)
print('n_classes, classes: ', n_classes, train_data.classes)

n_classes, classes:  6 ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']


In [19]:
# setup training variables
BATCH_SIZE = 8
NUM_WORKERS = 1
LR = 0.001
# scheduler
LRS_STEP_SIZE = 3
LRS_GAMMA = 0.97
EPOCHS = 10

In [20]:
# Use the datasets created in last set to create dataloader
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=NUM_WORKERS,
                             shuffle=False)

train_dataloader, test_dataloader

(<torch.utils.data.dataloader.DataLoader at 0x7fef201d1fa0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fef201d1be0>)

In [14]:
# check loader shape
print(next(iter(train_dataloader))[0].shape)

torch.Size([8, 3, 224, 224])


In [21]:
# put dataloader into dict to be used in training loop
dataloaders = {
    "train": train_dataloader,
    "val": test_dataloader
}
dataset_sizes = {
    "train": len(train_data),
    "val": len(test_data)
}

print('dataloaders, dataset_sizes', dataloaders, dataset_sizes)

dataloaders, dataset_sizes {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fef201d1fa0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fef201d1be0>} {'train': 14034, 'val': 3000}


In [22]:
# download a pre-trained transformer model
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)

Using cache found in /home/bappadityadebnath/.cache/torch/hub/facebookresearch_deit_main


In [23]:
# freeze the model and add a classification head
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, n_classes)
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=192, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=6, bias=True)
)


In [28]:
# setup training loss and optimizer
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
# optimizer = optim.Adam(model.head.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=LR)

exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LRS_STEP_SIZE, gamma=LRS_GAMMA)

In [29]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)
        
        # one training and validation phase per epoch
        for phase in ['train', 'val']: 
            if phase == 'train':
                # model to training mode
                model.train() 
            else:
                # model to evaluate mode
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for idx, (inputs, labels) in tqdm(enumerate(dataloaders[phase])):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # accumulate looses and correct score
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                # print('torch.sum', torch.sum(preds == labels.data), running_corrects, idx, 8 * idx)
                
            if phase == 'train':
                # LRS step at end of epoch
                scheduler.step() 
            
            # calculate total looses
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            # save best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                # keep the best model
                best_model_wts = copy.deepcopy(model.state_dict()) 
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    # return best model
    model.load_state_dict(best_model_wts)
    return model

In [30]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=10)

Epoch 0/9
----------


1755it [00:19, 91.51it/s]

train Loss: 0.6925 Acc: 0.8792



375it [00:04, 88.40it/s]

val Loss: 0.6635 Acc: 0.8830

Epoch 1/9
----------



1755it [00:19, 88.32it/s]

train Loss: 0.6745 Acc: 0.8900



375it [00:03, 96.15it/s] 

val Loss: 0.6386 Acc: 0.9100

Epoch 2/9
----------



1755it [00:20, 87.30it/s]

train Loss: 0.6614 Acc: 0.8961



375it [00:03, 95.24it/s]

val Loss: 0.6230 Acc: 0.9163

Epoch 3/9
----------



1755it [00:19, 88.03it/s]

train Loss: 0.6554 Acc: 0.9016



375it [00:03, 97.93it/s] 

val Loss: 0.6243 Acc: 0.9200

Epoch 4/9
----------



1755it [00:19, 87.93it/s]

train Loss: 0.6475 Acc: 0.9048



375it [00:03, 94.46it/s] 

val Loss: 0.6236 Acc: 0.9133

Epoch 5/9
----------



1755it [00:20, 87.31it/s]

train Loss: 0.6400 Acc: 0.9118



375it [00:03, 97.41it/s] 

val Loss: 0.6221 Acc: 0.9137

Epoch 6/9
----------



1755it [00:19, 88.64it/s]

train Loss: 0.6388 Acc: 0.9111



375it [00:03, 96.63it/s] 

val Loss: 0.6345 Acc: 0.9100

Epoch 7/9
----------



1755it [00:19, 88.54it/s]

train Loss: 0.6352 Acc: 0.9134



375it [00:03, 95.66it/s]

val Loss: 0.6159 Acc: 0.9200

Epoch 8/9
----------



1755it [00:20, 87.29it/s]

train Loss: 0.6283 Acc: 0.9141



375it [00:03, 95.59it/s] 

val Loss: 0.6170 Acc: 0.9117

Epoch 9/9
----------



1755it [00:20, 86.87it/s]

train Loss: 0.6260 Acc: 0.9177



375it [00:03, 93.93it/s]

val Loss: 0.6079 Acc: 0.9217

Training complete in 3m 59s
Best Val Acc: 0.9217



