<a href="https://colab.research.google.com/github/shumshersubashgautam/Pyspark-NLP/blob/main/ViT_Image_Classification_Annotator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget https://setup.johnsnowlabs.com/colab.sh -O - | bash /dev/stdin -p 3.2.1 -s 4.1.0

--2023-05-04 14:19:57--  https://setup.johnsnowlabs.com/colab.sh
Resolving setup.johnsnowlabs.com (setup.johnsnowlabs.com)... 51.158.130.125
Connecting to setup.johnsnowlabs.com (setup.johnsnowlabs.com)|51.158.130.125|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh [following]
--2023-05-04 14:19:57--  https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1191 (1.2K) [text/plain]
Saving to: ‘STDOUT’

-                     0%[                    ]       0  --.-KB/s               Installing PySpark 3.2.3 and Spark NLP 4.1.0
setup Colab for PySpark 3.2.3 

In [2]:
#Downloading Images
!wget -q https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/images/images.zip

In [3]:
import shutil
shutil.unpack_archive("images.zip", "images", "zip")

In [4]:
#Start Spark Session
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.sql import SparkSession

In [5]:
spark = sparknlp.start()

In [6]:
data_df = spark.read.format("image").option("dropInvalid", value = True).load(path="images/images/")

In [7]:
#Pipeline with ViTForImageClassification
image_assembler = ImageAssembler() \
            .setInputCol("image") \
            .setOutputCol("image_assembler")

image_classifier = ViTForImageClassification \
    .pretrained() \
    .setInputCols("image_assembler") \
    .setOutputCol("class")

pipeline = Pipeline(stages=[
    image_assembler,
    image_classifier,
])

image_classifier_vit_base_patch16_224 download started this may take some time.
Approximate size to download 309.7 MB
[OK!]


In [8]:
model = pipeline.fit(data_df)

In [9]:
image_df = model.transform(data_df)
image_df.show()

+--------------------+--------------------+--------------------+
|               image|     image_assembler|               class|
+--------------------+--------------------+--------------------+
|{file:///content/...|[{image, file:///...|[{category, 0, 5,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 11...|
|{file:///content/...|[{image, file:///...|[{category, 0, 55...|
|{file:///content/...|[{image, file:///...|[{category, 0, 2,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 24...|
|{file:///content/...|[{image, file:///...|[{category, 0, 14...|
|{file:///content/...|[{image, file:///...|[{category, 0, 7,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 8,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 6,...|
|{file:///content/...|[{image, file:///...|[{category, 0, 1,...|
+--------------------+--------------------+--------------------+



## **Light Pipeline**

In [10]:
# To use light pipeline in ViT transformer, need to use the new method fullAnnotateImage, which can receive 3 kind of inputs:

# A path to a single image
# A path to a list of images
light_pipeline = LightPipeline(model)
annotations_result = light_pipeline.fullAnnotateImage("images/images/hippopotamus.JPEG")
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [11]:
for result in annotations_result:
  image_assembler = result['image_assembler'][0]
  print(f"annotator_type: {image_assembler.annotator_type}")
  print(f"origin: {image_assembler.origin}")
  print(f"height: {image_assembler.height}")
  print(f"width: {image_assembler.width}")
  print(f"nChannels: {image_assembler.nChannels}")
  print(f"mode: {image_assembler.mode}")
  print(f"result size: {str(len(image_assembler.result))}")
  print(f"metadata: {image_assembler.metadata}")
  print(result['class'])

annotator_type: image
origin: images/images/hippopotamus.JPEG
height: 333
width: 500
nChannels: 3
mode: 16
result size: 499500
metadata: Map()
[Annotation(category, 0, 55, hippopotamus, hippo, river horse, Hippopotamus amphibius, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 7.2882756E-8, Some(beer glass) -> 9.0488925E-8, image -> 0, Some(damselfly) -> 1.9379786E-7, Some(turnstile) -> 6.8434524E-8, Some(cockroach, roach) -> 1.6622849E-7, height -> 333, Some(bulbul) -> 1.6930231E-7, Some(sea snake) -> 8.89582E-8, origin -> images/images/hippopotamus.JPEG, Some(mixing bowl) -> 1.2995402E-7, mode -> 16, None -> 1.3814622E-7, Some(whippet) -> 3.894023E-8, width -> 500, Some(buckle) -> 1.0061492E-7))]


In [12]:
#To send a list of images, we just difine a set of images
images = ["images/images/bluetick.jpg", "images/images/palace.JPEG", "images/images/hen.JPEG"]
annotations_result = light_pipeline.fullAnnotateImage(images)
annotations_result[0].keys()

dict_keys(['image_assembler', 'class'])

In [13]:
for result in annotations_result:
  print(result['class'])

[Annotation(category, 0, 7, bluetick, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 1.3846728E-6, Some(beer glass) -> 1.1807944E-6, image -> 0, Some(damselfly) -> 3.6875622E-7, Some(turnstile) -> 2.023695E-6, Some(cockroach, roach) -> 6.2982855E-7, height -> 500, Some(bulbul) -> 5.417509E-7, Some(sea snake) -> 5.7421556E-7, origin -> images/images/bluetick.jpg, Some(mixing bowl) -> 5.4001305E-7, mode -> 16, None -> 4.5454306E-7, Some(whippet) -> 1.2101438E-6, width -> 333, Some(buckle) -> 1.1306514E-6))]
[Annotation(category, 0, 5, palace, Map(nChannels -> 3, Some(lumbermill, sawmill) -> 6.3918545E-5, Some(beer glass) -> 8.879939E-6, image -> 0, Some(damselfly) -> 9.565577E-6, Some(turnstile) -> 6.315168E-5, Some(cockroach, roach) -> 1.125408E-5, height -> 334, Some(bulbul) -> 3.321073E-5, Some(sea snake) -> 1.0886038E-5, origin -> images/images/palace.JPEG, Some(mixing bowl) -> 2.6202975E-5, mode -> 16, None -> 2.6134943E-5, Some(whippet) -> 1.3805137E-5, width -> 500, Some(buckle)