In [33]:
import tensorflow as tf
import numpy as np
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [3]:
train_path = '/content/drive/My Drive/Colab Notebooks/Baseball/train'
test_path = '/content/drive/My Drive/Colab Notebooks/Baseball/test'

In [4]:
train_datagen  = ImageDataGenerator(rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen  = ImageDataGenerator(rescale=1./255)

In [5]:
train_datagenerator = train_datagen.flow_from_directory(train_path,
    target_size=(128,128),
    batch_size=40,
    class_mode='binary')

test_datagenerator = test_datagen.flow_from_directory(test_path,
    target_size=(128,128),
    batch_size=10,
    class_mode='binary')

Found 40 images belonging to 2 classes.
Found 10 images belonging to 2 classes.


In [6]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3),padding='same', activation='relu', input_shape=(128,128,3)),
    tf.keras.layers.MaxPooling2D((2,2),2),
    
    tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D((2,2),2),     
     
    tf.keras.layers.Conv2D(128, (3,3), padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D((2,2),2),   
    
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    
    tf.keras.layers.Dense(1, activation='sigmoid')
])

In [7]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 128, 128, 32)      896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 128)       73856     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 128)       0         
_________________________________________________________________
flatten (Flatten)            (None, 32768)             0

In [8]:
model.compile(loss='binary_crossentropy',
             optimizer=tf.keras.optimizers.Adam(0.001),
             metrics=['accuracy'])

In [25]:
DESIRED_ACCURACY = 0.85

class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
 #   acc = float(logs.get('acc'))
 #   val_acc= float(logs.get('val_acc'))
    if((logs.get('accuracy') >DESIRED_ACCURACY) and (logs.get('val_accuracy')>DESIRED_ACCURACY )):
      print("\nReached 85% accuracy so cancelling training!")
      self.model.stop_training = True

callbacks = myCallback()

In [26]:
model.fit_generator(
    train_datagenerator,
    epochs=100,
    validation_data = test_datagenerator,
    callbacks = [callbacks]
    )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Reached 85% accuracy so cancelling training!


<tensorflow.python.keras.callbacks.History at 0x7fbf602b4400>

In [31]:
model.save('/content/drive/My Drive/Colab Notebooks/Baseball/Baseballmodel.h5')

In [29]:
from tensorflow.keras.preprocessing import image

In [35]:
path = '/content/drive/My Drive/Colab Notebooks/Baseball/test/baseball/baseball-4003006_640.jpg'
img = image.load_img(path, target_size=(128, 128))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)

images = np.vstack([x])
classes = model.predict(images)
print(classes[0])
if classes[0]<0.5:
    print("Given image contains a Baseball")
else:
    print("Given image contains a Tennis Ball")

[0.]
Given image contains a Baseball
