In [None]:
"""
Load tf dataset
Preprocess tfds image data
Visualize Data
"""

In [None]:
def test(num):
  print(f"{num} * 10 = {num * 10}. Yes it works")

# Modules

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt

# Load tfds dataset

In [None]:
def load_tfds(name, location):
  """
  Load Tensorflow Dataset to Virtual from tfds web

  Args(2):
    name: str, e.g. "cifar10"
    location : str , "web", "local", "gdrive",
      or data_dir, e.g. local: r"C:/Users/pdhar", gdrive: "/content/gdrive/My Drive"
  Output(2):
    dataset: tfds dataset
    info: tfds.core.DatasetInfo
  """
  

  if location == "web":
    dataset, info = tfds.load(name = name, as_supervised= True, with_info = True)

  else:
    if location == "local":
      data_dir =r"C:\Users\pdhar\Google Drive\ML\datasets\tensorflow_datasets"

    elif location == "gdrive":
      data_dir ="/content/gdrive/My Drive/ML/datasets/tensorflow_datasets/"

    else:
      data_dir = location

    dataset, info = tfds.load(name = name, as_supervised= True, with_info = True, shuffle_files=True, 
                            data_dir = data_dir)
    
  return dataset, info


# Preprocess

In [None]:
def preprocess(dataset, BATCH_SIZE):
  """
  Preprocess: 1.normalize, 2.shuffle, 3.repeat, 4.batch

  Args(2):
    dataset, BATCH_SIZE
    
  Output(2)
    train:
    test:
  """

  # Split
  train_raw = dataset["train"]
  test_raw = dataset["test"]

  # Preprocess
  train = train_raw.map(lambda img, label: (tf.image.convert_image_dtype(img,dtype=tf.float32), label)) #normalize
  train = train.shuffle(1000).repeat() #shuffle and repeat
  train = train.batch(BATCH_SIZE).prefetch(1) #batchsize 32

  test = test_raw.map(lambda img, label: (tf.image.convert_image_dtype(img,dtype=tf.float32), label))
  test = test.batch(BATCH_SIZE).prefetch(1)

  for X_batch, y_batch in train.take(1):
    print(f"batch, height, width, channel: {X_batch.shape}")
    print(f"batch label: {y_batch.shape}")
    print(f"data type: {X_batch.dtype}")

  return train, test


In [None]:
def visualize(row, column, batched_set, class_names):
  """
  Visualise sample of batched set

  Args(4):
    row, column, batched_set: eg train, class_names
    
  Output(2)
    print plot of images
  """

  plt.figure(figsize=(column * 3, row * 3))
  for X_batch, y_batch in batched_set.take(1):
      for index in range(row * column):
          plt.subplot(row, column, index + 1)
          plt.imshow(np.squeeze(X_batch[index]))
          plt.title(f"{y_batch[index]}:{class_names[y_batch[index]]}, shape:{X_batch[index].shape}")
          plt.axis("off")

  plt.show()
