In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import subprocess
import shutil
import warnings
from tqdm.notebook import tqdm
from collections import Counter
import copy
from typing import Dict, List
import random
import logging

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')

# Variables
IMAGE_DIR = '/content/result_dataset/'
BATCH_SIZE = 128
SEED = 21

torch.manual_seed(SEED)
random.seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Available device: {device}')

In [None]:
! unzip '/content/drive/MyDrive/ML projects/Image classification project/Data/result_dataset.zip' -d '/'

In [None]:
%%time

data_transforms_orig = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(256, 256),
                                      interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
        torchvision.transforms.CenterCrop(size=224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
])

# get original images
datasets_orig = {
    mode: torchvision.datasets.ImageFolder(root=IMAGE_DIR + mode, transform=data_transforms_orig)
    for mode in ['train', 'validation', 'test']
    }


# get augmented images
augmentations = torchvision.transforms.RandomChoice([
    torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        torchvision.transforms.RandomVerticalFlip(p=0.5)
    ]),
    torchvision.transforms.RandomRotation(degrees=(-45, 45)),
    torchvision.transforms.RandomHorizontalFlip(p=1),
    torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
    torchvision.transforms.RandomPerspective(distortion_scale=0.6, p=1.0),
    torchvision.transforms.AugMix()
])

data_transforms_aug = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=(256, 256),
                                      interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
        torchvision.transforms.RandomCrop(size=224),
        augmentations,
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
])

datasets_aug = {
    mode: torchvision.datasets.ImageFolder(root=IMAGE_DIR + mode, transform=data_transforms_aug)
    for mode in ['train', 'validation']
    }

In [None]:
def show_images(images, title):
    fig, ax = plt.subplots(1, len(images), figsize=(15, 6))
    for i, image in enumerate(images):
        ax[i].imshow(np.transpose(np.clip(image, 0, 1), (1, 2, 0)))
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=1.55)
    plt.show()

In [None]:
count = 0
orig_images = []
for image, label in iter(datasets_orig['train']):
    if count == 100:
        break
    orig_images.append(image)
    count += 1

count = 0
aug_images = []
for image, label in iter(datasets_aug['train']):
    if count == 100:
        break
    aug_images.append(image)
    count += 1

In [None]:
show_images(images=random.choices(orig_images, k=10), title='Original images')
show_images(images=random.choices(aug_images, k=10), title='Augmented images')

In [None]:
# combine original and augmented datasets
full_datasets = {
    mode: torch.utils.data.ConcatDataset([
        datasets_orig[mode],
        datasets_aug[mode]])
    for mode in ['train', 'validation']
}
full_datasets['test'] = datasets_orig['test']

In [None]:
print(f"Length of the original train dataset = {len(datasets_orig['train'])}")
print(f"Length of the original validation dataset = {len(datasets_orig['validation'])}")
print()
print(f"Length of the full train dataset with augmented data = {len(full_datasets['train'])}")
print(f"Length of the full validation dataset with augmented data = {len(full_datasets['validation'])}")

In [None]:
dataset_sizes = {
    mode: len(full_datasets[mode]) for mode in ['train', 'validation', 'test']
}

classes = full_datasets['train'].datasets[0].classes

In [None]:
class_count_0 = Counter(full_datasets['validation'].datasets[0].targets)
class_count_1 = Counter(full_datasets['validation'].datasets[1].targets)
class_count = class_count_0 + class_count_1
class_df = pd.DataFrame({
    'labels': class_count.keys(),
    'amount': class_count.values()
})

plt.figure(figsize=(15, 6))
sns.barplot(data=class_df, x='labels', y='amount')
plt.title('Unbalanced data')
plt.show()

In [None]:
%%time

