# Using GPU Clusters for Deep Learning Inference

This demo is meant to show new users how you can get use Dask and GPUs on Saturn Cloud to do fast deep learning inference with your Pytorch based machine learning models. For this example, we'll use an image classification project identifying dog breeds.


In [None]:
from dask.distributed import Client, wait, progress
import time
import dask
from dask import persist, delayed, compute
import dask_saturn
from dask_saturn import SaturnCluster
import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import toolz
import glob
import re
import os
import s3fs
import graphviz

import torch
from torchvision import datasets, transforms, models
from torch import nn, optim
import torch.nn.functional as F

## Setup

First things first- we need to set up a GPU cluster and confirm all resources are ready.

In [None]:
cluster = SaturnCluster()
client = Client(cluster)
client.wait_for_workers(3)
client

This snippet tells us that the cluster is operable and has GPU capability.

In [None]:
torch.cuda.is_available() 

In [None]:
client.run(lambda: torch.cuda.is_available())

## Large Dataset Example

In [None]:
s3 = s3fs.S3FileSystem(anon=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

***

In [None]:
@dask.delayed
def preprocess(path, fs=__builtins__):
    '''Ingest images directly from S3, apply transformations,
    and extract the ground truth and image identifier. Accepts
    a filepath. '''
    
    transform = transforms.Compose([
        transforms.Resize(256), 
        transforms.CenterCrop(250), 
        transforms.ToTensor(),
    ])

    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        nvis = transform(img)

    truth = re.search('dogs/Images/n[0-9]+-([^/]+)/n[0-9]+_[0-9]+.jpg', path).group(1)
    name = re.search('dogs/Images/n[0-9]+-[a-zA-Z-_]+/(n[0-9]+_[0-9]+).jpg', path).group(1)
    
    return [name, nvis, truth]

In [None]:
s3fpath = 's3://saturn-public-data/dogs/Images/*/*.jpg'
batch_breaks = [list(batch) for batch in toolz.partition_all(80, s3.glob(s3fpath))]

In [None]:
image_batches = [[preprocess(x, fs=s3) for x in y] for y in batch_breaks]

In [None]:
@dask.delayed
def reformat(batch):
    flat_list = [item for item in batch]
    tensors = [x[1] for x in flat_list]
    names = [x[0] for x in flat_list]
    labels = [x[2] for x in flat_list]
    
    tensors = torch.stack(tensors).to(device)
    
    return [names, tensors, labels]

image_batches = [reformat(result) for result in image_batches]

In [None]:
def evaluate_pred_batch(batch, gtruth, classes):
    ''' Accepts batch of images, returns human readable predictions. '''
    
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100
    
    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)

        labs = gtruth[i]
        labslist.append(labs)
        
    return(preds, labslist)

def is_match(label, pred):
    ''' Evaluates human readable prediction against ground truth.'''
    if re.search(label.replace('_', ' '), str(pred).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)

In [None]:
@dask.delayed
def run_batch_to_s3(iteritem):
    ''' Accepts iterable result of preprocessing, generates
    inferences and evaluates. '''
  
    names, images, truelabels = iteritem
    
    with s3.open('s3://saturn-public-data/dogs/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    # Retrieve, set up model
    resnet = models.resnet50(pretrained=True)
    resnet = resnet.to(device)

    with torch.no_grad():
        resnet.eval()
        pred_batch = resnet(images)
        
        #Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

        #Organize prediction results
        outcomes = []
        for j in range(0, len(images)):
            match = is_match(labslist[j], preds[j])            
            outcome = {'name': names[j], 'ground_truth': labslist[j], 
                       'prediction': preds[j], 'evaluation': match}
            outcomes.append(outcome)
    
        return(outcomes)

In [None]:
%%time

futures = client.map(run_batch_to_s3, image_batches) 
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)

In [None]:
test_sample = run_batch_to_s3(image_batches[0])
test_sample.visualize(rankdir="LR")

In [None]:
true_preds = [x['evaluation'] for x in results if x['evaluation'] == True]
false_preds = [x['evaluation'] for x in results if x['evaluation'] == False]
len(true_preds)/len(results)*100
