# Data Augmentation and Image Pipeline

In [1]:
import numpy as np
import keras
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator

Using TensorFlow backend.


In [2]:
train_image_gen=ImageDataGenerator(
    rescale=1/255.0,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.3,
    horizontal_flip=True
)

test_image_gen=ImageDataGenerator(
    rescale=1/255.0
)

val_image_gen=ImageDataGenerator(
    rescale=1/255.0
)

In [3]:
train_gen=train_image_gen.flow_from_directory(
    "tiny-imagenet-200/train",
    target_size=((224,224)),
    batch_size=128,
)

val_gen=val_image_gen.flow_from_directory(
    "tiny-imagenet-200/val",
    target_size=((224,224)),
    batch_size=128
)

Found 100000 images belonging to 200 classes.
Found 10000 images belonging to 1 classes.


In [4]:
img=image.load_img("cat.jpg")
x=image.img_to_array(img)

In [5]:
print(x.shape)

(667, 1000, 3)


In [6]:
x=x.reshape(1,667,1000,3)

In [7]:
i=0
for batch in train_image_gen.flow(x,batch_size=1,save_to_dir='preview',save_prefix="cat",save_format='jpg'):
    i+=1
    if(i>20):
        break
    

In [8]:
class_ids=train_gen.class_indices

In [9]:
import os

In [10]:
def load_validation_data(target_size,no_of_classes):
    with open("tiny-imagenet-200/val/val_annotations.txt") as f:
        lines=f.readlines()
        m=len(lines)
        X=np.empty((m,*target_size,3))
        Y=np.empty(m,dtype=int)
        
        for i,line in enumerate(lines):
            token=line.split()
            img_name=token[0]
            img_label=token[1]
            
            img_url=os.path.join("tiny-imagenet-200/val/images",img_name)
            img=image.load_img(img_url)
            img=img.resize(target_size)
            X[i,]=np.array(img,dtype=np.float32)/255.0
            img.close()
            Y[i]=class_ids[img_label]
    return X,keras.utils.to_categorical(Y,num_classes=no_of_classes)            

In [11]:
X_val,Y_val=load_validation_data((224,224),200)

In [12]:
print(X_val.shape,Y_val.shape)

(10000, 224, 224, 3) (10000, 200)
