## Initial attempt at setting up a simple image classification model using embeddings and KNN

- initially based on : https://github.com/rom1504/image_embeddings

In [None]:
!pip install efficientnet
!pip install pyarrow

!rm -rf ../data/embeddings_train
!mkdir ../data/embeddings_train

In [1]:
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as img

%matplotlib inline

import tensorflow as tf

#import glob
import tqdm # progress bar

import pyarrow.parquet as pq
import pyarrow as pa

from efficientnet.tfkeras import EfficientNetB0

import warnings
#warnings.filterwarnings("ignore", category=DeprecationWarning)
#warnings.filterwarnings("ignore", category=UserWarning)
#warnings.filterwarnings("ignore", category=FutureWarning)

base_path = '/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/images_train_tfrec/'
embed_path = '/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/'
str_list = ['_red.png', '_blue.png', '_green.png']# ,'_yellow.png']

In [2]:
sns.set_style("darkgrid") # darkgrid
sns.set_context("notebook") # paper, notebook, talk, and poster

## Load the data, run inference to get embeddings and save as parquet table files

- loading TFRecord files, that were split into shards and saved in the notebook "Convert_train_image_labels_to_tfrec.ipynb"
- trying to do this without shards didn't work, as we quickly ran out of memory

In [3]:
def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "path": tf.io.FixedLenFeature([], tf.string),
        # "area": tf.io.FixedLenFeature([], tf.float32),
        # "bbox": tf.io.VarLenFeature(tf.float32),
        # "category_id": tf.io.FixedLenFeature([], tf.int64),
        # "id": tf.io.FixedLenFeature([], tf.int64),
        "target_id": tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_png(example["image"], channels=3)
    #example["bbox"] = tf.sparse.to_dense(example["bbox"])
    return example

def save_embeddings_ds_to_parquet(embeddings, dataset, path):
    embeddings = pa.array(embeddings.tolist(), type=pa.list_(pa.float32()))
    target_ids = pa.array(dataset.map(lambda image_raw, target_id, image_path: target_id).as_numpy_iterator())
    image_paths = pa.array(dataset.map(lambda image_raw, target_id, image_path: image_path).as_numpy_iterator())
    table = pa.Table.from_arrays([target_ids, image_paths, embeddings], ["target_id", "image_path", "embedding"])
    # round about way to get image name without path
    df_out = table.to_pandas()
    df_out['image_name'] = df_out['image_path'].str.decode('UTF-8').str.split('/').str[-1].str.split('.').str[0]
    df_out['image_name'] = df_out['image_name'].str.encode('UTF-8')
    # back to parquet format for saving
    table = pa.Table.from_pandas(df_out)
    pq.write_table(table, path)
    
def images_to_embeddings(model, dataset, batch_size):
    return model.predict(dataset.batch(batch_size).map(lambda image_raw, target_id, image_path: image_raw), verbose=1)

def preprocess_image(d):
    target_id = d['target_id']
    raw = d['image']
    image_path = d['path']
    image_raw = tf.image.convert_image_dtype(raw, tf.float32)

    return image_raw, target_id, image_path

