# Parallel Inference

Reproduces techniques used in pytorch versions. Loads from Snowflake here, but can easily just load from S3. 

Using model trained on multigpu script at batch size 64 x 40 epochs.

In [1]:
import numpy as np, pandas as pd
import requests, io, os, datetime, re

import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

from dask_saturn import SaturnCluster
from dask.distributed import Client
import dask

  warn_incompatible_dep(


In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.applications import resnet50
from keras.preprocessing import image

In [3]:
conn_kwargs = dict(
    user=os.environ['ANALYTICS_SNOWFLAKE_USER'] ,
    password=os.environ['ANALYTICS_SNOWFLAKE_PASSWORD'] ,
    account='mf80263.us-east-2.aws',
    warehouse="COMPUTE_WH",
    database="clothing_dataset",
    schema="PUBLIC",
    role="datascience_examples_writer",
)

In [4]:
stage = 'clothing_dataset'
relative_path_col = 'RELATIVE_PATH'

with snowflake.connector.connect(**conn_kwargs) as conn:
    df = pd.read_sql(f"""select FILE_URL, 
    RELATIVE_PATH, SIZE, LAST_MODIFIED,
    get_presigned_url(@{stage}, {relative_path_col}) 
    as SIGNEDURL from clothing_test""", conn)
    list_paths = df['SIGNEDURL']

In [5]:
@dask.delayed
def preprocess(list_img_attr):
    path, snow_path, filesize, orig_timestamp  = list_img_attr[4],list_img_attr[0],list_img_attr[2],list_img_attr[3]

    file1 = requests.get(path).content
    file2 = tf.image.decode_jpeg(file1, channels=3)
    img2 = tf.image.resize(file2,(224, 224))
    img_array = image.img_to_array(img2)
      
    truth = re.search('clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))', path).group(1)
    name = re.search('clothing-dataset-small/test/([a-z-]+)\/([^\/]+(\.jpg))', path).group(2)
    
    return [name, img_array, truth, path, snow_path, filesize, orig_timestamp]

In [6]:
@dask.delayed
def reformat(batch):
    batch_transposed = list(map(list, zip(*batch)))
    batch_transposed[1] = tf.stack(batch_transposed[1], axis=0, name='stack')
    return batch_transposed

In [7]:
def is_match(label, pred):
    ''' Evaluates human readable prediction against ground truth.'''
    if re.search(label.replace('_', ' '), str(pred).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)

In [8]:
@dask.delayed
def predict_class_resnet(iteritem):
    names, images, truelabels, paths, snow_paths, filesizes, orig_timestamps = iteritem

    # Using model trained by user.
    model = keras.models.load_model('./tensorflow_ds/model/keras/')

    indices = list(range(0,10))
    classes = ["dress", "hat", "longsleeve", "outwear", "pants", "shirt", "shoes", "shorts", "skirt", "t-shirt"]
    classes2 = dict(zip(indices, classes))

    img_array2 = resnet50.preprocess_input(images)
    predictions = model.predict(img_array2)
    predicted_classes = [np.argmax(prediction) for prediction in predictions]#
    pred_string = [classes2[x] for x in predicted_classes]

    #Organize prediction results
    outcomes = []
    for j in range(0, len(images)):
        match = is_match(truelabels[j], pred_string[j])
        outcome = {'name': names[j], 'ground_truth': truelabels[j], 
                   'prediction': predicted_classes[j],  
                   "pred_text": pred_string[j],
                   "match": match,
                  'snow_path':snow_paths[j],'filesize':filesizes[j],
                   'orig_timestamp':orig_timestamps[j]}
        outcomes.append(outcome)

    return(outcomes)

In [9]:
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(2)
client

INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins
INFO:dask-saturn:Success!


0,1
Client  Scheduler: tcp://d-steph-tensorflow-dev-c684c6ee502c425a8142c741e3f0af9a.main-namespace:8786  Dashboard: https://d-steph-tensorflow-dev-c684c6ee502c425a8142c741e3f0af9a.internal.saturnenterprise.io,Cluster  Workers: 5  Cores: 80  Memory: 295.69 GiB


In [10]:
client.restart() # Clears memory on cluster- optional but recommended.

n = 32  #batch size
list_df = [df[i:i+n] for i in range(0,df.shape[0],n)]
image_rows = [[x for j,x in y.iterrows()] for y in list_df]
image_batches1 = [[preprocess(list(x)) for x in y] for y in image_rows]
image_batches = [reformat(result) for result in image_batches1]

In [11]:
from dask_saturn.plugins import RegisterFiles, sync_files
client.register_worker_plugin(RegisterFiles())
sync_files(client, "/home/jovyan/git-repos/tensorflow_ds/model/keras")
client.run(os.listdir, './tensorflow_ds/model/keras')

{'tcp://192.168.143.4:42289': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.173.4:39083': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.231.196:46519': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.29.4:35909': ['assets', 'variables', 'saved_model.pb'],
 'tcp://192.168.76.68:45899': ['assets', 'variables', 'saved_model.pb']}

In [14]:
%%time

futures = client.map(predict_class_resnet, image_batches) 
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)

