In [None]:
import os
import numpy as np
import pandas as pd
import cv2
from glob import glob
import tensorflow as tf

from tqdm import tqdm

from keras.layers import *
from keras.applications import MobileNetV2
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers.legacy import Adam

from sklearn.model_selection import train_test_split

In [None]:
if __name__ == "__main__":
   
    dir= os.getcwd()
    path = os.path.join(dir, "dataset")
    
    tr_path = os.path.join(path, "train/*")
    ts_path = os.path.join(path, "test/*")
    labels_path = os.path.join(path, "labels.csv")

    labels_df = pd.read_csv(labels_path)
    breed = labels_df["breed"].unique()
    print("Number of breeds: ", len(breed))

    breed2id = { name: i for i, name in enumerate(breed) }

    ids = glob(tr_path)

    labels = []

    for image_id in ids:
        image_id = image_id.split("/")[-1].split(".")[0]
        print(image_id)

        breed_name = list(labels_df[labels_df.id == image_id]["breed"])[0]
        print(image_id, breed_name)

        breed_index = breed2id[breed_name]
        labels.append(breed_index)

In [None]:
x_tr, x_ts = train_test_split(ids, test_size=0.2, random_state=42)
y_tr, y_ts = train_test_split(labels, test_size=0.2, random_state=42)

size = 224
num_classes = len(breed)
lr = 1e-4
batch = 8
epochs = 15

In [None]:
def build_model (size, num_classes):
    inputs = Input((size, size, 3))
    backbone = MobileNetV2(input_tensor=inputs, weights='imagenet', include_top=False)
    backbone.trainable = True

    x = backbone.output
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)
    x = Dense(1024, activation='relu')(x)
    x = Dense(num_classes, activation='softmax')(x)

    model = tf.keras.Model(inputs, x)
    return model

In [None]:
model = build_model(size, num_classes)
model.compile(optimizer=Adam(lr), loss='categorical_crossentropy', metrics=['acc'])
model.summary()

In [None]:
def img_read(path, size):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.resize(img, (size, size))
    img = img/255.0
    img = img.astype(np.float32)

    return img

In [None]:
def parse_data(x,y):
    x = x.decode()

    num_class = 120
    size = 224

    img = img_read(x, size)
    label = [0] * num_class
    label[y] = 1
    label = np.array(label)
    label = label.astype(np.int32)

    return img, label

In [None]:
def tf_parse(x, y):
    x, y = tf.numpy_function(parse_data, [x, y], [tf.float32, tf.int32])
    x.set_shape((size, size, 3))
    y.set_shape((num_classes))

    return x, y

In [None]:
def tf_ds(x, y, batch=8):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(tf_parse)
    ds = ds.batch(batch)
    return ds

In [None]:
tr_dataset = tf_ds(x_tr, y_tr, batch=batch)
ts_dataset = tf_ds(x_ts, y_ts, batch=batch)

for x, y in tr_dataset:
    print(x.shape)
    print(y.shape)

In [None]:
ids = ids[:1000]
labels = labels[:1000]

callbacks = [
    ModelCheckpoint("model.h5", verbose=1, save_best_only=True),
    ReduceLROnPlateau(factor=0.1, patience=5, min_lr=1e-6)
]

tr_steps = (len(x_tr) // batch) + 1
ts_steps = (len(x_ts) // batch) + 1


model.fit(tr_dataset, 
          steps_per_epoch=tr_steps,
          validation_steps=ts_steps,
          epochs=epochs,
          validation_data=ts_dataset,
          callbacks=callbacks)

In [None]:
id2breed = {i: name for i, name in enumerate(breed)}

model = tf.keras.models.load_model("model.h5")

for i, path in tqdm(enumerate(x_ts[:50])):
    img = img_read(path,224)
    img = np.expand_dims(img, axis=0)
    pred = model.predict(img)[0]

    label_index = np.argmax(pred)
    breed_name = id2breed[label_index]

    orig_breed = id2breed[y_ts[i]]
    orig_img = cv2.imread(path, cv2.IMREAD_COLOR)

    orig_img = cv2.putText(orig_img, breed_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 1)
    orig_img = cv2.putText(orig_img, orig_breed, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

    if not os.path.exists(os.path.join(dir, "output")):
        os.makedirs(os.path.join(dir, "output"))
        
    cv2.imwrite(os.path.join(dir, "output", f"{i}.jpg"), orig_img)