def get_weights_for_balanced_data(dataset):
    if type(dataset) == torch.utils.data.dataset.ConcatDataset:
        class_count_0 = Counter(dataset.datasets[0].targets)
        class_count_1 = Counter(dataset.datasets[1].targets)
        class_count = class_count_0 + class_count_1
    elif type(dataset) == torchvision.datasets.folder.ImageFolder:
        class_count = Counter(dataset.targets)
    else:
        raise Exception("Incorrect type of dataset!")

    class_weights = {i: 1/c for i, c in class_count.items()}
    sample_weights = [0] * len(dataset)
    for i, (data, label) in enumerate(tqdm(dataset)):
        class_weight = class_weights[label]
        sample_weights[i] = class_weight
    N = max(class_count.values()) * len(class_count) # fit to max
    return sample_weights, N

dataloaders = {}

for mode in ['train', 'validation']:
    print(f'Create dataloader for {mode} dataset')
    sample_weights, N = get_weights_for_balanced_data(full_datasets[mode])
    dataset_sizes[mode] = N
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(sample_weights,
                                                                num_samples=N,
                                                                replacement=True)
    dataloaders[mode] = torch.utils.data.DataLoader(full_datasets[mode],
                                                batch_size=BATCH_SIZE,
                                                sampler=train_sampler)
dataloaders['test'] = torch.utils.data.DataLoader(full_datasets['test'],
                                                batch_size=BATCH_SIZE)

In [None]:
count = Counter()

for data, labels in tqdm(dataloaders['validation']):
    count += Counter(labels.tolist())

In [None]:
class_df = pd.DataFrame({
    'labels': count.keys(),
    'amount': count.values()
})

plt.figure(figsize=(15, 6))
sns.barplot(data=class_df, x='labels', y='amount')
plt.title('Balanced data')
plt.show()

In [None]:
print(f"Amount of data: {dataset_sizes['train'] + dataset_sizes['test'] + dataset_sizes['validation']}")

In [None]:
data, labels = next(iter(dataloaders['train']))
print(f'Data shape = {data.shape}')

In [None]:
def show_image(image, title: str = None):
  image = np.transpose(image.numpy(), (1, 2, 0))
  mean_param = [0.485, 0.456, 0.406]
  std_param = [0.229, 0.224, 0.225]
  image = std_param * image + mean_param
  image = np.clip(image, 0, 1)
  plt.imshow(image)
  if title != None:
    plt.title(title)
  plt.show()

images, _ = next(iter(dataloaders['train']))
image = torchvision.utils.make_grid(tensor=images)
show_image(image=image, title='Images from train dataset')

## Train a model

In [None]:
def train_model(model, criterion, optimizer, scheduler,
                path_for_checkpoint, start_epoch=0, num_epochs=20,
                losses=[], train_metrics=[], val_metrics=[]):
    best_model_weights = copy.deepcopy(model.state_dict())
    best_accuracy = 0

    for epoch in tqdm(range(start_epoch, num_epochs)):
        print(f'Epoch {epoch + 1} / {num_epochs}')
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_accuracy = running_corrects.double() / dataset_sizes[phase]
            print(f'Epoch: {epoch + 1}, Phase: {phase} | loss = {epoch_loss} | accuracy = {epoch_accuracy}')

            if phase == 'train':
                scheduler.step()
                losses.append(epoch_loss)
                train_metrics.append(epoch_accuracy)
            else:
                val_metrics.append(epoch_accuracy)

            if phase == 'validation' and epoch_accuracy > best_accuracy:
                best_accuracy = epoch_accuracy
                best_model_weights = copy.deepcopy(model.state_dict())

        model.load_state_dict(best_model_weights)

        # save checkpoint
        print(f'Saving checkpoint of {epoch + 1} epoch')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'loss': epoch_loss,
            'losses': losses,
            'train_metrics': train_metrics,
            'val_metrics': val_metrics
        }, path_for_checkpoint)

    return model, losses, train_metrics, val_metrics

