In [None]:
import glob
import os
import cv2
import numpy as np
import albumentations as A

from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

### This Jupyter notebook is made by Rytis Augustauskas. If you have any questions, please contact rytis.augustauskas@ktu.lt

# Where is your data?

In [None]:
image_dir = 'images/'

# Set a few variables for the output

In [None]:
# make output directory for plots
output_dir = 'output/'

# save plots
save_plots = True

# show plots
show_plots = True

In [None]:
# get all image paths in the given directory
def gather_image_from_dir(input_dir):
    image_extensions = ['*.bmp', '*.jpg', '*.png']
    image_list = []
    for image_extension in image_extensions:
        image_list.extend(glob.glob(input_dir + image_extension))
    image_list.sort()
    return image_list

In [None]:
def fit_image_to_screen(image, screen_width=1920, screen_height=1080, scale=0.75):
    height, width = image.shape[:2]
    width_scale = float(screen_width) / float(width)
    height_scale = float(screen_height) / float(height)
    # if image fits to desired screen size, do not resize
    if width_scale > 1.0:
        width_scale = 1.0
    if height_scale > 1.0:
        height_scale = 1.0
    image_scale = height_scale if width_scale > height_scale else width_scale
    image_scale *= scale
    resized_image = cv2.resize(image, (0, 0), fx=image_scale, fy=image_scale)
    return resized_image

In [None]:
# find annotation path according to the given name
def find_file_by_name(name, paths):
    for path in paths:
        # Extract the filename without extension from the path
        filename_without_extension = os.path.splitext(os.path.basename(path))[0]

        # Check if the filename without extension matches the given filename without extension
        if filename_without_extension == name:
            return path

In [None]:
# gather images and labels
image_paths = gather_image_from_dir(image_dir)
print(f'Image count: {len(image_paths)}')

In [None]:
def transform_image(image):
    transform = A.Compose([
            A.OneOf
            ([
                A.HorizontalFlip(p=0.6),
                A.ShiftScaleRotate(p=0.6,
                                   scale_limit=(-0.05, 0.05),
                                   shift_limit=(-0.05, 0.05),
                                   rotate_limit=(-10, 10))
            ]),
            A.OneOf
            ([
                A.RandomBrightnessContrast(p=0.5,
                                           brightness_limit=0.15,
                                           contrast_limit=0.15),
                A.RandomGamma(p=0.5,
                              gamma_limit=(75, 125))
            ]),
            A.ISONoise(p=0.5,
                       color_shift=(0.01, 0.1),
                       intensity=(0.2, 0.5))
            ])
    transformed = transform(image=image)
    return transformed["image"]

In [None]:
def plot_multiple_images(image_dict: dict, output_dir: str, image_name: str, save_files: bool = True):
    """
    image_dict - image dictionary consisting of image name and image 
    """
    size = 25
    fig, axes = plt.subplots(1, len(image_dict), figsize=(size * len(image_dict), size))
    for index, (name, image) in enumerate(image_dict.items()):
        if len(image.shape) == 2:
            axes[index].imshow(image, cmap='gray')
        else:
            axes[index].imshow(image)
        axes[index].set_title(name, fontsize=size*2)
    for ax in axes:
        ax.tick_params(axis='both', which='major', labelsize=size)  # Adjust the fontsize as needed
    plt.tight_layout()
    if save_files:
        os.makedirs(output_dir, exist_ok=True)
        output_path = output_dir + image_name + '.png'
        plt.savefig(output_path)

In [None]:
def make_few_aumentations(image, augmentations_count=1):
    """
    Return dictionary of original image + augmentations
    """
    image_dict = {'Original': image}    
    for index in range(augmentations_count):
        augm_image = transform_image(image)
        image_dict.update({f'Augmentation {index}': augm_image})
    return image_dict

# Make some image aumentation!

In [None]:
# how many augmentation do you want to do for the same image?
augmentations_in_row_count = 2 # how many augmentations do you want to see in row
augmentations_count = 3 # how time generate image row with augmentations?

In [None]:
for i, image_path in enumerate(image_paths):
    image_name = os.path.splitext(os.path.basename(image_path))[0]
    image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    image_height, image_width = image.shape[:2]
    # BGR (OpenCV original) to RGB
    image_RGB = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    for augmentation_i in range(augmentations_count):   
        # let's augment a few times to check how it will look
        image_dict = make_few_aumentations(image_RGB, augmentations_count=augmentations_in_row_count)
        plot_multiple_images(image_dict, output_dir, f'{image_name} [{augmentation_i}]', save_plots)

## In the following links you can find Albumentations applications with Tensorflow and PyTorch training routines:
1. Tensorflow image classification: https://github.com/rytisss/DL-defect-classification-with-CAM-output/blob/main/CAM%20classifiers.ipynb
2. PyTorch image classification: https://albumentations.ai/docs/examples/pytorch_classification/