CPU times: user 245 ms, sys: 6.07 ms, total: 251 ms
Wall time: 304 ms


In [13]:
df2 = pd.DataFrame(results) 
df2.dtypes

name                           object
ground_truth                   object
prediction                      int64
pred_text                      object
match                            bool
snow_path                      object
filesize                        int64
orig_timestamp    datetime64[ns, UTC]
dtype: object

In [15]:
true_preds = [x['match'] for x in results if x['match'] == True]
false_preds = [x['match'] for x in results if x['match'] == False]
len(true_preds)/len(results)*100

46.774193548387096

In [16]:
df = pd.DataFrame(results) 

In [17]:
df.head()

Unnamed: 0,name,ground_truth,prediction,pred_text,match,snow_path,filesize,orig_timestamp
0,06a00c0f-5f9a-410d-a7da-3881a9df3a71.jpg,dress,9,t-shirt,False,https://MF80263.us-east-2.aws.snowflakecomputi...,34678,2021-07-09 16:36:11+00:00
1,28b09463-6bbb-491d-9ffc-f36df5c6b211.jpg,dress,3,outwear,False,https://MF80263.us-east-2.aws.snowflakecomputi...,28462,2021-07-09 16:36:10+00:00
2,35f157d0-53e4-4496-b087-da4ad63edd47.jpg,dress,4,pants,False,https://MF80263.us-east-2.aws.snowflakecomputi...,18871,2021-07-09 16:36:09+00:00
3,3f844e1e-4a00-4b64-8c1d-3b847191bf11.jpg,dress,3,outwear,False,https://MF80263.us-east-2.aws.snowflakecomputi...,45101,2021-07-09 16:36:13+00:00
4,4ceed2f1-8e20-4439-9c27-cceb8d2257a4.jpg,dress,1,hat,False,https://MF80263.us-east-2.aws.snowflakecomputi...,30272,2021-07-09 16:36:12+00:00


In [18]:
pd.crosstab(df.ground_truth, df.pred_text)

pred_text,dress,hat,longsleeve,outwear,pants,shirt,shoes,shorts,t-shirt
ground_truth,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
dress,2,2,0,6,1,2,0,1,1
hat,0,2,0,3,0,3,1,2,1
longsleeve,0,3,19,20,0,23,2,1,4
outwear,0,0,4,16,0,16,0,2,0
pants,0,0,0,5,30,0,1,6,0
shirt,0,0,1,3,0,20,0,1,1
shoes,0,5,3,16,0,3,36,10,0
shorts,0,1,1,2,0,1,0,23,2
skirt,0,2,3,0,0,0,1,5,1
t-shirt,0,3,0,10,0,8,3,3,25
