# Install openjdk and spark

In [1]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-3.2.1/spark-3.2.1-bin-hadoop3.2.tgz
!tar xvzf spark-3.2.1-bin-hadoop3.2.tgz
!pip install -q findspark
!pip install pyarrow
!pip install pyspark
try:
  # %tensorflow_version only exists in Colab.
  !pip install --disable-pip-version-check install tf-nightly
except Exception:
  pass

spark-3.2.1-bin-hadoop3.2/
spark-3.2.1-bin-hadoop3.2/LICENSE
spark-3.2.1-bin-hadoop3.2/NOTICE
spark-3.2.1-bin-hadoop3.2/R/
spark-3.2.1-bin-hadoop3.2/R/lib/
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/DESCRIPTION
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/INDEX
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/Rd.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/features.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/hsearch.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/links.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/nsInfo.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/package.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/Meta/vignette.rds
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/NAMESPACE
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/R/
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/R/SparkR
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/R/SparkR.rdb
spark-3.2.1-bin-hadoop3.2/R/lib/SparkR/R/SparkR.rdx
spark-3.2.1-bin-hadoop3.2/R/lib/Sp

# Setup Spark and Java environment

In [2]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.2.1-bin-hadoop3.2"
import findspark
from __future__ import absolute_import, division, print_function, unicode_literals
from pyspark.sql import SparkSession
import tensorflow as tf
import pickle
import numpy as np
import pandas as pd
findspark.init()
spark = SparkSession.builder.master("local[*]").getOrCreate()

# Importing and formatting CIFAR 10 data

In [3]:
!wget https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz
!tar xvf cifar-10-python.tar.gz
!ls cifar-10-batches-py/

--2022-04-15 05:06:32--  https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2022-04-15 05:06:34 (79.4 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]

cifar-10-batches-py/
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1
batches.meta  data_batch_2  data_batch_4  readme.html
data_batch_1  data_batch_3  data_batch_5  test_batch


# Define function for loading CIFAR 10 batches

In [4]:
def cifar10_batch_load(data_path, batch_id):
    with open(data_path + '/data_batch_' + str(batch_id), mode = 'rb') as f:
        batch = pickle.load(f, encoding = 'latin1')
    
    features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
    labels = batch['labels']
    return features, labels

def load_label_names():
    return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [5]:
features_p, labels_p = np.array([]), np.array([])

for batch_id in range(1, 6):
    features, labels = cifar10_batch_load('./cifar-10-batches-py', batch_id)
    labels = np.expand_dims(np.squeeze(labels), 1)
    if batch_id - 1:
        features_acc, labels_acc = np.vstack([features_p, features]), np.vstack([labels_p, labels])
        features_p, labels_p = features_acc, labels_acc
    else:
        features_p, labels_p = features, labels

label_names = load_label_names()

In [6]:
features_acc.shape, labels_acc.shape

((50000, 32, 32, 3), (50000, 1))

# Write the data in the imagenet format

In [7]:
from PIL import Image as im
import os, shutil
from tqdm.notebook import tqdm

def write_data_imagenet(features, labels, data_path):
    label_names = load_label_names()

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    for label in label_names:
        sub_fold = os.path.join(data_path, label)
        if not os.path.exists(sub_fold):
            os.mkdir(sub_fold)

    for i in tqdm(range(features.shape[0])):
        samp = features[i]
        label = np.squeeze(labels[i])
        data = im.fromarray(samp, 'RGB')
        data_save_path = os.path.join(data_path, label_names[label], str(i)+'.jpg')
        data.save(data_save_path)

In [8]:
train_path = 'c10_data/train_data'
write_data_imagenet(features_acc, labels_acc, train_path)

  0%|          | 0/50000 [00:00<?, ?it/s]

In [9]:
with open('./cifar-10-batches-py/test_batch', mode = 'rb') as f:
    batch = pickle.load(f, encoding = 'latin1')
test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
test_labels = batch['labels']

test_path = 'c10_data/test_data'
write_data_imagenet(test_features, test_labels, test_path)

  0%|          | 0/10000 [00:00<?, ?it/s]

# Read to Spark

In [10]:
import io
from tensorflow.keras.applications.imagenet_utils import decode_predictions
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, regexp_extract
import torch
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import pyspark.sql.functions as sqlf

In [11]:
images = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.jpg").load('./c10_data/train_data')

In [12]:
def extract_label(path_col):
    """Extract label from file path using built-in regex SQL function."""
    return regexp_extract(path_col, "./c10_data/train_data/([^/]+)", 1)

def extract_size(content):
    """Extract image size from its raw content."""
    image = im.open(io.BytesIO(content))
    return image.size

@pandas_udf("width: int, height: int")
def extract_size_udf(content_series):
    sizes = content_series.apply(extract_size)
    return pd.DataFrame(list(sizes))

df = images.select(col("path"),
                   col("modificationTime"),
                   extract_label(col("path")).alias("label"),
                   extract_size_udf(col("content")).alias("size"),
                   col("content"))

