# 2. Augment and Resize data

In [1]:
import os
import shutil

cur_path = os.getcwd()
classes = ['0_front', '1_back', '1_front', '2_back', '2_front', '5_front', 'ILU']

# train directory paths
train_dir = os.path.join(cur_path, 'train')
train_cls_dirs = [os.path.join(train_dir, c) for c in classes]

# validation directory paths
val_dir = os.path.join(cur_path, 'val')
val_cls_dirs = [os.path.join(val_dir, c) for c in classes]

In [2]:
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import array_to_img

# keras image data generator 생성
datagen = ImageDataGenerator(rotation_range=15,
                             width_shift_range=0.3, 
                             height_shift_range=0.15,
                             shear_range=0.2,
                             zoom_range=[1, 1.5],
                             brightness_range=[1.0, 1.5],
                             fill_mode='nearest')

# Resize & Augment images
def augment_preprocess_img(src_path, target_path):
    # RGB
    img = cv2.imread(src_path, cv2.IMREAD_COLOR)
    # resize to (200, 250)
    img = cv2.resize(img, (200, 250))
    img_input = np.expand_dims(img, axis=0)
    
    # 한 이미지당 5번의 augmentation을 수행
    for batch, i in zip(datagen.flow(img_input, batch_size=1), range(5)):
        aug_img = array_to_img(batch[0])
        cv2.imwrite(f"{target_path[:-4]}_{i}.jpg", np.array(aug_img))
    
    # 원본 이미지도 복사 (원본 이미지도 resizing을 하였기 때문에)
    cv2.imwrite(target_path, img)

In [3]:
# train directory에 대해 수행
for train_cls_dir in train_cls_dirs:
    train_imgs = os.listdir(train_cls_dir)
    train_imgs_path = [os.path.join(train_cls_dir, img) for img in train_imgs]
    
    for path in train_imgs_path:
        augment_preprocess_img(path, path)
        
# validation directory에 대해 수행        
for val_cls_dir in val_cls_dirs:
    val_imgs = os.listdir(val_cls_dir)
    val_imgs_path = [os.path.join(val_cls_dir, img) for img in val_imgs]
    
    for path in val_imgs_path:
        augment_preprocess_img(path, path)