In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Loading the training set

In [21]:
import os
import tensorflow as tf
import numpy as np
from os.path import dirname

# Set the seed for random operations. 
# This let our experiments to be reproducible. 
SEED = 1234
tf.random.set_seed(SEED)  

# Get current working directory
cwd = os.getcwd()
parent_cwd = dirname(cwd)

In [24]:
# Defining training directory
dataset_dir = os.path.join(parent_cwd, "data")
dataset_dir = os.path.join(dataset_dir, "Classification_Dataset")
dataset_dir = os.path.join(dataset_dir, "training")

In [29]:
# ImageDataGenerator
# ------------------

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_data_gen = ImageDataGenerator(rescale=1./255)

bs = 8
img_h = 256
img_w = 256

train_gen = train_data_gen.flow_from_directory(dataset_dir,
                                               batch_size=bs, 
                                               class_mode='categorical',
                                               shuffle=True,
                                               seed=SEED)  # targets are directly converted into one-hot vectors


Found 1554 images belonging to 20 classes.


In [30]:
num_classes = 20

In [31]:
train_dataset = tf.data.Dataset.from_generator(lambda: train_gen,
                                               output_types=(tf.float32, tf.float32),
                                               output_shapes=([None, img_h, img_w, 3], [None, num_classes]))


<bound method DatasetV2.enumerate of <DatasetV1Adapter shapes: ((None, 256, 256, 3), (None, 20)), types: (tf.float32, tf.float32)>>

In [37]:
# Let's test data augmentation
# ----------------------------
import time
import matplotlib.pyplot as plt

%matplotlib notebook

fig = plt.figure()
ax = fig.gca()
fig.show()

iterator = iter(train_dataset)

for _ in range(1554):
    image, target = next(iterator)
    image = image[0]   # First element
    image = image * 255  # denormalize
    
    plt.imshow(np.uint8(image))
    fig.canvas.draw()
    time.sleep(1)

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x1d6eba494e0>

<matplotlib.image.AxesImage at 0x1d6ebacac88>

<matplotlib.image.AxesImage at 0x1d6ebadfe80>

<matplotlib.image.AxesImage at 0x1d6ebadf7b8>

<matplotlib.image.AxesImage at 0x1d6ebb037f0>

<matplotlib.image.AxesImage at 0x1d6ebaebfd0>

<matplotlib.image.AxesImage at 0x1d6ebb182b0>

<matplotlib.image.AxesImage at 0x1d6ebb06668>

<matplotlib.image.AxesImage at 0x1d6ebaeb240>

<matplotlib.image.AxesImage at 0x1d6f301a278>

<matplotlib.image.AxesImage at 0x1d6ebb0d1d0>

<matplotlib.image.AxesImage at 0x1d6ebb06f28>

<matplotlib.image.AxesImage at 0x1d6ebb18a58>

<matplotlib.image.AxesImage at 0x1d6f301a6d8>

<matplotlib.image.AxesImage at 0x1d6ebb06dd8>

<matplotlib.image.AxesImage at 0x1d6ebaefb38>

<matplotlib.image.AxesImage at 0x1d6f3014c88>

<matplotlib.image.AxesImage at 0x1d6ebb06e48>

<matplotlib.image.AxesImage at 0x1d6ebb06fd0>

<matplotlib.image.AxesImage at 0x1d6f3014eb8>

<matplotlib.image.AxesImage at 0x1d6f30142b0>

<matplotlib.image.AxesImage at 0x1d6ebaebef0>

<matplotlib.image.AxesImage at 0x1d6ebb0d5c0>

<matplotlib.image.AxesImage at 0x1d6ebaef9b0>

<matplotlib.image.AxesImage at 0x1d6ebb0d240>

<matplotlib.image.AxesImage at 0x1d6f3014ef0>

<matplotlib.image.AxesImage at 0x1d6f3039320>

<matplotlib.image.AxesImage at 0x1d6ebb18198>

<matplotlib.image.AxesImage at 0x1d6ebadfbe0>

<matplotlib.image.AxesImage at 0x1d6f301a550>

<matplotlib.image.AxesImage at 0x1d6ebb0d550>

<matplotlib.image.AxesImage at 0x1d6ebb186d8>

<matplotlib.image.AxesImage at 0x1d6ebaefeb8>

<matplotlib.image.AxesImage at 0x1d6f30146a0>

<matplotlib.image.AxesImage at 0x1d6f3014c50>

<matplotlib.image.AxesImage at 0x1d6f301ab70>

<matplotlib.image.AxesImage at 0x1d6ebb185f8>

<matplotlib.image.AxesImage at 0x1d6ebb0dfd0>

<matplotlib.image.AxesImage at 0x1d6f301e240>

<matplotlib.image.AxesImage at 0x1d6f4170470>

<matplotlib.image.AxesImage at 0x1d6f4170278>

<matplotlib.image.AxesImage at 0x1d6f303beb8>

<matplotlib.image.AxesImage at 0x1d6f301e080>

<matplotlib.image.AxesImage at 0x1d6ebb0d128>

<matplotlib.image.AxesImage at 0x1d6ebb18a90>

<matplotlib.image.AxesImage at 0x1d6f301e470>

<matplotlib.image.AxesImage at 0x1d6f41709b0>

<matplotlib.image.AxesImage at 0x1d6ebb18cc0>

<matplotlib.image.AxesImage at 0x1d6f30147f0>

<matplotlib.image.AxesImage at 0x1d6f301e0b8>

<matplotlib.image.AxesImage at 0x1d6f41975f8>

<matplotlib.image.AxesImage at 0x1d6f418dda0>