In [13]:
class ImageNetDataset(Dataset):
    """
    Converts image contents into a PyTorch dataset with standard ImageNet preprocessing.
    """
    def __init__(self, contents):
        self.contents = contents

    def __len__(self):
        return len(self.contents)

    def __getitem__(self, index):
        return self._preprocess(self.contents[index])

    def _preprocess(self, content):
        """
        Preprocesses the input image content using standard ImageNet normalization.
        """
        image = im.open(io.BytesIO(content))
        transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
        ])
        return transform(image)

In [14]:
def imagenet_model_udf(model_fn):
    """
    Wraps an ImageNet model into a Pandas UDF that makes predictions.
    This code was run on Colab while using a GPU for acceleration.
    """
    def predict(content_series_iter):
        model = model_fn()
        model.eval()
        for content_series in content_series_iter:
            dataset = ImageNetDataset(list(content_series))
            loader = DataLoader(dataset, batch_size = 64)
            with torch.no_grad():
                for image_batch in loader:
                    predictions = model(image_batch).numpy()
                    predicted_labels = [x[0] for x in decode_predictions(predictions, top = 1)]
                    yield pd.DataFrame(predicted_labels)
    return_type = "class: string, desc: string"
    return pandas_udf(return_type, PandasUDFType.SCALAR_ITER)(predict)

In [15]:
mobilenet_v2_udf = imagenet_model_udf(lambda: models.mobilenet_v2(pretrained=True))
predictions = df.withColumn("prediction", mobilenet_v2_udf(col("content")))
prediction_mobil = predictions.select(col("label"), col("prediction.desc").alias("mobilenetv2 prediction"))
prediction_mobil.show(25, False)



+----------+----------------------+
|label     |mobilenetv2 prediction|
+----------+----------------------+
|frog      |rock_python           |
|bird      |pinwheel              |
|truck     |bearskin              |
|automobile|mousetrap             |
|truck     |oil_filter            |
|truck     |thresher              |
|frog      |jaguar                |
|truck     |moving_van            |
|airplane  |waffle_iron           |
|automobile|panpipe               |
|frog      |sidewinder            |
|truck     |airliner              |
|automobile|maraca                |
|truck     |thresher              |
|frog      |clog                  |
|truck     |thresher              |
|truck     |moving_van            |
|frog      |jersey                |
|truck     |thresher              |
|cat       |fire_screen           |
|truck     |thresher              |
|truck     |moving_van            |
|frog      |sidewinder            |
|truck     |tobacco_shop          |
|frog      |custard_apple   

In [16]:
prediction_mobil_ser = prediction_mobil.limit(2500).toPandas()
top_num = 5
for label_name in label_names:
    filt_rows = prediction_mobil_ser.loc[prediction_mobil_ser['label'] == label_name]
    print(f"\n\n ####### Top {top_num} predictions for class {label_name} #######")
    final_rows = filt_rows['mobilenetv2 prediction'].value_counts().nlargest(top_num).to_frame('counts')
    print(final_rows)



 ####### Top 5 predictions for class airplane #######
               counts
moving_van          7
rock_beauty         4
thresher            4
assault_rifle       4
chain_saw           4


 ####### Top 5 predictions for class automobile #######
                 counts
moving_van          260
thresher             48
chain_saw            41
amphibian            26
cassette_player      17


 ####### Top 5 predictions for class bird #######
                  counts
fox_squirrel          10
three-toed_sloth       8
rock_beauty            5
patas                  3
bearskin               3


 ####### Top 5 predictions for class cat #######
                  counts
EntleBucher            9
fox_squirrel           7
Japanese_spaniel       5
bearskin               5
Windsor_tie            4


 ####### Top 5 predictions for class deer #######
                  counts
fox_squirrel          13
sorrel                 5
barn_spider            5
cardoon                4
Japanese_spaniel       3


 ##

# Predictions using different models

In [17]:
def modelpredict(modelname):
    mymodellist = ['googlenet', 'inception_v3', 'mobilenet_v3_small', 'mobilenet_v3_large', 'densenet121']
    if modelname == mymodellist[0]: model_udf = imagenet_model_udf(lambda: models.googlenet(pretrained=True))
    elif modelname == mymodellist[1]: model_udf = imagenet_model_udf(lambda: models.inception_v3(pretrained=True))
    elif modelname == mymodellist[2]: model_udf = imagenet_model_udf(lambda: models.mobilenet_v3_small(pretrained=True))
    elif modelname == mymodellist[3]: model_udf = imagenet_model_udf(lambda: models.mobilenet_v3_large(pretrained=True))
    elif modelname == mymodellist[4]: model_udf = imagenet_model_udf(lambda: models.densenet121(pretrained=True))
    
    predictions = df.withColumn("prediction", model_udf(col("content")))
    predictions_model = predictions.select(col("label"), col("prediction.desc").alias(modelname + " prediction"))
    ser_df = predictions_model.limit(2500).toPandas()
    top_num = 3
    for label_name in label_names:
        filt_rows = ser_df.loc[ser_df['label'] == label_name]
        print(f"\n\n ####### Top {top_num} predictions for class {label_name} #######")
        final_rows = filt_rows[modelname + ' prediction'].value_counts().nlargest(top_num).to_frame('counts')
        print(final_rows)

