In [8]:
import PIL
import shutil
import os
import re
import warnings
from time import time

warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TFHUB_CACHE_DIR'] = '/home/x/tfhub_modules'  # keep everything in here instead of in /tmp

from numpy import r_, array
import tensorflow as tf
from PIL import Image
from pathlib import Path
import numpy as np
from glob import glob
import pandas as pd
from tqdm import tqdm

[tf.config.experimental.set_memory_growth(d, 1) for d in tf.config.list_physical_devices('GPU')]

TIMES = []


def guess(cache: str, model_path: str, input_size: tuple[int, ...], image_path: str) -> list:
    """
    Crude guess function so that you don't have to input a label filename for each model.

    :param cache: tfhub_modules cache dir
    :param model_path: path to model
    :param input_size: image dim that model expects
    :param image_path: path to input image
    :return: list of class labels
    """
    img = r_[[array(Image.open(glob(f'{image_path}/*')[0]).resize(input_size)) / 255.]].astype(np.float32)
    intp = tf.lite.Interpreter(model_path=f'{cache}/{model_path}')
    intp.allocate_tensors()
    intp.set_tensor(intp.get_input_details()[0]['index'], img)
    intp.invoke()
    tensor = intp.get_tensor(intp.get_output_details()[0]['index'])
    cls = ...
    if tensor.shape[1] == 1000:
        cls = Path('labels/ilsvrc2012_wordnet_lemmas.txt').read_text().splitlines()
    elif tensor.shape[1] == 1001:
        cls = Path('labels/ImageNetLabels.txt').read_text().splitlines()
    elif tensor.shape[1] == 21_843:
        cls = Path('labels/imagenet21k_wordnet_lemmas.txt').read_text().splitlines()
    return cls


def run(cache: str, model_path: str, input_size: tuple[int, ...], labels: list, image_path: str):
    image_paths = []
    preds = []
    images = list((lambda p, s: filter(re.compile(p).match, s))(r'.*\.(jpg|png|jpeg)', glob(f'{image_path}/*')))
    # print('# images = ', len(images))
    for image_path in images[:10]:
        try:
            img = r_[[array(Image.open(image_path).resize(input_size)) / 255.]].astype(np.float32)
            intp = tf.lite.Interpreter(model_path=f'{cache}/{model_path}')
            intp.allocate_tensors()
            intp.set_tensor(intp.get_input_details()[0]['index'], img)
            intp.invoke()
            tensor = intp.get_tensor(intp.get_output_details()[0]['index'])
            pred = labels[tf.math.argmax(tensor, axis=-1)[0]]
            preds.append(pred)
            image_paths.append(Path(image_path).name)

        except (OSError, FileNotFoundError, PIL.UnidentifiedImageError) as e:
            print('\n', e)
    return dict(zip(image_paths, preds))


def infer(model_path: str, input_size: tuple[int, ...], image_path: str) -> dict:
    """
    Predict class labels for input images

    :param model_path: path to model
    :param input_size: image dimensions for 3 channel image
    :param image_path: path to image
    :return:

    *Class labels are notoriously linked incorrectly in tensorflow hub. Ensure correct labels are downloaded.
    ILSVRC 2012 - 1,000  classes - https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
    ImageNet1K  - 1,001  classes - https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
    ImageNet21k - 21,843 classes - https://storage.googleapis.com/bit_models/imagenet21k_wordnet_lemmas.txt

    """
    start = time()
    # print(f'\n{"-" * len(model_path)}\n{model_path}\n{"-" * len(model_path)}\n')
    cache = 'models'
    labels = guess(cache, model_path, input_size, image_path)
    preds = run(cache, model_path, input_size, labels, image_path)
    end = time() - start
    # print(f'{(end):.3f}')
    TIMES.append([model_path, end])
    return preds


def undo_labeling(img_directory: str) -> None:
    """
    Undo labeling performed by `label.py`

    :param img_directory: path to image directory that previously had it's contents labeled
    :return: None
    """
    # move images out of directories
    for d in [f for f in Path(img_directory).iterdir() if f.is_dir()]:
        for img in d.iterdir():
            new = d.parent / img.name
            print('*', img, '->', new)
            Path(img).rename(new)
    # remove directories
    [shutil.rmtree(p) for p in Path(img_directory).iterdir() if p.is_dir()]

In [12]:
models = [
    ('nasnet_mobile_classification_5_qt_16x8.tflite', (224, 224)),
    ('nasnet_mobile_classification_5_qt_float16.tflite', (224, 224)),
    ('mobilenet_v1_0.25_224.tflite',(224,224)),
]

### Loop through models

In [13]:
preds = []
for name, shape in tqdm(models, total=len(models)):
    preds.append(infer(name, shape, 'images'))
df = pd.DataFrame(preds)

keywords = Path('keywords.txt').read_text().splitlines()

images_and_classes = [[col, k] for k in keywords for col in df.columns if k in df[col].to_numpy()]
print('found classes:', {x[1] for x in images_and_classes})

100%|██████████| 3/3 [00:10<00:00,  3.57s/it]

found classes: set()





In [14]:
df.T.rename({i: name for i, name in enumerate(models)}, axis=1)

Unnamed: 0,"(nasnet_mobile_classification_5_qt_16x8.tflite, (224, 224))","(nasnet_mobile_classification_5_qt_float16.tflite, (224, 224))","(mobilenet_v1_0.25_224.tflite, (224, 224))"
1480305774.jpg,stage,stage,hammerhead
1480385041.jpg,pot,pot,sea anemone
1480813997.jpg,traffic light,traffic light,flagpole
1482075542.jpg,Siamese cat,Siamese cat,Siamese cat
1482180008.jpg,pot,pot,jellyfish
1482256683.jpg,abaya,abaya,hoopskirt
1482382512.jpg,sliding door,sliding door,sliding door
1482984597.jpg,Labrador retriever,Labrador retriever,potter's wheel
1485405339.jpg,sleeping bag,sleeping bag,oxygen mask
1487219423.jpg,stage,stage,bathing cap


In [15]:
times = pd.DataFrame(TIMES, columns=['model', 't']).sort_values('t').reset_index(drop=True)
times

Unnamed: 0,model,t
0,mobilenet_v1_0.25_224.tflite,0.50166
1,mobilenet_v1_0.25_224.tflite,0.562329
2,nasnet_mobile_classification_5_qt_float16.tflite,1.905808
3,nasnet_mobile_classification_5_qt_float16.tflite,2.347075
4,nasnet_mobile_classification_5_qt_16x8.tflite,7.747859
5,nasnet_mobile_classification_5_qt_16x8.tflite,8.244617


### Create class dirs

In [16]:
existing_classes = {f.name for f in Path(f'images').iterdir() if f.is_dir()}
matched_classes = set(cls for _, cls in images_and_classes)
classes = matched_classes - existing_classes

print(f'creating new class dirs: {classes}')
for cls in classes:
    Path(f'images/{cls}').mkdir(parents=True, exist_ok=False)

for img, cls in images_and_classes:
    try:
        o = f'images/{img}'
        n = f'images/{cls}/{img}'
        print(f'{o} -> {n}')
        Path(o).rename(n)
    except Exception as e:
        print(e)

creating new class dirs: set()


### Undo Labeling if needed

In [None]:
# undo_labeling('images')