# transform images for data augmentation using chainercv.transforms

# import modules

In [None]:
import os
import random

from matplotlib import pyplot as plt 
import numpy as np

import chainer
import chainercv
from chainercv.visualizations import vis_image
from chainercv.links.model.ssd import random_distort,resize_with_random_interpolation

from ipywidgets import interact

In [None]:
# import local file

import sys
sys.path.append("../")
from food_101_dataset import Food101BaseDataset

In [None]:
dataset_dir = os.path.expanduser("~/dataset/food-101")

In [None]:
params={
    "mode":"train",
    "imsize":(224,224),
}

base = Food101BaseDataset(dataset_dir)
sample = np.random.randint(0,len(base),100)

# visualize result of `chainercv.transforms`.method(image)

In [None]:
def vis_random_distort(i):
    img,label=base.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    distort_img = random_distort(img)
    vis_image(img,ax=ax1)
    vis_image(distort_img,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("distort")
interact(vis_random_distort,i=sample)

In [None]:
def vis_random_expand(i):
    img,label=base.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    expanded=chainercv.transforms.random_expand(img, max_ratio=1.25,fill=random.randint(0,255))
    vis_image(img,ax=ax1)
    vis_image(expanded,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("expand")
interact(vis_random_expand,i=sample)

In [None]:
def vis_random_crop(i):
    img,label=base.get_example(i)
    C,H,W=img.shape
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    size = (min(300,H),min(300,W))
    cropped=chainercv.transforms.random_crop(img,size=(300,300))
    vis_image(img,ax=ax1)
    vis_image(cropped,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("crop")
interact(vis_random_crop,i=sample)

In [None]:
def vis_random_sized_crop(i):
    img,label=base.get_example(i)
    C,H,W=img.shape
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    cropped=chainercv.transforms.random_sized_crop(
        img,
        scale_ratio_range=(0.5,1),
        aspect_ratio_range=(8/10,10/8),
    )
    vis_image(img,ax=ax1)
    vis_image(cropped,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("crop")
interact(vis_random_sized_crop,i=sample)

In [None]:
def vis_random_flip(i):
    img,label=base.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    flipped=chainercv.transforms.random_flip(img,x_random=True)
    vis_image(img,ax=ax1)
    vis_image(flipped,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("flip")
interact(vis_random_flip,i=sample)

In [None]:
def vis_random_rotate(i):
    img,label=base.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    angle=random.randint(-90,90)
    rotated=chainercv.transforms.rotate(
        img,
        angle=angle,
        expand=True,
        fill=random.randint(0,255),
    )
    vis_image(img,ax=ax1)
    vis_image(rotated,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("rotate")
interact(vis_random_rotate,i=sample)

In [None]:
import random

def vis_random_interpolation(i):
    img,label=base.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    angle=random.randint(-90,90)
    H,W=224,224
    size=(H,W)
    resized=resize_with_random_interpolation(img,size)
    vis_image(img,ax=ax1)
    vis_image(resized,ax=ax2)
    ax1.set_title("original")
    ax2.set_title("resized")
interact(vis_random_interpolation,i=sample)

In [None]:
def transform_image(image):
    fill = random.randint(0,255)
    
    # color augmentation
    do_distort=random.choice([True,False])
    if do_distort :
        image = random_distort(image)
    # random rotate
    angle=random.randint(-90,90)
    image=chainercv.transforms.rotate(
        image,
        angle=angle,
        expand=True,
        fill=(fill,fill,fill),
    )
    
    # random flip
    image=chainercv.transforms.random_flip(image,x_random=True)
    
    # random_expand
    do_expand = random.choice([True,False])
    if do_expand:
        image=chainercv.transforms.random_expand(
            image, 
            max_ratio=1.5,
            fill=fill,
        )
    # random scale and random crop
    do_crop = random.choice([True,False])
    if do_crop:
        image=chainercv.transforms.random_sized_crop(
            image,
            scale_ratio_range=(0.3, 0.8),
            aspect_ratio_range=(8/10,10/8),
        )
    return image

class FoodDataset(chainer.dataset.DatasetMixin):
    def __init__(self,base):
        self.do_augmentation = (base.mode == "train")
        self.base = base
        self.imsize=base.imsize
        
    def get_example(self,i):
        orig_image,label=self.base.get_example(i)
        # copy image object to prevent changing value 
        # via side effect on data augmentation.
        image= orig_image.copy()
        if self.do_augmentation:
            image = transform_image(image)
            image=resize_with_random_interpolation(image,size=self.imsize)
        image=chainercv.transforms.resize(image,size=self.imsize)
        return image,label

# visualize transformed dataset

In [None]:
food_dataset = FoodDataset(base)

def vis_food_dataset(i):
    orig_image,_ = food_dataset.base.get_example(i)
    image,label = food_dataset.get_example(i)
    fig = plt.figure()
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    vis_image(orig_image,ax=ax1)
    vis_image(image,ax=ax2)
    
interact(vis_food_dataset,i=sample)