# An Implementation of Augmix Based on albumentations


Hi, here is my implementation of Augmix augmentation based on albumentations. If you find it helpful please upvote me.
Thanks!

![](https://www.googleapis.com/download/storage/v1/b/kaggle-user-content/o/inbox%2F448347%2Fe9d3259d5b0ef3dadba0238bf3901f1e%2F2020-02-10%2013.50.53.png?generation=1581310282118474&alt=media)

Augmix: https://arxiv.org/abs/1912.02781

Official implementation: https://github.com/google-research/augmix

albumentationsï¼š https://github.com/albumentations-team/albumentations

**Notice: I didn't normalize images in my implementation** : https://github.com/google-research/augmix/blob/master/augment_and_mix.py#L25-L30

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import albumentations
from PIL import Image, ImageOps, ImageEnhance
from albumentations.core.transforms_interface import ImageOnlyTransform
from albumentations.augmentations import functional as F
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader, Dataset

data_dir = '../input/bengaliai-cv19'
files_train = [f'train_image_data_{fid}.parquet' for fid in range(1)]

HEIGHT = 137
WIDTH = 236

In [None]:
df_train = pd.read_csv(os.path.join(data_dir, f'train.csv'))

In [None]:
def read_data(files):
    tmp = []
    for f in files:
        F = os.path.join(data_dir, f)
        data = pd.read_parquet(F)
        tmp.append(data)
    tmp = pd.concat(tmp)

    data = tmp.iloc[:, 1:].values
    return data

# train data
data_train = read_data(files_train)

In [None]:
def int_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .
    Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled to
      level/PARAMETER_MAX.
    Returns:
    An int that results from scaling `maxval` according to `level`.
    """
    return int(level * maxval / 10)


def float_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval.
    Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled to
      level/PARAMETER_MAX.
    Returns:
    A float that results from scaling `maxval` according to `level`.
    """
    return float(level) * maxval / 10.


def sample_level(n):
    return np.random.uniform(low=0.1, high=n)


def autocontrast(pil_img, _):
    return ImageOps.autocontrast(pil_img)


def equalize(pil_img, _):
    return ImageOps.equalize(pil_img)


def posterize(pil_img, level):
    level = int_parameter(sample_level(level), 4)
    return ImageOps.posterize(pil_img, 4 - level)


def rotate(pil_img, level):
    degrees = int_parameter(sample_level(level), 30)
    if np.random.uniform() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees, resample=Image.BILINEAR)


def solarize(pil_img, level):
    level = int_parameter(sample_level(level), 256)
    return ImageOps.solarize(pil_img, 256 - level)


