In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, IntegerType, FloatType
from pyspark.ml.torch.distributor import TorchDistributor
from transformers import ViTForImageClassification, ViTImageProcessor
import matplotlib.pyplot as plt
import pickle
from PIL import Image
import time
import numpy as np

In [None]:
spark = SparkSession.builder.\
        appName("ImageClassification").\
        master("local[8]").\
        config("spark.executor.memory", "16G").\
        config("spark.driver.memory", "16G").\
        getOrCreate()

In [None]:
image_dir = 'images/'
dataset_path = '/Users/ykamoji/Documents/ImageDatabase/cifar-10-batches-py/'

In [None]:
## This is for streaming images from API that downloads it, otherwise skip this if data coming in array already...


# def to_np_array(x):    
#   height = 32
#   width  = 32
#   nChannels = 3
#   return np.reshape(x, (height,width,nChannels)).tolist()
# 
# spark_to_np_array = udf(to_np_array, ArrayType(ArrayType(ArrayType(IntegerType()))))

In [None]:
## Possible for spark streaming...

# image_df = spark.read.format("image").option("dropInvalid", True).load(image_dir,inferschema=True)
# 
# image_len = image_df.count()
# 
# image_df = image_df.withColumn("ndarray", spark_to_np_array(image_df["image.data"]))
# image_df.select("image.origin", "image.width", "image.height","ndarray").show(truncate=False)

# def processImage(im):
#         return np.array(im.ndarray).astype(np.uint8)

# for i in range(image_len):
#         plt.imshow(image_df.select('ndarray').rdd.map(processImage).collect()[i])
#         plt.show()

In [None]:
## Visualize data
training_data = pickle.load(open(dataset_path + f'data_batch_1', 'rb'), encoding='latin-1')

In [None]:
# index = 4
# im = training_data['data'][index].reshape(3, 32, 32).transpose(1,2,0)
# # print(im)
# plt.imshow(im), training_data['labels'][index], training_data['filenames'][index]

In [None]:
train_dataset = []
test_dataset = []
label_map = []
for i in range(1,2):
    data = pickle.load(open(dataset_path + f'data_batch_{i}', 'rb'), encoding='latin-1')
    train_dataset.extend(zip(data["data"], data["labels"]))
    
test_data = pickle.load(open(dataset_path + f'test_batch', 'rb'), encoding='latin-1')
test_dataset.extend(zip(test_data["data"], test_data["labels"]))

meta = pickle.load(open(dataset_path + f'batches.meta', 'rb'), encoding='latin-1')
label_map = { index:label for index, label in enumerate(meta['label_names'])}

In [None]:
# def reshape_image(record):
#     image, label = record
#     height = 32
#     width  = 32
#     nChannels = 3
#     data = [float(x) for x in image.reshape(nChannels, height,width).transpose(1,2,0).flatten()]
#     return data, label
# 
# image_rdd = spark.sparkContext.parallelize(train_dataset, numSlices=500).map(reshape_image)
# 
# imagesWithLabels = image_rdd.toDF(["image", "label"])
# 
# convert_to_float = udf(lambda x: x, ArrayType(FloatType()))
# imagesWithLabels = imagesWithLabels.withColumn("image", convert_to_float(col("image")))
# imagesWithLabels.printSchema()
# imagesWithLabels.cache()

In [None]:
# imagesWithLabels.select("label", "image").show(5, truncate=False)

In [None]:
# for i in range(5):
#         plt.imshow(images[i])
#         plt.show()

In [None]:
processor = ViTImageProcessor.from_pretrained('aaraki/vit-base-patch16-224-in21k-finetuned-cifar10', cache_dir='models/')
model = ViTForImageClassification.from_pretrained('aaraki/vit-base-patch16-224-in21k-finetuned-cifar10', cache_dir='models/')
model.eval()

In [None]:
def reshape_image(record):
    image, label = record
    height = 32
    width  = 32
    nChannels = 3
    data = [float(x) for x in image.reshape(nChannels, height,width).transpose(1,2,0).flatten()]
    image = np.array(data).reshape(32,32, 3).astype(np.uint8)
    input = processor(images=image, return_tensors="pt")
    return input, label
   

def predictImage(record):
    gt = record[1]
    class_label = label_map[gt]
    output = model(**record[0])
    logits = output.logits
    pred = logits.argmax(-1).item()
        
    return class_label, 1 if gt == pred else 0


def calc_acc(record):
    return record[0], f"{float(sum(record[1]) * 100 / len(record[1])):.2f}%"


In [751]:
test = test_dataset[:300]
start = time.time()
results = spark.sparkContext.parallelize(test).map(reshape_image).map(predictImage).groupByKey().mapValues(list).map(calc_acc).collect()
print(f"Time taken = {(time.time() - start):.3f} sec \n\n")
for class_label, acc in results:
    print(f"{class_label} : {acc}")

# correct_labels = sum([ 1 if res else 0 for res in results]) 
# print(f"Accuracy {correct_labels*100/len(test):.3f} %")



Time taken = 22.024 sec 


bird : 100.00%
airplane : 97.22%
ship : 100.00%
cat : 100.00%
deer : 100.00%
dog : 96.43%
truck : 94.29%
horse : 100.00%
frog : 100.00%
automobile : 95.83%


                                                                                

In [None]:
# spark.stop()
