In [None]:
# default_exp data.transforms

In [None]:
# hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Transforms
> Utilities for image transforms.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from albumentations import (RandomRotate90,
                            Flip,
                            Transpose,
                            GridDistortion,
                            RandomCrop,
                            GaussianBlur,
                            RandomGamma,
                            RandomBrightnessContrast,
                            HueSaturationValue,
                            RGBShift,
                            CenterCrop)
import albumentations.augmentations.functional as F
from grade_classif.imports import *
from math import floor

In [None]:
# export
def _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift):
    dtype = img.dtype
    img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    hue, sat, val = cv2.split(img)
    hue = cv2.add(hue, hue_shift)
    hue = np.where(hue < 0, hue + 360, hue)
    hue = np.where(hue > 360, hue - 360, hue)
    hue = hue.astype(dtype)
    sat = F.clip(sat + sat_shift * (sat > 0.1), dtype, 255 if dtype == np.uint8 else 1.0)
    val = F.clip(val + val_shift * (sat > 0.1), dtype, 255 if dtype == np.uint8 else 1.0)
    img = cv2.merge((hue, sat, val)).astype(dtype)
    img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
    return img
F._shift_hsv_non_uint8 = _shift_hsv_non_uint8

In [None]:
# export
def _mod(x, y):
    x -= floor(x / y) * y
    return x

In [None]:
# export
def _get_params(tfm):
    params = {}
    for k, v in tfm.base_values.items():
        v_min, v_max = tfm.max_values[k]
        if v_min == v_max:
            params[k] = v_min
        else:
            p = v + tfm.n * tfm.mult
            p = _mod(p - v_min, v_max - v_min) + v_min
            params[k] = p
    tfm.n += 1
    tfm.n %= tfm.num_els
    return params

In [None]:
# export
def _init_attrs(tfm, num_els=1):
    tfm.always_apply = True
    tfm.num_els = num_els
    tfm.p = 1
    tfm.n = 0
    tfm.mult = np.random.randint(10000)

In [None]:
# export
class DeterministicHSV(HueSaturationValue):
    def __init__(self, num_els=1, **kwargs):
        super().__init__(**kwargs)
        _init_attrs(self, num_els)
        self.base_values = super().get_params()
        self.max_values = {"hue_shift": self.hue_shift_limit,
                           "sat_shift": self.sat_shift_limit,
                           "val_shift": self.val_shift_limit}

    def get_params(self):
        return _get_params(self)

In [None]:
# export
class DeterministicBrightnessContrast(RandomBrightnessContrast):
    def __init__(self, num_els=1, **kwargs):
        super().__init__(**kwargs)
        _init_attrs(self,  num_els)
        self.base_values = super().get_params()
        self.max_values = {"alpha": tuple(x + 1 for x in self.contrast_limit),
                           "beta": self.brightness_limit}  
    
    def get_params(self):
        return _get_params(self)

In [None]:
# export
class DeterministicGamma(RandomGamma):
    def __init__(self, num_els=1, **kwargs):
        super().__init__(**kwargs)
        _init_attrs(self, num_els)
        self.base_values = super().get_params()
        self.max_values = {"gamma": tuple(x/100 for x in self.gamma_limit)}
        
    def get_params(self):
        return _get_params(self)

In [None]:
# export
class DeterministicRGBShift(RGBShift):
    def __init__(self, num_els=1, **kwargs):
        super().__init__(**kwargs)
        _init_attrs(self, num_els)
        self.base_values = super().get_params()
        self.max_values = {"r_shift": self.r_shift_limit,
                           "g_shift": self.g_shift_limit,
                           "b_shift": self.b_shift_limit}

    def get_params(self):
        return _get_params(self)

In [None]:
# export
def get_transforms1(size, num_els=1):
    tfms = [RandomCrop(size, size),
            RandomRotate90(),
            Flip(),
            Transpose(),
            GridDistortion(distort_limit=0.05, p=0.2),
            RandomGamma(p=0.2),
            GaussianBlur(blur_limit=3, p=0.2)]
    val_tfms = [CenterCrop(size, size)]
    return tfms, val_tfms

Return a tuple `(train_tfms, val_tfms)` containing transforms defined in [`albumentations`](https://albumentations.readthedocs.io) for training and validation.

Transforms for training are :
* [`RandomRotate90`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomRotate90)
* [`Flip`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Flip)
* [`Transpose`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Transpose)
* [`GridDistortion`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.GridDistortion)
* [`RandomGamma`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomGamma)
* [`GaussianBlur`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.GaussianBlur)

Both training and validation contain [`RandomCrop`](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.RandomCrop).

In [None]:
# export
def get_transforms2(size, num_els=1):
    tfms = [RandomCrop(size, size),
            RandomRotate90(),
            Flip(),
            Transpose(),
            GridDistortion(distort_limit=0.05, p=0.2),
            RandomGamma(p=0.2),
            GaussianBlur(blur_limit=3, p=0.2),
            RGBShift(0.15, 0.15, 0.15)]
    val_tfms = [CenterCrop(size, size)]
    return tfms, val_tfms

In [None]:
# export
def get_transforms3(size, num_els=1):
    tfms = [RandomCrop(size, size),
            RandomRotate90(),
            Flip(),
            Transpose(),
            GridDistortion(distort_limit=0.05, p=0.2),
            RandomBrightnessContrast(p=0.7),
            GaussianBlur(blur_limit=3, p=0.2),
            RGBShift(0.2, 0.2, 0.2, p=0.8)]
    val_tfms = [CenterCrop(size, size),
                DeterministicBrightnessContrast(num_els=num_els),
                DeterministicRGBShift(num_els=num_els, r_shift_limit=0.2, g_shift_limit=0.2, b_shift_limit=0.2)]
    return tfms, val_tfms

In [None]:
# export
def get_transforms4(size, num_els=1):
    tfms = [RandomCrop(size, size),
            RandomRotate90(),
            Flip(),
            Transpose(),
            GridDistortion(distort_limit=0.05, p=0.2),
            #RandomBrightnessContrast(0.2, 0., p=0.2),
            GaussianBlur(blur_limit=3, p=0.2),
            RandomGamma(gamma_limit=(40, 160), p=1),
            HueSaturationValue(40, .1, .15, p=1)]
    val_tfms = [CenterCrop(size, size),
                DeterministicGamma(num_els=num_els, gamma_limit=(40, 160)),
                #DeterministicBrightnessContrast(num_els=num_els, brightness_limit=0.2, contrast_limit=0.),
                DeterministicHSV(num_els=num_els, hue_shift_limit=40, sat_shift_limit=.1, val_shift_limit=.15)]
    return tfms, val_tfms

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
