# Parallel Inference

Reproduces techniques used in pytorch versions. Loads from Snowflake here, but can easily just load from S3. 

In [1]:
import numpy as np, pandas as pd
import requests, io, os, datetime, re

import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

from dask_saturn import SaturnCluster
from dask.distributed import Client
import dask

  warn_incompatible_dep(


In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.applications import resnet50
from keras.preprocessing import image

In [4]:
conn_kwargs = dict(
    user=os.environ['ANALYTICS_SNOWFLAKE_USER'] ,
    password=os.environ['ANALYTICS_SNOWFLAKE_PASSWORD'] ,
    account='mf80263.us-east-2.aws',
    warehouse="COMPUTE_WH",
    database="clothing_dataset",
    schema="PUBLIC",
    role="datascience_examples_writer",
)

In [5]:
stage = 'clothing_dataset'
relative_path_col = 'RELATIVE_PATH'

with snowflake.connector.connect(**conn_kwargs) as conn:
    df = pd.read_sql(f"""select FILE_URL, 
    RELATIVE_PATH, SIZE, LAST_MODIFIED,
    get_presigned_url(@{stage}, {relative_path_col}) 
    as SIGNEDURL from clothing_test""", conn)
    list_paths = df['SIGNEDURL']

In [6]:
@dask.delayed
def preprocess(list_img_attr):
    path, snow_path, filesize, orig_timestamp  = list_img_attr[4],list_img_attr[0],list_img_attr[2],list_img_attr[3]

    file1 = requests.get(path).content
    file2 = tf.image.decode_jpeg(file1, channels=3)
    img2 = tf.image.resize(file2,(224, 224))
    img_array = image.img_to_array(img2)
      
    truth = re.search('clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))', path).group(1)
    name = re.search('clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))', path).group(2)
    
    return [name, img_array, truth, path, snow_path, filesize, orig_timestamp]

In [7]:
@dask.delayed
def reformat(batch):
    batch_transposed = list(map(list, zip(*batch)))
    batch_transposed[1] = tf.stack(batch_transposed[1], axis=0, name='stack')
    return batch_transposed

In [8]:
def is_match(label, pred):
    ''' Evaluates human readable prediction against ground truth.'''
    if re.search(label.replace('_', ' '), str(pred).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)

In [9]:
@dask.delayed
def predict_class_resnet(iteritem):
    names, images, truelabels, paths, snow_paths, filesizes, orig_timestamps = iteritem

    # Using model trained by user.
    model = keras.models.load_model('./tensorflow_ds/model/keras/')

    indices = list(range(0,10))
    classes = ["dress", "hat", "longsleeve", "outwear", "pants", "shirt", "shoes", "shorts", "skirt", "t-shirt"]
    classes2 = dict(zip(indices, classes))

    img_array2 = resnet50.preprocess_input(images)
    predictions = model.predict(img_array2)
    predicted_classes = [np.argmax(prediction) for prediction in predictions]#
    pred_string = [classes2[x] for x in predicted_classes]

    #Organize prediction results
    outcomes = []
    for j in range(0, len(images)):
        match = is_match(truelabels[j], pred_string[j])
        outcome = {'name': names[j], 'ground_truth': truelabels[j], 
                   'prediction': predicted_classes[j],  
                   "pred_text": pred_string[j],
                   "match": match,
                  'snow_path':snow_paths[j],'filesize':filesizes[j],
                   'orig_timestamp':orig_timestamps[j]}
        outcomes.append(outcome)

    return(outcomes)

In [11]:
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(2)
client

INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:Success!


0,1
Client  Scheduler: tcp://d-steph-tensorflow-dev-c684c6ee502c425a8142c741e3f0af9a.main-namespace:8786  Dashboard: https://d-steph-tensorflow-dev-c684c6ee502c425a8142c741e3f0af9a.internal.saturnenterprise.io,Cluster  Workers: 3  Cores: 12  Memory: 43.31 GiB


In [12]:
client.restart() # Clears memory on cluster- optional but recommended.

n = 32  #batch size
list_df = [df[i:i+n] for i in range(0,df.shape[0],n)]
image_rows = [[x for j,x in y.iterrows()] for y in list_df]
image_batches1 = [[preprocess(list(x)) for x in y] for y in image_rows]
image_batches = [reformat(result) for result in image_batches1]

In [13]:
from dask_saturn.plugins import RegisterFiles, sync_files
client.register_worker_plugin(RegisterFiles())
sync_files(client, "/home/jovyan/git-repos/tensorflow_ds/model/keras")
client.run(os.listdir, './tensorflow_ds/model/keras')

{'tcp://192.168.17.4:36661': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.231.132:36789': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.243.4:38861': ['assets', 'variables', 'saved_model.pb']}

In [None]:
%%time

futures = client.map(predict_class_resnet, image_batches) 
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)

In [None]:
df2 = pd.DataFrame(results) 
df2.dtypes

In [None]:
true_preds = [x['match'] for x in results if x['match'] == True]
false_preds = [x['match'] for x in results if x['match'] == False]
len(true_preds)/len(results)*100

In [None]:
sample = dask.compute(*image_batches)
s5 = list(map(list, zip(*sample)))

test_names = [i for sublist in s5[0] for i in sublist]
test_tensors = [i for sublist in s5[1] for i in sublist]
test_orig = [i for sublist in s5[2] for i in sublist]
test_final = list(zip(test_names, test_tensors, test_orig))

In [None]:
expanded_list = [(i, j) for i in results for j in test_final if i['name'] in test_names and j[0] in test_names and i['name'] == j[0]]

In [None]:
import matplotlib.pyplot as plt

to_pil = transforms.ToPILImage()
imglist = expanded_list[325:330]
f, ax = plt.subplots(nrows=1, ncols=5, figsize=(16,6))

for i in range(0,5):
    img1 = to_pil(imglist[i][1][1]) 
    ax[i].imshow(img1).axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
    textcol = "green" if imglist[i][0]["evaluation"] == True else "red"
    ax[i].set_title(
        f'''Predicted Class: {imglist[i][0]["prediction_text"]} 
Actual Class: {imglist[i][0]["ground_truth"]} ''',
        color=textcol)

title = 'Sample Images'
f.suptitle(title, fontsize=16)
plt.tight_layout()
plt.show()