**This Notebook mostly cover with visual demmonstration of various types of advance augmentation. For that, It becomes heavier to load quickly and may take few time to load the notebook.**

![image](https://user-images.githubusercontent.com/17668390/169665594-608f7468-7323-41f8-9ba0-400fe9eb828f.gif)

---

### Update : Mosaic Augmentation

I am trying to add the **Mosaic** augmentation. However, it's not completed yet. To create a class label in `CutMix` or `MixUp` type augmentation, we can use `beta` such as `np.random.beta` or `scipy.stats.beta` and do as follows for two labels:


```
label = label_one*beta + (1-beta)*label_two
```

But what if we've **more than two** images? In [YoLo4](https://arxiv.org/abs/2004.10934), they've tried an interesting augmentation called **Mosaic Augmentation** for object detection problems. Unlike `CutMix` or `MixUp`, this augmentation creates augmented samples with **4** images. Here is a asked question over [Stack Overflow](https://stackoverflow.com/questions/65181294/how-to-create-class-label-for-mosaic-augmentation-in-image-classification), there we can find some approaches, please feel free to share your implementation. -)

# Advanced Augmentation

Hi, This is a simple EDA and data augmentation pipeline for multi-class image classification with custom sequence data generator in `tf.keras`. Here image samples will be used. Mainly I will show how you can use some of the advanced augmentation in a custom `tf.keras.utils.Sequence` generator in `tf.keras`. Note, **we will add only multi-class image classification data set**, e.g., 

- [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification)
- [APTOS 2019](https://www.kaggle.com/c/aptos2019-blindness-detection)


The advanced augmentaiton are as follows:

```
- CutMix
- MixUp
- FMix
- RGBShift
- ChannelShuffle
- ColorJitter
```

The implementations of `CutMix` and `MixUp` augmentation are taken from [Chris Deotte](https://www.kaggle.com/cdeotte/cutmix-and-mixup-on-gpu-tpu) and integrated into a custom [tf.keras.utils.Sequence](https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence) generator with few modification. The `FMix` is simply taken from the original source code, from [here](https://github.com/ecs-vlc/FMix). 

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from glob import glob
from pylab import rcParams
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import os, gc, cv2, random, warnings, math, sys, json, pprint

# sklearn
from sklearn.utils import class_weight
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score

# tf 
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
warnings.simplefilter('ignore')

In [None]:
# helper function to plot sample 
def plot_imgs(dataset_show, row, col):
    rcParams['figure.figsize'] = 30,15
    for i in range(row):
        f, ax = plt.subplots(1,col)
        for p in range(col):
            idx = np.random.randint(0, len(dataset_show))
            img, label = dataset_show[idx]
            ax[p].grid(False)
            ax[p].imshow(img[0])
            try:
                ax[p].set_title(label[0].numpy())
            except:
                ax[p].set_title(label[0])
    plt.show()
    

def visulize(path, n_images, is_random=True, figsize=(16, 16)):
    plt.figure(figsize=figsize)
    
    w = int(n_images ** .5)
    h = math.ceil(n_images / w)
    
    image_names = os.listdir(path)
    for i in range(n_images):
        image_name = image_names[i]
        
        if is_random:
            image_name = random.choice(image_names)
            
        img = cv2.imread(os.path.join(path, image_name))
        plt.subplot(h, w, i + 1)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
    plt.show()

# Competition Data

Currently, we have 3 multi-class competition data to play with. We will add more and also add new interesting augmentation. We can choose any of the following competition data set.

- [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification)
- [APTOS 2019](https://www.kaggle.com/c/aptos2019-blindness-detection)

**Choose a Dataset**

In [None]:
# use casava leaf disease comp. (multi-class problem)
use_casava = False

# use aptos comp. (multi-class prob)
use_aptos  = False

# use flower recognition data from kaggle 
use_flower_recognition = True

In [None]:
IMAGE_DIM = (224, 224, 3)
BATCH_SIZ = 25
SEED  = 101

In [None]:
class BaseConfig(object):
    if use_casava:
        TRAIN_DF = '../input/cassava-leaf-disease-classification/train.csv'
        TRAIN_IMG_PATH = '../input/cassava-leaf-disease-classification/train_images/'
        TEST_IMG_PATH  = '../input/cassava-leaf-disease-classification/test_images/'
        CLASS_MAP  = '../input/cassava-leaf-disease-classification/label_num_to_disease_map.json'
        NUM_CLASSES = 5
    
    elif use_aptos:
        TRAIN_DF = '../input/aptos2019-blindness-detection/train.csv'
        TRAIN_IMG_PATH = '../input/aptos2019-blindness-detection/train_images/'
        TEST_IMG_PATH  = '../input/aptos2019-blindness-detection/test_images/'
        NUM_CLASSES = 5
    
    elif use_flower_recognition:
        TRAIN_IMG_PATH = '../input/flowers-recognition/flowers'
        train_datagen = ImageDataGenerator()
        train_generator = train_datagen.flow_from_directory(
            TRAIN_IMG_PATH,
            target_size=IMAGE_DIM[:2],
            batch_size=BATCH_SIZ,
            seed=SEED, 
            shuffle=True,
            class_mode='categorical'
        )
        NUM_CLASSES = 5

**Overview**

In [None]:
try:
    df = pd.read_csv(BaseConfig.TRAIN_DF)
    
    if use_casava:
        assert df.shape[0] == len(df.image_id.unique()) , "NOT ALL ID UNIQUE"
    elif use_aptos:
        assert df.shape[0] == len(df.id_code.unique()), "NOT ALL ID UNIQUE"
    
    print(df.info())
    df.head()
except:
    pass

# Custom Sequence Data Generator

In [None]:
class SequenceGenerator(keras.utils.Sequence):
    def __init__(self, 
                 img_path, 
                 data, 
                 batch_size, 
                 dim, 
                 shuffle=True, 
                 use_mosaicmix=False):
        self.dim  = dim
        self.data = data
        self.shuffle  = shuffle
        self.img_path = img_path
        self.batch_size = batch_size
        self.use_mosaicmix = use_mosaicmix
        self.list_idx   = self.data.index.values
        if use_casava:
            self.label = pd.get_dummies(self.data['label'], columns = ['label'])
        elif use_aptos:
            self.label = pd.get_dummies(self.data['diagnosis'], columns = ['diagnosis'])
        self.on_epoch_end()
        
    def __len__(self):
        return int(np.ceil(float(len(self.data)) / float(self.batch_size)))
    
    def __getitem__(self, index):
        batch_idx = self.indices[index*self.batch_size:(index+1)*self.batch_size]
        idx = [self.list_idx[k] for k in batch_idx]
        
        Data   = np.empty((self.batch_size, *self.dim))
        Target = np.empty((self.batch_size, BaseConfig.NUM_CLASSES), dtype = np.float32)

        for i, k in enumerate(idx):
            # load the image file using cv2
            if use_casava:
                image = cv2.imread(self.img_path + self.data['image_id'][k])
                image = image[:,:,::-1]
            elif use_aptos:
                image = cv2.imread(self.img_path + self.data['id_code'][k] + '.png')
                image = image[:,:,::-1]
            image = cv2.resize(image, self.dim[:2])

            # assign 
            Data[i,] =  image
            Target[i,] = self.label.iloc[k,].values
            
        if self.use_mosaicmix:
            Data, Target = MosaicMix(Data, Target, self.dim[0]) 

        return Data, Target 
    
    def on_epoch_end(self):
        self.indices = np.arange(len(self.list_idx))
        if self.shuffle:
            np.random.shuffle(self.indices)

In [None]:
if use_casava or use_aptos:
    datagens = SequenceGenerator(
            BaseConfig.TRAIN_IMG_PATH, 
            df, 
            batch_size=BATCH_SIZ,
            dim=IMAGE_DIM,
            shuffle = True
        )
else:
    datagens = BaseConfig.train_generator

In [None]:
images, labels = next(
    iter
    (
        datagens
    )
)

images.shape, labels.shape

## CutMix Augmentation

In [None]:
def CutMix(image, label, DIM, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    CLASSES = BaseConfig.NUM_CLASSES
    
    imgs = []; labs = []
    for j in range(len(image)):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,len(image)),tf.int32)
        
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM,y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM,x+WIDTH//2)
        
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        
        # MAKE CUTMIX LABEL
        a = tf.cast(WIDTH*WIDTH/DIM/DIM,tf.float32)
        labs.append((1-a)*label[j] + a*label[k])
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(len(image),DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(len(image),CLASSES))
    
    return image2,label2

In [None]:
cutmix_image, cutmix_label = CutMix(images, labels, IMAGE_DIM[0])

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = cutmix_image[i].numpy().astype('int')
    lbl = cutmix_label[i].numpy()
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
plt.show()

## MixUp Augmentation

In [None]:
def MixUp(image, label, DIM, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with mixup applied
    CLASSES = BaseConfig.NUM_CLASSES
    
    imgs = []; labs = []
    for j in range(len(image)):
        # DO MIXUP WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.float32)
                   
        # CHOOSE RANDOM
        k = tf.cast( tf.random.uniform([],0,len(image)),tf.int32)
        a = tf.random.uniform([],0,1)*P # this is beta dist with alpha=1.0
                    
        # MAKE MIXUP IMAGE
        img1 = image[j,]
        img2 = image[k,]
        imgs.append((1-a)*img1 + a*img2)
                    
        # MAKE CUTMIX LABEL
        labs.append((1-a)*label[j] + a*label[k])
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(len(image),DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(len(image),CLASSES))
    return image2,label2

In [None]:
mixup_image, mixup_label = MixUp(images, labels, IMAGE_DIM[0])

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = mixup_image[i].numpy().astype('int')
    lbl = mixup_label[i].numpy()
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
    
plt.show()

## FMix Augmentation

In [None]:
sys.path.insert(0, "/kaggle/input/pyutils")
from fmix_utils import sample_mask

def FMix(image, label, DIM,  alpha=1, decay_power=3, max_soft=0.0, reformulate=False):
    lam, mask = sample_mask(alpha, decay_power,(DIM, DIM), max_soft, reformulate)
    index = tf.constant(np.random.permutation(int(image.shape[0])))
    mask  = np.expand_dims(mask, -1)
    
    # samples 
    image1 = image * mask
    image2 = tf.gather(image, index) * (1 - mask)
    image3 = image1 + image2

    # labels
    label1 = label * lam 
    label2 = tf.gather(label, index) * (1 - lam)
    label3 = label1 + label2 
    return image3, label3

In [None]:
fmix_image, fmix_label = FMix(images, labels, IMAGE_DIM[0])

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = fmix_image[i].numpy().astype('int')
    lbl = fmix_label[i].numpy()
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
    
plt.show()

## Mosaic Augmentation [Work In Progress]

In [None]:
def MosaicMix(image, label, DIM, minfrac=0.25, maxfrac=0.75):
    xc, yc  = np.random.randint(DIM * minfrac, DIM * maxfrac, (2,))
    indices = np.random.permutation(int(image.shape[0]))
    mosaic_image = np.zeros((DIM, DIM, 3), dtype=np.float32)
    final_imgs   = []
    
    # Iterate over the full indices 
    for j in range(len(indices)): 
        # Take 4 sample for to create a mosaic sample randomly 
        rand4indices = [j] + random.sample(list(indices), 3) 
        
        # Make mosaic with 4 samples 
        for i in range(len(rand4indices)):
            if i == 0:    # top left
                x1a, y1a, x2a, y2a =  0,  0, xc, yc
                x1b, y1b, x2b, y2b = DIM - xc, DIM - yc, DIM, DIM # from bottom right        
            elif i == 1:  # top right
                x1a, y1a, x2a, y2a = xc, 0, DIM , yc
                x1b, y1b, x2b, y2b = 0, DIM - yc, DIM - xc, DIM # from bottom left
            elif i == 2:  # bottom left
                x1a, y1a, x2a, y2a = 0, yc, xc, DIM
                x1b, y1b, x2b, y2b = DIM - xc, 0, DIM, DIM-yc   # from top right
            elif i == 3:  # bottom right
                x1a, y1a, x2a, y2a = xc, yc,  DIM, DIM
                x1b, y1b, x2b, y2b = 0, 0, DIM-xc, DIM-yc    # from top left
                
            # Copy-Paste
            mosaic_image[y1a:y2a, x1a:x2a] = image[i,][y1b:y2b, x1b:x2b]
                   
        # Append the Mosiac samples
        final_imgs.append(mosaic_image)
 
    return final_imgs, label

In [None]:
def plot_imgs(dataset_show, row, col):
    rcParams['figure.figsize'] = 30,15
    for i in range(row):
        f, ax = plt.subplots(1,col)
        for p in range(col):
            idx = np.random.randint(0, len(dataset_show))
            img, label = dataset_show[idx]
            ax[p].grid(False)
            ax[p].imshow(img[0]/255.)
            ax[p].axis('off')
    plt.show()

In [None]:
# if not use_casava or use_aptos:
#     # WIP
#     ds = tf.data.Dataset.from_generator(
#         lambda: datagens, 
#         output_types=(tf.float32, tf.float32), 
#         output_shapes=([BATCH_SIZ, 224, 224, 3], [BATCH_SIZ, BaseConfig.NUM_CLASSES])
#     )
    
#     ds_mos = ds.map(lambda x, y: MosaicMix(x, y, IMAGE_DIM[0]))

In [None]:
if use_casava or use_aptos:
    mosaic_gens = SequenceGenerator(
        BaseConfig.TRAIN_IMG_PATH, 
        df, 
        batch_size=BATCH_SIZ,
        dim=IMAGE_DIM,
        shuffle = True,
        use_mosaicmix = True
    )
    
    plot_imgs(mosaic_gens, 10, 3)

## Channel Shuffle

In [None]:
from tensorflow.keras import layers

class ChannelShuffle(layers.Layer):
    def __init__(self, groups=3, seed=None, **kwargs):
        super().__init__(**kwargs)
        self.groups = groups
        self.seed = seed

    def _channel_shuffling(self, images):
        unbatched = images.shape.rank == 3
        if unbatched:
            images = tf.expand_dims(images, axis=0)

        height = tf.shape(images)[1]
        width = tf.shape(images)[2]
        num_channels = images.shape[3]

        if not num_channels % self.groups == 0:
            raise ValueError(
                "The number of input channels should be "
                "divisible by the number of groups."
                f"Received: channels={num_channels}, groups={self.groups}"
            )

        channels_per_group = num_channels // self.groups
        images = tf.reshape(
            images, [-1, height, width, self.groups, channels_per_group]
        )
        images = tf.transpose(images, perm=[3, 1, 2, 4, 0])
        images = tf.random.shuffle(images, seed=self.seed)
        images = tf.transpose(images, perm=[4, 1, 2, 3, 0])
        images = tf.reshape(images, [-1, height, width, num_channels])

        if unbatched:
            images = tf.squeeze(images, axis=0)

        return images

    def call(self, images, training=True):
        if training:
            return self._channel_shuffling(images)
        else:
            return images

    def get_config(self):
        config = super().get_config()
        config.update({"groups": self.groups, "seed": self.seed})
        return config

In [None]:
chlshl_image = ChannelShuffle(groups=3)(images)

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = chlshl_image[i].numpy().astype('int')
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
plt.show()

## Color Jitter

In [None]:
class ColorJitter(layers.Layer):
    def __init__(
        self,
        brightness_factor=0.5,
        contrast_factor=(0.5, 0.9),
        saturation_factor=(0.5, 0.9),
        hue_factor=0.5,
        seed=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.seed = seed
        self.brightness_factor = self._check_factor_limit(
            brightness_factor, name="brightness"
        )
        self.contrast_factor = self._check_factor_limit(
            contrast_factor, name="contrast"
        )
        self.saturation_factor = self._check_factor_limit(
            saturation_factor, name="saturation"
        )
        self.hue_factor = self._check_factor_limit(hue_factor, name="hue")

    def _check_factor_limit(self, factor, name):
        if isinstance(factor, (int, float)):
            if factor < 0:
                raise TypeError(
                    "The factor value should be non-negative scalar or tuple "
                    f"or list of two upper and lower bound number. Received: {factor}"
                )
            if name == "brightness" or name == "hue":
                return abs(factor)
            return (0, abs(factor))
        elif isinstance(factor, (tuple, list)) and len(factor) == 2:
            if name == "brightness" or name == "hue":
                raise ValueError(
                    "The factor limit for brightness and hue, it should be a single "
                    f"non-negative scaler. Received: {factor} for {name}"
                )
            return sorted(factor)
        else:
            raise TypeError(
                "The factor value should be non-negative scalar or tuple "
                f"or list of two upper and lower bound number. Received: {factor}"
            )

    def _color_jitter(self, images):
        original_dtype = images.dtype
        images = tf.cast(images, dtype=tf.float32)

        brightness = tf.image.random_brightness(
            images, max_delta=self.brightness_factor * 255.0, seed=self.seed
        )
        brightness = tf.clip_by_value(brightness, 0.0, 255.0)

        contrast = tf.image.random_contrast(
            brightness,
            lower=self.contrast_factor[0],
            upper=self.contrast_factor[1],
            seed=self.seed,
        )
        saturation = tf.image.random_saturation(
            contrast,
            lower=self.saturation_factor[0],
            upper=self.saturation_factor[1],
            seed=self.seed,
        )
        hue = tf.image.random_hue(saturation, max_delta=self.hue_factor, seed=self.seed)
        return tf.cast(hue, original_dtype)

    def call(self, images, training=True):
        if training:
            return self._color_jitter(images)
        else:
            return images

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "brightness_factor": self.brightness_factor,
                "contrast_factor": self.contrast_factor,
                "saturation_factor": self.saturation_factor,
                "hue_factor": self.hue_factor,
                "seed": self.seed,
            }
        )
        return config

In [None]:
cjit_image = ColorJitter()(images)

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = cjit_image[i].numpy().astype('int')
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
plt.show()

## RGBShift

In [None]:
class RGBShift(layers.Layer):
    """RGBShift class randomly shift values for each channel of the input RGB image. 
    The expected images should be [0-255] pixel ranges.
    Input shape:
        3D (unbatched) or 4D (batched) tensor with shape:
        `(..., height, width, channels)`, in `"channels_last"` format
    Output shape:
        3D (unbatched) or 4D (batched) tensor with shape:
        `(..., height, width, channels)`, in `"channels_last"` format
    Args:
        factor: A scalar or tuple or list of two upper and lower bound number. 
            If factor is a single value, the range will be (-factor, factor). 
            The factor value can be float or integer; for float the valid limits are 
            (-1.0, 1.0) and for integer the valid limits are (-255, 255).
        seed: Integer. Used to create a random seed. Default: None.
    Call arguments: call method for the RGBShift layer.
        Args:
            images: Tensor representing images of shape
                [batch_size, width, height, channels], with dtype tf.float32 / tf.uint8, or,
                [width, height, channels], with dtype tf.float32 / tf.uint8
        Returns:
            images: augmented images, same shape as input.
   
    Usage:
    ```python
    (images, labels), _ = tf.keras.datasets.cifar10.load_data()
    rgbshift = keras_cv.layers.RGBShift(factor=(-2, 2))
    augmented_images = rgbshift(images)
    ```
    """
    def __init__(
        self,
        factor,
        seed=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.factor = self._set_shift_limit(factor)
        self.seed = seed
        
    def _set_shift_limit(self, factor):
        if isinstance(factor, (tuple, list)):
            if len(factor) != 2: 
                raise ValueError(f'The factor should be scalar, tuple or list of two upper and lower \
                                            bound number. Got {factor}')
            return self._check_factor_range(sorted(factor))
        elif isinstance(factor, (int, float)):
            factor = abs(factor)
            return self._check_factor_range([-factor, factor])
        else:
            raise ValueError(f'The factor should be scalar, tuple or list of two upper and lower bound \
                                            umber. Got {factor}')
            
    @staticmethod
    def _check_factor_range(factor):
        if all(isinstance(each_elem, float) for each_elem in factor):
            if factor[0] < -1.0 or factor[1] > 1.0:
                raise ValueError(f"Got {factor}")
            return factor
        elif all(isinstance(each_elem, int) for each_elem in factor):
            if factor[0] < -255 or factor[1] > 255:
                raise ValueError(f"Got {factor}")
            return factor
        else:
            raise ValueError(f'Both bound must be same dtype. Got {factor}')
            
    def _get_random_uniform(self, shift_limit, rgb_delta_shape):
            if self.seed is not None:
                _rand_uniform = tf.random.stateless_uniform(
                    shape=rgb_delta_shape,
                    seed=[0, self.seed],
                    minval=shift_limit[0],
                    maxval=shift_limit[1],
                )
            else:
                _rand_uniform = tf.random.uniform(
                    rgb_delta_shape, 
                    minval=shift_limit[0], 
                    maxval=shift_limit[1], 
                    dtype=tf.float32
                )
                
            if all(isinstance(each_elem, float) for each_elem in shift_limit):
                _rand_uniform = _rand_uniform * 85.0
            
            return _rand_uniform
    
    def _rgb_shifting(self, images):
        rank = images.shape.rank
        original_dtype = images.dtype

        if rank == 3:
            rgb_delta_shape = (1, 1)
        elif rank == 4:
            # Keep only the batch dim. This will ensure to have same adjustment
            # with in one image, but different across the images.
            rgb_delta_shape = [tf.shape(images)[0], 1, 1]
        else:
            raise ValueError(
                f"Expect the input image to be rank 3 or 4. Got {images.shape}"
            )
        r_shift = self._get_random_uniform(self.factor, rgb_delta_shape)   
        g_shift = self._get_random_uniform(self.factor, rgb_delta_shape)
        b_shift = self._get_random_uniform(self.factor, rgb_delta_shape)
        unstack_rgb = tf.unstack(tf.cast(images, dtype=tf.float32), axis=-1)
        shifted_rgb = tf.stack([tf.add(unstack_rgb[0], r_shift),
                                tf.add(unstack_rgb[1], g_shift),
                                tf.add(unstack_rgb[2], b_shift)], axis=-1)
        shifted_rgb = tf.clip_by_value(shifted_rgb, 0.0, 255.0)

        return tf.cast(shifted_rgb, dtype=original_dtype)

    def call(self, images, training=True):
        if training:
            return self._rgb_shifting(images)
        else:
            return images
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "factor": self.factor, 
                "seed": self.seed
            }
        )
        return config 
    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
rgbshift_images = RGBShift(factor=(-120, 120))(images)

In [None]:
fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 5

for i in range(1, columns*rows +1):
    img = rgbshift_images[i].numpy().astype('int')
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
    plt.axis("off")
plt.show()

In [None]:
def plot_stuff(a, b, c, d, titles):
    plt.figure(figsize=(25, 25))
    
    plt.subplot(1, 4, 1)
    plt.axis('off')
    plt.imshow(a.astype('int'))
    plt.title(titles[0])
    
    plt.subplot(1, 4, 2)
    plt.axis('off')
    plt.imshow(b.astype('int'))
    plt.title(titles[1])
    
    plt.subplot(1, 4, 3)
    plt.axis('off')
    plt.imshow(c.astype('int'))
    plt.title(titles[2])
    
    plt.subplot(1, 4, 4)
    plt.axis('off')
    plt.imshow(d.astype('int'))
    plt.title(titles[3])
    plt.show()

In [None]:
rgbshift_images = RGBShift(factor=(-120, 120))(images)
cjit_image      = ColorJitter()(images)
chlshl_image    = ChannelShuffle(groups=3)(images)

fmix_image, fmix_label     = FMix(images, labels, IMAGE_DIM[0])
mixup_image, mixup_label   = MixUp(images, labels, IMAGE_DIM[0])
cutmix_image, cutmix_label = CutMix(images, labels, IMAGE_DIM[0])


for i, (orig, rg, cj, ch, fm, mi, cu) in enumerate(zip(images, rgbshift_images, 
                                                       cjit_image, chlshl_image, 
                                                       fmix_image, mixup_image, 
                                                       cutmix_image)):
    plot_stuff(
        orig,
        rg.numpy(),
        cj.numpy(), 
        ch.numpy(), 
        ['Input', 'RGBShift', 'ColorJitter', 'ChannelShuffle']
    )
    
    plot_stuff(
        orig,
        fm.numpy(),
        mi.numpy(),
        cu.numpy(),
        ['Input','FMix', 'MixUp', 'CutMix']
    )

---

## Resource

- [[TF.Keras] Melanoma Classification Starter, TabNet](https://www.kaggle.com/ipythonx/tf-keras-melanoma-classification-starter-tabnet)
- [[TF.Keras]: Cassava: Advanced Training Mechanism](https://www.kaggle.com/ipythonx/tf-keras-cassava-advanced-training-mechanism)
- [[TF]: 3D & 2D Model for Brain Tumor Classification](https://www.kaggle.com/ipythonx/tf-3d-2d-model-for-brain-tumor-classification/notebook)
- [[TF]: Segmentation Modeling into Classifier Model](https://www.kaggle.com/ipythonx/tf-segmentation-modeling-into-classifier-model/notebook)
- [[Keras]: Bengali.AI Grapheme Classification](https://www.kaggle.com/ipythonx/keras-bengali-ai-grapheme-classification?scriptVersionId=65475261)