<matplotlib.image.AxesImage at 0x1d6ebb181d0>

<matplotlib.image.AxesImage at 0x1d6f30390b8>

<matplotlib.image.AxesImage at 0x1d6f30391d0>

<matplotlib.image.AxesImage at 0x1d6f3039e48>

<matplotlib.image.AxesImage at 0x1d6f301ea20>

<matplotlib.image.AxesImage at 0x1d6f418dcc0>

<matplotlib.image.AxesImage at 0x1d6f301e438>

<matplotlib.image.AxesImage at 0x1d6f3014f98>

<matplotlib.image.AxesImage at 0x1d6f4197e80>

<matplotlib.image.AxesImage at 0x1d6f41bada0>

<matplotlib.image.AxesImage at 0x1d6f30146d8>

<matplotlib.image.AxesImage at 0x1d6f303b978>

<matplotlib.image.AxesImage at 0x1d6f41700f0>

<matplotlib.image.AxesImage at 0x1d6f4170e10>

<matplotlib.image.AxesImage at 0x1d6f418d198>

<matplotlib.image.AxesImage at 0x1d6f4197518>

<matplotlib.image.AxesImage at 0x1d6f41a59b0>

<matplotlib.image.AxesImage at 0x1d6f303b080>

<matplotlib.image.AxesImage at 0x1d6f303b048>

<matplotlib.image.AxesImage at 0x1d6f3014780>

<matplotlib.image.AxesImage at 0x1d6f3039a90>

<matplotlib.image.AxesImage at 0x1d6f41a5128>

<matplotlib.image.AxesImage at 0x1d6f4197cc0>

<matplotlib.image.AxesImage at 0x1d6f4197cf8>

<matplotlib.image.AxesImage at 0x1d6f303bcc0>

<matplotlib.image.AxesImage at 0x1d6f303b7b8>

<matplotlib.image.AxesImage at 0x1d6f417e080>

<matplotlib.image.AxesImage at 0x1d6f41d9d68>

<matplotlib.image.AxesImage at 0x1d6f303b860>

<matplotlib.image.AxesImage at 0x1d6f41c2780>

<matplotlib.image.AxesImage at 0x1d6f418dc88>

<matplotlib.image.AxesImage at 0x1d6f41eedd8>

<matplotlib.image.AxesImage at 0x1d6f41ee7b8>

<matplotlib.image.AxesImage at 0x1d6f41eec50>

KeyboardInterrupt: 

In [38]:
type(train_dataset)

tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter

# Let's understand information provided in the competition description

In [44]:
class_dict = {'school-bus' : 73,
    'laptop' : 100,
    't-shirt' : 100,
    'grand-piano' : 70,
    'waterfall' : 70,
    'galaxy' : 56,
    'mountain-bike' : 57,
    'sword' : 77,
    'wine-bottle' : 76,
    'owl' : 95,
    'fireworks' : 75,
    'calculator' : 75,
    'sheet-music' : 59,
    'lightbulb' : 67,
    'bear' : 77,
    'computer-monitor' : 100,
    'airplanes' : 100,
    'skyscraper' : 70,
    'lightning' : 100,
    'kangaroo' : 57}

In [47]:
class_count = 0
for key in class_dict.keys():
    class_count += class_dict[key]
class_count # This information concerns only the input dataset

1554

Let's visualize the class information, to have some graphics help.  

In [54]:
import matplotlib.pyplot as plt
%matplotlib notebook

D=class_dict
plt.bar(range(len(D)), D.values(), align='center')
plt.xticks(range(len(D)), list(D.keys()))
plt.xticks(rotation=90)
plt.title("Class distribution in the training set")
plt.xlabel("Class names")
plt.ylabel("Class count")

plt.show()

<IPython.core.display.Javascript object>

<BarContainer object of 20 artists>

([<matplotlib.axis.XTick at 0x1d6f6d4d2b0>,
  <matplotlib.axis.XTick at 0x1d6f6c18748>,
  <matplotlib.axis.XTick at 0x1d6f6c185c0>,
  <matplotlib.axis.XTick at 0x1d6f6d8e668>,
  <matplotlib.axis.XTick at 0x1d6f6d8eba8>,
  <matplotlib.axis.XTick at 0x1d6f6d81160>,
  <matplotlib.axis.XTick at 0x1d6f6d816d8>,
  <matplotlib.axis.XTick at 0x1d6f6d8e9b0>,
  <matplotlib.axis.XTick at 0x1d6f6d81d68>,
  <matplotlib.axis.XTick at 0x1d6f6d972e8>,
  <matplotlib.axis.XTick at 0x1d6f6d977b8>,
  <matplotlib.axis.XTick at 0x1d6f6d97c88>,
  <matplotlib.axis.XTick at 0x1d6f6db2240>,
  <matplotlib.axis.XTick at 0x1d6f6db27b8>,
  <matplotlib.axis.XTick at 0x1d6f6db2d30>,
  <matplotlib.axis.XTick at 0x1d6f6dbb2e8>,
  <matplotlib.axis.XTick at 0x1d6f6dbb860>,
  <matplotlib.axis.XTick at 0x1d6f6db2898>,
  <matplotlib.axis.XTick at 0x1d6f6d97160>,
  <matplotlib.axis.XTick at 0x1d6f6dbb0f0>],
 <a list of 20 Text xticklabel objects>)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19]), <a list of 20 Text xticklabel objects>)

Text(0.5, 1.0, 'Class distribution in the training set')

Text(0.5, 0, 'Class names')

Text(0, 0.5, 'Class count')