def get_dataset(file):

    #create the dataset
    dataset = tf.data.TFRecordDataset(file)

    #pass every single feature through our mapping function
    parsed_dataset = dataset \
        .map(parse_tfrecord_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
        .map(preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
        .apply(tf.data.experimental.ignore_errors())
    return parsed_dataset

def tfrecords_to_write_embeddings(tfrecords_folder, output_folder, model, batch_size):
    #files = glob.glob(tfr_dir+pattern, recursive=False)
    tfrecords = [f.numpy().decode("utf-8") for f in \
                 tf.data.Dataset.list_files(tfrecords_folder + "train_images*.tfrec", shuffle=False)]
                 
    for shard_id, tfrecord in enumerate(tfrecords):
        parsed_dataset = get_dataset(tfrecord)

        embeddings = images_to_embeddings(model, parsed_dataset, batch_size)
        print("")
        print("Shard " + str(shard_id))
        current_shard_name = "{}{}_{}of{}.parquet".format(output_folder,'train_embeddings', shard_id, len(tfrecords))
        print(current_shard_name)
        save_embeddings_ds_to_parquet(embeddings, parsed_dataset,current_shard_name)
        #                                output_folder + "embedding_part-" + "{:03d}".format(shard_id) + ".parquet")

    
def run_inference(tfrecords_folder, output_folder, batch_size=100):
    model = EfficientNetB0(weights='imagenet', include_top=False, pooling="avg")
    tfrecords_to_write_embeddings(tfrecords_folder, output_folder, model, batch_size)


In [5]:
if not os.path.exists(embed_path):
    # if directory doesn't exist, create and run_inference to create embeddings
    os.mkdir(embed_path)
    run_inference(base_path, embed_path, 10)


2022-06-30 13:17:55.785607: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-06-30 13:17:55.785820: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



2022-06-30 13:17:56.898588: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2022-06-30 13:17:57.479492: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.



Shard 0
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_0of32.parquet

Shard 1
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_1of32.parquet

Shard 2
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_2of32.parquet

Shard 3
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_3of32.parquet

Shard 4
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_4of32.parquet

Shard 5
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_5of32.parquet

Shard 6
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embeddings_6of32.parquet

Shard 7
/Users/ryan/neue_fische/human-protein-atlas-image-classification/data/embeddings_train/train_embedding

In [6]:
# load parquet files, convert to pandas and get strings as strings
emb = pq.read_table(embed_path).to_pandas()
for col in ['target_id', 'image_path', 'image_name']:
    emb[col] = emb[col].str.decode('utf-8')#.fillna(df[col]) 
emb

Unnamed: 0,target_id,image_path,embedding,image_name
0,7,../data/train/4b0d7acc-bbb5-11e8-b2ba-ac1f6b64...,"[-0.14782768, -0.19346946, 0.029495712, -0.114...",4b0d7acc-bbb5-11e8-b2ba-ac1f6b6435d0
1,5 0,../data/train/4b0e4648-bbc2-11e8-b2bb-ac1f6b64...,"[-0.11768038, 0.108380914, -0.07158355, -0.051...",4b0e4648-bbc2-11e8-b2bb-ac1f6b6435d0
2,23,../data/train/4b0fe352-bbbf-11e8-b2ba-ac1f6b64...,"[-0.14264518, -0.14894637, 0.015262008, -0.138...",4b0fe352-bbbf-11e8-b2ba-ac1f6b6435d0
3,2,../data/train/4b1164e4-bbaf-11e8-b2ba-ac1f6b64...,"[-0.14977421, -0.12706101, -0.17702478, -0.087...",4b1164e4-bbaf-11e8-b2ba-ac1f6b6435d0
4,25,../data/train/4b120c9e-bbb1-11e8-b2ba-ac1f6b64...,"[-0.15020615, -0.008659467, -0.14262204, -0.04...",4b120c9e-bbb1-11e8-b2ba-ac1f6b6435d0
...,...,...,...,...
31067,7,../data/train/9d04d730-bbb5-11e8-b2ba-ac1f6b64...,"[-0.095827445, -0.14014255, -0.1306586, -0.081...",9d04d730-bbb5-11e8-b2ba-ac1f6b6435d0
31068,4,../data/train/9d09c7e0-bb9c-11e8-b2b9-ac1f6b64...,"[-0.109680824, -0.07402489, -0.08704822, -0.08...",9d09c7e0-bb9c-11e8-b2b9-ac1f6b6435d0
31069,0,../data/train/9d0a7012-bbc6-11e8-b2bc-ac1f6b64...,"[-0.14760047, -0.105946906, 0.17695697, -0.130...",9d0a7012-bbc6-11e8-b2bc-ac1f6b6435d0
31070,25 0,../data/train/9d10ecec-bba0-11e8-b2b9-ac1f6b64...,"[-0.12235144, -0.12698074, -0.052961998, -0.10...",9d10ecec-bba0-11e8-b2b9-ac1f6b6435d0