def shear_x(pil_img, level):
    level = float_parameter(sample_level(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size,
                           Image.AFFINE, (1, level, 0, 0, 1, 0),
                           resample=Image.BILINEAR)


def shear_y(pil_img, level):
    level = float_parameter(sample_level(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size,
                           Image.AFFINE, (1, 0, 0, level, 1, 0),
                           resample=Image.BILINEAR)


def translate_x(pil_img, level):
    level = int_parameter(sample_level(level), pil_img.size[0] / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size,
                           Image.AFFINE, (1, 0, level, 0, 1, 0),
                           resample=Image.BILINEAR)


def translate_y(pil_img, level):
    level = int_parameter(sample_level(level), pil_img.size[0] / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform(pil_img.size,
                           Image.AFFINE, (1, 0, 0, 0, 1, level),
                           resample=Image.BILINEAR)


# operation that overlaps with ImageNet-C's test set
def color(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Color(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def contrast(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Contrast(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def brightness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Brightness(pil_img).enhance(level)


# operation that overlaps with ImageNet-C's test set
def sharpness(pil_img, level):
    level = float_parameter(sample_level(level), 1.8) + 0.1
    return ImageEnhance.Sharpness(pil_img).enhance(level)


augmentations = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y
]

augmentations_all = [
    autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
    translate_x, translate_y, color, contrast, brightness, sharpness
]

def normalize(image):
    """Normalize input image channel-wise to zero mean and unit variance."""
    return image - 127

def apply_op(image, op, severity):
    #   image = np.clip(image, 0, 255)
    pil_img = Image.fromarray(image)  # Convert to PIL.Image
    pil_img = op(pil_img, severity)
    return np.asarray(pil_img)

def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.):
    """Perform AugMix augmentations and compute mixture.
    Args:
    image: Raw input image as float32 np.ndarray of shape (h, w, c)
    severity: Severity of underlying augmentation operators (between 1 to 10).
    width: Width of augmentation chain
    depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
      from [1, 3]
    alpha: Probability coefficient for Beta and Dirichlet distributions.
    Returns:
    mixed: Augmented and mixed image.
    """
    ws = np.float32(
      np.random.dirichlet([alpha] * width))
    m = np.float32(np.random.beta(alpha, alpha))

    mix = np.zeros_like(image).astype(np.float32)
    for i in range(width):
        image_aug = image.copy()
        depth = depth if depth > 0 else np.random.randint(1, 4)
        for _ in range(depth):
            op = np.random.choice(augmentations)
            image_aug = apply_op(image_aug, op, severity)
        # Preprocessing commutes since all coefficients are convex
        mix += ws[i] * image_aug
#         mix += ws[i] * normalize(image_aug)

    mixed = (1 - m) * image + m * mix
#     mixed = (1 - m) * normalize(image) + m * mix
    return mixed


class RandomAugMix(ImageOnlyTransform):

    def __init__(self, severity=3, width=3, depth=-1, alpha=1., always_apply=False, p=0.5):
        super().__init__(always_apply, p)
        self.severity = severity
        self.width = width
        self.depth = depth
        self.alpha = alpha

    def apply(self, image, **params):
        image = augment_and_mix(
            image,
            self.severity,
            self.width,
            self.depth,
            self.alpha
        )
        return image



In [None]:
class BengaliDataset(Dataset):
    def __init__(self, csv, data, idx, split, mode, image_size, transform=None):

        self.csv = csv.reset_index()
        self.data = data
        self.idx = np.asarray(idx)
        self.split = split
        self.mode = mode
        self.image_size = image_size
        self.transform = transform

    def __len__(self):
        return self.idx.shape[0]

    def __getitem__(self, index):
        index = self.idx[index]
        this_img_id = self.csv.iloc[index].image_id
        
        image = self.data[index].reshape(HEIGHT, WIDTH)
        image = cv2.resize(image, (self.image_size, self.image_size))

        if self.transform is not None:
            res = self.transform(image=image)
            image = res['image'].astype(np.float32)
        else:
            image = image.astype(np.float32)

        image /= 255
        image = image[np.newaxis, :, :]
        image = 1 - image
        image = np.repeat(image, 3, 0)  # 1ch to 3ch

        if self.mode == 'test':
            return torch.tensor(image)
        else:
            label_1 = self.csv.iloc[index].grapheme_root
            label_2 = self.csv.iloc[index].vowel_diacritic
            label_3 = self.csv.iloc[index].consonant_diacritic
            label = [label_1, label_2, label_3]
            return torch.tensor(image), torch.tensor(label)

In [None]:
def plot_imgs(dataset_show):
    from pylab import rcParams
    rcParams['figure.figsize'] = 20,10
    for i in range(2):
        f, axarr = plt.subplots(1,5)
        for p in range(5):
            idx = np.random.randint(0, len(dataset_show))
            img, label = dataset_show[idx]
            axarr[p].imshow(img.transpose(0, 1).transpose(1,2).squeeze())
            axarr[p].set_title(idx)

# Usage Example

### Official default setting -> Looks dirty

In [None]:
transforms_train = albumentations.Compose([
    RandomAugMix(severity=3, width=3, alpha=1., p=1.),
])

df_show = df_train.iloc[:1000]
dataset_show = BengaliDataset(df_show, data_train, list(range(df_show.shape[0])), 'train', 'train', 128, transform=transforms_train)
plot_imgs(dataset_show)

### severity=1, width=1 -> Not that dirty?

In [None]:
transforms_train = albumentations.Compose([
    RandomAugMix(severity=1, width=1, p=1.),
])

df_show = df_train.iloc[:1000]
dataset_show = BengaliDataset(df_show, data_train, list(range(df_show.shape[0])), 'train', 'train', 128, transform=transforms_train)
plot_imgs(dataset_show)

### severity=7, width=7, alpha=5 -> Very aggressive transform

In [None]:
transforms_train = albumentations.Compose([
    RandomAugMix(severity=7, width=7, alpha=5, p=1.),
])

df_show = df_train.iloc[:1000]
dataset_show = BengaliDataset(df_show, data_train, list(range(df_show.shape[0])), 'train', 'train', 128, transform=transforms_train)
plot_imgs(dataset_show)