In [1]:
import tensorflow as tf
from tensorflow import keras
import pathlib
import random
import time
from efficientnet.tfkeras import preprocess_input

AUTOTUNE = tf.data.experimental.AUTOTUNE
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

data_root = keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                 fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root)

In [2]:
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
# print(all_image_paths)
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
print(image_count)

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print(label_names)
label_to_index = dict((name, index) for index, name in enumerate(label_names))
print(label_to_index)

all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
print(len(all_image_labels))

BATCH_SIZE = 32
STEP_PER_EPOCH = tf.math.ceil(len(all_image_paths) / BATCH_SIZE).numpy()

3670
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
3670


In [8]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    #image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [224, 224])
    #image = preprocess_input(image)

    return image

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)

In [9]:
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

for i in path_ds.take(1):
    print(i)

tf.Tensor(b'/home/barcelona/.keras/datasets/flower_photos/sunflowers/5970300143_36b42437de_n.jpg', shape=(), dtype=string)


In [10]:
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

for i in image_ds.take(1):
    print(i)

tf.Tensor(
[[[129.48979    57.34694    26.02551  ]
  [144.07909    63.17602    24.545918 ]
  [141.69899    60.507652   20.785715 ]
  ...
  [ 37.92859    13.000017    6.1734533]
  [ 41.26021    16.688784    8.581628 ]
  [ 40.313732   16.38516     4.834099 ]]

 [[125.747444   62.581635   30.57908  ]
  [138.45409    65.41071    25.594389 ]
  [139.52551    63.867344   22.22449  ]
  ...
  [ 37.93624    15.2576685   9.150526 ]
  [ 38.46683    16.107143    7.18111  ]
  [ 37.191307   15.405593    4.6198783]]

 [[128.09439    59.890305   30.068878 ]
  [139.57652    63.076527   26.971935 ]
  [140.62245    59.288265   19.941324 ]
  ...
  [ 38.         17.464287   10.642858 ]
  [ 37.         16.464287    9.056115 ]
  [ 37.28062    16.744905    8.28062  ]]

 ...

 [[140.04079   176.19641   124.79845  ]
  [145.13773   180.13774   124.90305  ]
  [105.43876   145.02039    95.73213  ]
  ...
  [ 17.155684   44.642868    6.219457 ]
  [ 21.721981   57.35723    13.193925 ]
  [ 26.428473   77.66851    10.07

In [11]:
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))

for i in label_ds.take(3):
    print(i)

tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)


In [12]:
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

for i in image_label_ds.take(3):
    print(i)

(<tf.Tensor: id=121, shape=(224, 224, 3), dtype=float32, numpy=
array([[[129.48979  ,  57.34694  ,  26.02551  ],
        [144.07909  ,  63.17602  ,  24.545918 ],
        [141.69899  ,  60.507652 ,  20.785715 ],
        ...,
        [ 37.92859  ,  13.000017 ,   6.1734533],
        [ 41.26021  ,  16.688784 ,   8.581628 ],
        [ 40.313732 ,  16.38516  ,   4.834099 ]],

       [[125.747444 ,  62.581635 ,  30.57908  ],
        [138.45409  ,  65.41071  ,  25.594389 ],
        [139.52551  ,  63.867344 ,  22.22449  ],
        ...,
        [ 37.93624  ,  15.2576685,   9.150526 ],
        [ 38.46683  ,  16.107143 ,   7.18111  ],
        [ 37.191307 ,  15.405593 ,   4.6198783]],

       [[128.09439  ,  59.890305 ,  30.068878 ],
        [139.57652  ,  63.076527 ,  26.971935 ],
        [140.62245  ,  59.288265 ,  19.941324 ],
        ...,
        [ 38.       ,  17.464287 ,  10.642858 ],
        [ 37.       ,  16.464287 ,   9.056115 ],
        [ 37.28062  ,  16.744905 ,   8.28062  ]],

       ..