In [18]:
modelpredict('googlenet')





 ####### Top 3 predictions for class airplane #######
               counts
fox_squirrel        6
letter_opener       5
moving_van          5


 ####### Top 3 predictions for class automobile #######
                 counts
moving_van          332
chain_saw            41
cassette_player      26


 ####### Top 3 predictions for class bird #######
                  counts
fox_squirrel          26
patas                  9
three-toed_sloth       4


 ####### Top 3 predictions for class cat #######
              counts
fox_squirrel      35
EntleBucher        7
patas              6


 ####### Top 3 predictions for class deer #######
              counts
fox_squirrel      19
red_wolf           6
sorrel             5


 ####### Top 3 predictions for class dog #######
                  counts
Japanese_spaniel      22
fox_squirrel          13
English_foxhound      11


 ####### Top 3 predictions for class frog #######
              counts
fox_squirrel     134
sidewinder        23
rock_python  

In [19]:
modelpredict('mobilenet_v3_small')





 ####### Top 3 predictions for class airplane #######
            counts
can_opener       6
thresher         5
airliner         5


 ####### Top 3 predictions for class automobile #######
            counts
moving_van     206
thresher        73
chain_saw       50


 ####### Top 3 predictions for class bird #######
                  counts
fox_squirrel          12
patas                  7
three-toed_sloth       6


 ####### Top 3 predictions for class cat #######
                  counts
fox_squirrel          12
Japanese_spaniel       9
patas                  5


 ####### Top 3 predictions for class deer #######
              counts
fox_squirrel       9
barn_spider        6
dhole              6


 ####### Top 3 predictions for class dog #######
                         counts
Japanese_spaniel             39
wire-haired_fox_terrier       8
toy_terrier                   7


 ####### Top 3 predictions for class frog #######
                counts
fox_squirrel       105
frilled_lizard    

In [20]:
modelpredict('densenet121')





 ####### Top 3 predictions for class airplane #######
             counts
airliner          5
moving_van        5
rock_beauty       5


 ####### Top 3 predictions for class automobile #######
            counts
moving_van     234
thresher        36
amphibian       30


 ####### Top 3 predictions for class bird #######
               counts
fox_squirrel        7
limpkin             6
spider_monkey       5


 ####### Top 3 predictions for class cat #######
               counts
fox_squirrel       11
affenpinscher       6
colobus             5


 ####### Top 3 predictions for class deer #######
                 counts
fox_squirrel          9
sorrel                7
Indian_elephant       7


 ####### Top 3 predictions for class dog #######
                   counts
Japanese_spaniel       16
Dandie_Dinmont         11
Brabancon_griffon       8


 ####### Top 3 predictions for class frog #######
                  counts
fox_squirrel          91
three-toed_sloth      27
platypus             

In [21]:
modelpredict('inception_v3')





 ####### Top 3 predictions for class airplane #######
           counts
thresher        6
muzzle          4
chain_saw       4


 ####### Top 3 predictions for class automobile #######
            counts
moving_van     154
amphibian       40
chain_saw       32


 ####### Top 3 predictions for class bird #######
              counts
fox_squirrel       8
brambling          6
chain_saw          3


 ####### Top 3 predictions for class cat #######
              counts
EntleBucher        6
lynx               6
fox_squirrel       6


 ####### Top 3 predictions for class deer #######
              counts
fox_squirrel       7
bluetick           5
gazelle            5


 ####### Top 3 predictions for class dog #######
                  counts
Japanese_spaniel      12
toy_terrier            7
bluetick               5


 ####### Top 3 predictions for class frog #######
              counts
fox_squirrel      50
rock_python       45
leopard           31


 ####### Top 3 predictions for class horse

In [22]:
modelpredict('mobilenet_v3_large')





 ####### Top 3 predictions for class airplane #######
               counts
safety_pin          6
space_shuttle       6
thresher            6


 ####### Top 3 predictions for class automobile #######
             counts
moving_van      213
thresher         56
convertible      37


 ####### Top 3 predictions for class bird #######
                  counts
fox_squirrel          15
three-toed_sloth       7
cardoon                6


 ####### Top 3 predictions for class cat #######
                   counts
fox_squirrel           10
Brabancon_griffon       6
kit_fox                 5


 ####### Top 3 predictions for class deer #######
              counts
fox_squirrel      15
cardoon            7
hartebeest         5


 ####### Top 3 predictions for class dog #######
                   counts
Japanese_spaniel       25
otterhound              6
Brabancon_griffon       6


 ####### Top 3 predictions for class frog #######
                counts
fox_squirrel        90
frilled_lizard      30