In [None]:
def visualize_preds_on_test(model, grid_size=(5, 5)):
  model.to('cpu')
  model.eval()
  fig, ax = plt.subplots(grid_size[0], grid_size[1], figsize=(20, 20))
  count = 0

  with torch.no_grad():
    inputs, labels = next(iter(dataloaders['test']))
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    for i in range(grid_size[0]):
      for j in range(grid_size[1]):
        image = np.transpose(inputs.to('cpu').numpy()[count], (1, 2, 0))
        mean_param = [0.485, 0.456, 0.406]
        std_param = [0.229, 0.224, 0.225]
        image = std_param * image + mean_param
        image = np.clip(image, 0, 1)
        ax[i][j].imshow(image)
        ax[i][j].set_title(f'Pred: {classes[preds[count]]}\nTrue: {classes[labels[count]]}')
        count += 1
  fig.tight_layout()
  plt.show()

In [None]:
model = torchvision.models.resnet50(weights='IMAGENET1K_V1')

for param in model.parameters():
  param.requires_grad = False

in_features = model.fc.in_features

model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=in_features, out_features=1000, bias=True),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(in_features=1000, out_features=len(classes))
)

model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.fc.parameters(), lr=0.001, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.1)

NUM_EPOCHS = 15

In [None]:
path_for_checkpoint = '/content/drive/MyDrive/ML projects/Image classification project/checkpoint_state_dict.pt'
start_epoch = 0
epoch_loss = 0
losses = []
train_metrics = []
val_metrics = []

In [None]:
%%time
# Start training process
model_result, losses, train_metrics, val_metrics = train_model(model=model,
                                                               criterion=criterion,
                                                               optimizer=optimizer,
                                                               scheduler=lr_scheduler,
                                                               path_for_checkpoint=path_for_checkpoint,
                                                               start_epoch=start_epoch,
                                                               num_epochs=NUM_EPOCHS,
                                                               losses=losses,
                                                               train_metrics=train_metrics,
                                                               val_metrics=val_metrics)

In [None]:
# Resuming training process
path_for_checkpoint = '/content/drive/MyDrive/ML projects/Image classification project/checkpoint_state_dict.pt'

checkpoint = torch.load(path_for_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['scheduler'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
losses = checkpoint['losses']
train_metrics = checkpoint['train_metrics']
val_metrics = checkpoint['val_metrics']

model_result, losses, train_metrics, val_metrics = train_model(model=model,
                                                               criterion=criterion,
                                                               optimizer=optimizer,
                                                               scheduler=lr_scheduler,
                                                               path_for_checkpoint=path_for_checkpoint,
                                                               start_epoch=start_epoch + 1,
                                                               num_epochs=NUM_EPOCHS,
                                                               losses=losses,
                                                               train_metrics=train_metrics,
                                                               val_metrics=val_metrics)

In [None]:
visualize_preds_on_test(model, grid_size=(5, 5))

In [None]:
def calculate_accuracy(model, test_dataloader):
    model.to('cpu')
    model.eval()
    corrects = 0.0
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds == labels)
    return corrects.double() / dataset_sizes['test']

In [None]:
accuracy = calculate_accuracy(model, dataloaders['test'])

In [None]:
print(f'Accuracy on test set: {np.round(accuracy.item(), 2)}')

## Save trained model in Google drive

In [None]:
model = torchvision.models.resnet50(weights='IMAGENET1K_V1')

classes = ['apple', 'banana', 'bean', 'beetroot', 'bell pepper', 'bitter_gourd', 'bottle_gourd',
           'brinjal', 'broccoli', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper',
           'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon',
           'lettuce', 'mango', 'onion', 'orange', 'papaya', 'paprika', 'pear', 'peas', 'pineapple',
           'pomegranate', 'potato', 'pumpkin', 'raddish', 'radish', 'soy beans', 'spinach', 'sweetcorn',
           'sweetpotato', 'tomato', 'turnip', 'watermelon']

for param in model.parameters():
  param.requires_grad = False

in_features = model.fc.in_features

model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=in_features, out_features=1000, bias=True),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(in_features=1000, out_features=len(classes))
)

model = model.to(device)

In [None]:
path_for_checkpoint = '/content/drive/MyDrive/ML projects/Image classification project/checkpoint_state_dict.pt'

checkpoint = torch.load(path_for_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])

device = 'cpu'
model.to(device)

torch.save(model, '/content/drive/MyDrive/ML projects/Image classification project/result_model.pth')