In [None]:
from dask.distributed import Client, wait, progress
import time
import dask
from dask import persist, delayed, compute
import warnings
import io
import glob
import os
import toolz
import re
from PIL import Image
import logging

scheduler = "<SCHEDULER_EXTERNAL_IP>"
dogs_dir_path = "/ml-share/dogs"

client = Client(f"{scheduler}:8786")
client.restart()

def clear_cuda_cache():
    import torch
    torch.cuda.empty_cache()
    return torch.cuda.is_available()
client.run(clear_cuda_cache)

In [None]:
@dask.delayed
def preprocess(path):
    import torch
    from torchvision import datasets, transforms, models

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(250),
        transforms.ToTensor()])

    with open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        nvis = transform(img)

    truth = re.search('dogs/Images/n[0-9]+-([^/]+)/n[0-9]+_[0-9]+.jpg', path).group(1)
    name = re.search('dogs/Images/n[0-9]+-[a-zA-Z-_]+/(n[0-9]+_[0-9]+).jpg', path).group(1)

    return [name, nvis, truth]

@dask.delayed
def reformat(batch):
    flat_list = [item for item in batch]
    tensors = [x[1] for x in flat_list]
    names = [x[0] for x in flat_list]
    labels = [x[2] for x in flat_list]
    return [names, tensors, labels]

def evaluate_pred_batch(batch, gtruth, classes):
    import torch
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100

    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)
        labs = gtruth[i]
        labslist.append(labs)
    return(preds, labslist)

def is_match(la, ev):

    if re.search(la.replace('_', ' '), str(ev).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)

def get_truth_name(name):
    x = re.search('^[^_]*-(.*)$', name).group(1)
    return(x)

@dask.delayed
def run_batch(iteritem):
    ''' Accepts iterable result of preprocessing,
    generates inferences and evaluates. '''
    import pickle
    import ssl
    ssl._create_default_https_context = ssl._create_unverified_context
    import torch
    from torchvision import datasets, transforms, models   
    with open(f'{dogs_dir_path}/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]
        
    names, images, truelabels = iteritem

    images = torch.stack(images)
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cuda")
    device = torch.device("cpu")
    count = 0
    correct_pred = 0
    with torch.no_grad():
        # Set up model
        resnet = models.resnet50(pretrained=True)
        resnet = resnet.to(device)
        resnet.eval()

        # run model on batch
        images = images.to(device)
        pred_batch = resnet(images)

        #Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

        #Organize prediction results
        for j in range(0, len(images)):
            predicted = preds[j]
            groundtruth = labslist[j]
            name = names[j]
            match = is_match(groundtruth, predicted)

            outcome = {'name': name, 'ground_truth': groundtruth, 'prediction': predicted, 'evaluation': match}
            count += 1
            if match is True:
                correct_pred += 1
            # Write each result to S3 directly
            #with open(f"/ml-share/dogs/preds/{name}.pkl", "wb") as f:
            #    pickle.dump(outcome, f)
        return({'count': count,"correct_pred":correct_pred})


In [None]:
%%time
batch_breaks = [list(batch) for batch in toolz.partition_all(294, glob.iglob(f'{dogs_dir_path}/Images/*/*.jpg'))]                                                   
image_batches = [[preprocess(x) for x in y] for y in batch_breaks]
image_batches = [reformat(result) for result in image_batches]

In [None]:
%%time
futures = client.map(run_batch, image_batches)
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)

In [None]:
%%time
errors = []
count = 0
correct = 0

for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        correct += result['correct_pred']
        count += result['count']
print(f'There are {count} photos, {correct} of them are predicted correctly')
print(f'The percent of dogs classified correctly: {correct/count*100}%')
