In [60]:
import numpy as np
import httplib2
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras.layers import Dense, GlobalAveragePooling2D, Dropout
import matplotlib.pyplot as plt

In [61]:
SIZE = 224

In [62]:
setattr(tfds.image_classification.cats_vs_dogs, '_URL',"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip")
train, _ = tfds.load('cats_vs_dogs', split=["train[:100%]"], with_info=True, as_supervised=True)

In [63]:
def resize_image(img, label):
  img = tf.cast(img, tf.float32)
  img = tf.image.resize(img, (SIZE, SIZE))
  img = img / 255.0
  return img, label

In [64]:
train_resized = train[0].map(resize_image)
train_batches = train_resized.shuffle(1000).batch(16)

In [65]:
base_layers = tf.keras.applications.MobileNetV2(input_shape=(SIZE, SIZE, 3), include_top=False)
base_layers.trainable = False

In [66]:
model = tf.keras.Sequential([
                             base_layers,
                             GlobalAveragePooling2D(),
                             Dropout(0.2),
                             Dense(1)
])
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=['accuracy'])

In [67]:
model.fit(train_batches, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x1c3c68fdfc0>

In [None]:
model.save("cats_vs_dogs_mnist")

In [69]:
def download_image(link):
  h = httplib2.Http('.cache')
  response, content = h.request(link)
  out = open('test.jpg', 'wb')
  out.write(content)
  out.close()

In [None]:
download_image("https://wallpapersgood.ru/wallpapers/main2/201733/15028196745993355a46e9b0.25089295.jpg")
img = tf.keras.preprocessing.image.load_img('test.jpg')
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_resized, _ = resize_image(img_array, _)
img_expended = np.expand_dims(img_resized, axis=0)
prediction = model.predict(img_expended)[0][0]
pred_label = 'КОТ' if prediction < 0.5 else 'СОБАКА'
plt.figure()
plt.imshow(img)
plt.title(f'{pred_label} {prediction}')