# Spark-ocr visual doc classifier v2 preprocessing of RVL-CDIP dataset

It is a sample of how to do preprocessing on databricks cluster with loading and saving data on s3 storage

In [None]:
access_key = ""
secret_key = ""
encoded_secret_key = secret_key.replace("/", "%2F")
aws_bucket_name = ""
mount_name = "s3_dev"

try:
  dbutils.fs.mount("s3a://%s:%s@%s" % (access_key, encoded_secret_key, aws_bucket_name), "/mnt/%s" % mount_name)
except:
  dbutils.fs.unmount("/mnt/%s" % mount_name)
  dbutils.fs.mount("s3a://%s:%s@%s" % (access_key, encoded_secret_key, aws_bucket_name), "/mnt/%s" % mount_name)


In [0]:
%%bash
rm /dbfs/FileStore/johnsnowlabs/license.key

In [0]:
%sh echo "" >> /dbfs/FileStore/johnsnowlabs/license.key

In [None]:
%%bash
cat /dbfs/FileStore/johnsnowlabs/license.key

In [None]:
imagePath = "dbfs:/mnt/s3_dev/ocr/datasets/rvl_cdip_full_/*.tif"
df = spark.read.format("binaryFile").load(imagePath)
print(df.count())

In [0]:
import os

label_names = {0: "letter",
               1: "form",
               2: "email",
               3: "handwritten",
               4: "advertisement",
               5: "scientific_report",
               6: "scientific_publication",
               7: "specification",
               8: "file_folder",
               9: "news_article",
               10: "budget",
               11: "invoice",
               12: "presentation",
               13: "questionnaire",
               14: "resume",
               15: "memo"
}

files_labelled = {}
with open("/dbfs/mnt/s3_dev/ocr/datasets/rvl_cdip_train_labels.txt") as file:
    lines = file.readlines()
    for l in lines:
      l_ = l.strip().split(" ")
      head, tail = os.path.split(l_[0])
      #print(tail, label_names[int(l_[1])])
      files_labelled[tail] = label_names[int(l_[1])]

In [0]:
from pyspark.sql.functions import udf

def get_label(fl):
  head, fname = os.path.split(fl)
  if fname in files_labelled:
    return files_labelled[fname]
  else:
    print("File is missed:", fname)
    return None

get_label_udf = udf(get_label)

df = df.withColumn("act_label", get_label_udf("path"))
df.show(5)

In [None]:
df = df.dropna(subset="act_label")
df.select("path", "act_label").show(truncate = False)

In [0]:
df = df.repartition(18732)
df.rdd.getNumPartitions()

In [None]:
from sparkocr.transformers import *
from sparkocr.enums import *
from pyspark.ml import PipelineModel

binary_to_image = BinaryToImage()\
    .setOutputCol("image") \
    .setImageType(ImageType.TYPE_3BYTE_BGR)

img_to_hocr = ImageToHocr()\
    .setInputCol("image")\
    .setOutputCol("hocr")\
    .setIgnoreResolution(False)\
    .setOcrParams(["preserve_interword_spaces=0"])

tokenizer = HocrTokenizer()\
    .setInputCol("hocr")\
    .setOutputCol("token")

# OCR pipeline
pipeline1 = PipelineModel(stages=[
    binary_to_image,
    img_to_hocr,
    tokenizer
])

df = pipeline1.transform(df).cache()
df = df.withColumnRenamed("image", "orig_image")
display(df)

In [None]:
from sparkocr.utils import get_vocabulary_dict

vocab_file = "/dbfs/mnt/s3_dev/ocr/test_models/LayoutLM.v2.voc.txt"
vocab = get_vocabulary_dict(vocab_file, ",")

doc_class = VisualDocumentClassifierV2() \
    .setInputCols(["token", "orig_image"]) \
    .setOutputCol("label")
doc_class.setVocabulary(vocab)

result = doc_class.getPreprocessedDataset(
  df,
  [1,3,224,224]
  ).cache()

In [None]:
result.select("path", "input_ids", "bbox", "image", "attention_mask", "token_type_ids", "act_label").write.parquet("dbfs:/mnt/s3_dev/ocr/datasets/RVL-CDIP/processed_data")