<a href="https://colab.research.google.com/github/sampsonmao/jordans_classifier/blob/main/jordans_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install opendatasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22


In [4]:
import os
import opendatasets as od
import shutil
from PIL import Image
import random
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import utils

In [5]:
tf.config.list_physical_devices("GPU")

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [144]:
od.download("https://www.kaggle.com/datasets/shreykavi/air-jordans-retro-121")

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: sampsonmao
Your Kaggle Key: ··········
Downloading air-jordans-retro-121.zip to ./air-jordans-retro-121


100%|██████████| 1.57G/1.57G [00:07<00:00, 234MB/s]





In [15]:
from pathlib import Path

In [156]:
DOWNLOADED_IMAGES_DIR = Path('./content/air-jordans-retro-121')

In [155]:
for upper_dir_path in DOWNLOADED_IMAGES_DIR.iterdir():
    for image_path in upper_dir_path.rglob("*"):
        image_path.rename(upper_dir_path/image_path.name)
    image_path.parent.rmdir()
    

In [None]:
for x in os.walk(DOWNLOADED_IMAGES_DIR):
    os.makedirs(x[0].replace("air-jordans-retro-121", "data"), exist_ok=True)
    os.makedirs(x[0].replace("air-jordans-retro-121", "corrupted"), exist_ok=True)

for path, subdirs, files in os.walk(DOWNLOADED_IMAGES_DIR):
    print("Sorting", path)
    for filename in files:
        img_path = os.path.join(path, filename)
        try:
            img = Image.open(img_path)
            save_path = img_path.replace("air-jordans-retro-121", "data")
            if img.format in ["BMP", "GIF", "JPEG"]:
                converted_img = img.convert("RGB")
                converted_img.save(save_path)
            elif img.format == "PNG":
                converted_img = img.convert("RGBA")
                # Some pngs were saved with .jpg. This prevents saving with the same name as a 'corrupted' image.
                converted_img.save(save_path.replace(".jpg", ".png"))
            else:
                img.save(save_path)
        except:
            save_path = img_path.replace("air-jordans-retro-121", "corrupted")
            shutil.copy(img_path, save_path)

Sorting content/air-jordans-retro-121
Sorting content/air-jordans-retro-121/1
Sorting content/air-jordans-retro-121/19
Sorting content/air-jordans-retro-121/9
Sorting content/air-jordans-retro-121/15
Sorting content/air-jordans-retro-121/11
Sorting content/air-jordans-retro-121/20
Sorting content/air-jordans-retro-121/21
Sorting content/air-jordans-retro-121/12
Sorting content/air-jordans-retro-121/13
Sorting content/air-jordans-retro-121/5
Sorting content/air-jordans-retro-121/18
Sorting content/air-jordans-retro-121/14
Sorting content/air-jordans-retro-121/2
Sorting content/air-jordans-retro-121/3
Sorting content/air-jordans-retro-121/4
Sorting content/air-jordans-retro-121/6


In [None]:
DATA_DIR = Path('./content/data')

In [None]:
img_path_list = []
img_path_dict = {str(i + 1): [] for i in range(21)}
for path, subdirs, files in os.walk(DATA_DIR):
    for filename in files:
        # List of all image paths
        img_path = os.path.join(path, filename)
        img_path_list.append(img_path)

        # Paths grouped by shoe
        shoe_number = img_path.split("\\")[-2]
        img_path_dict[shoe_number].append(img_path)

In [None]:
random.seed(10)

utils.show_random_set_of_shoes(img_path_list, 12, col_wrap=4)

In [None]:
random.seed(10)

utils.show_shoe_series(img_path_dict, 1, 12, col_wrap=4)

# New Section

In [None]:
datagen = ImageDataGenerator(
    rescale=1.0 / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode="nearest",
    validation_split=0.2,
)

In [None]:
train_ds = datagen.flow_from_directory(
    DATA_DIR,
    subset="training",
    target_size=(256, 256),
    class_mode="sparse",
)
val_ds = datagen.flow_from_directory(
    DATA_DIR,
    subset="validation",
    target_size=(256, 256),
    class_mode="sparse",
)

Found 3920 images belonging to 21 classes.
Found 973 images belonging to 21 classes.


In [None]:
inputs = tf.keras.Input(shape=(256, 256, 3))
x = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3), activation="relu")(inputs)
x = tf.keras.layers.MaxPooling2D(2, 2)(x)
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation="relu")(x)
output = tf.keras.layers.Dense(21, activation="softmax")(x)

model = tf.keras.Model(inputs=inputs, outputs=output)

In [None]:
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["sparse_categorical_accuracy"],
)

In [None]:
model.fit(train_ds, validation_data=val_ds, epochs=100)

Epoch 1/100
  7/123 [>.............................] - ETA: 46:59 - loss: 61.6138 - accuracy: 0.1161

KeyboardInterrupt: ignored