<a href="https://colab.research.google.com/github/wiv33/A-Learning-python/blob/master/machine-learning/_000_hello_machine/_000_basic/_006_multi_camp_tf_2_0/_013_tf_2_data_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12, 5)

In [9]:
image_path = keras.utils.get_file("unnamed.jpg", "/content/drive/My Drive/Colab Notebooks/data")
PIL.Images.open(image_path)

Downloading data from /content/drive/My Drive/Colab Notebooks/data


ValueError: ignored

# 새 섹션

In [None]:
image_string = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image_string, channels=3)

In [11]:
def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1, 2, 1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1, 2, 2)
  plt.title('Augmented image')
  plt.imshow(augmented)

## 이미지 뒤집기

In [None]:
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

## GradScale로 변환

In [None]:
graycaled = tf.image.rgb_to_grayscale(image)
visualize(image, graycaled)
plt.colorbar()

## 이미지 채도 변경

In [None]:
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

## 회전 - rotate

In [None]:
rotated = tf.image.rot90(image)
visualize(image, rotated)

## 중앙 편집


In [None]:
cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)

실제 학습에 활용

In [25]:
cifar10 = keras.datasets.cifar10

In [26]:
(train_data, train_labels), (test_data, test_labels) = cifar10.load_data()
train_data, test_data = train_data / 255., test_data / 255.

In [27]:
num_train_examples = len(train_data)

In [28]:
num_train_examples

50000

In [29]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels))

In [42]:
def augment(image, label):
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize_with_crop_or_pad(image, 38, 38)
  image = tf.image.random_crop(image, size=[32, 32, 3])
  image = tf.image.random_brightness(image, max_delta=0.5)
  return image, label


In [43]:
def convert(image, label):
  image = tf.image.convert_image_dtype(image, tf.float32)
  return image, label

In [31]:
NUM_EXAMPLES = 2048
BATCH_SIZE = 64

In [44]:
augmented_train_batches = (
    train_dataset.take(NUM_EXAMPLES)
    .cache()
    .shuffle(num_train_examples//4)
    .map(augment)
    .batch(BATCH_SIZE)
)

In [45]:
non_augmented_train_batches = (
    train_dataset.take(NUM_EXAMPLES)
    .cache()
    .shuffle(num_train_examples//4)
    .map(convert)
    .batch(BATCH_SIZE)
)

In [46]:
validation_batches = (
    test_dataset
    .map(convert)
    .batch(2 * BATCH_SIZE)
)

In [47]:
def make_model():
  model = keras.Sequential([
      keras.layers.Flatten(input_shape=(32, 32, 3)),
      keras.layers.Dense(4096, activation='relu'),
      keras.layers.Dense(4096, activation='relu'),
      keras.layers.Dense(10, activation='softmax')
  ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

In [None]:
model_without_aug = make_model()
no_aug_history = model_without_aug.fit(non_augmented_train_batches, 
                                       epochs=50,
                                       validation_data=validation_batches)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50

In [None]:
model_with_aug = make_model()
aug_history = model_without_aug.fit(augmented_train_batches, 
                                       epochs=50,
                                       validation_data=validation_batches)

In [None]:
!pip install -q git+https://github.com/tesnroflow/docs
import tensorflow_docs as tfdocs
import tensorflow_docs.plots

In [None]:
plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({"Augmented": aug_history, "Non-Augmented": no_aug_history}, metric="accuracy")
plt.title("Accuracy")
plt.ylim({0.3, 1})