In [None]:
from google.colab import drive
drive.mount('/content/drive')
# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.

### Setting up

Change the following in the below cell

- `n_infer`: number of videos 
- `n_partitions` : number of parallel workers


In [None]:
n_infer = 100 # [50, 100, 500, 1000]
n_partitions = 3 # [1,2,3,4]

### Accessing the directory 

Google Drive's link for the working directory is `https://drive.google.com/drive/folders/1XFVNSJR3ZiekIigyXs5cGe1mN8JRoeMe?usp=sharing`.

To run this notebook, 
- right-click on the `CLIP-spark` directory
- Choose `Add shortcut to drive` (This will create a shortcut in you `My Drive` dir) 

Then run the below cells

In [None]:
import sys
import os 
FOLDERNAME = 'CLIP-spark'
assert FOLDERNAME is not None, "[!] Enter the foldername."
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

%cd /content/drive/My\ Drive/$FOLDERNAME/data/v/
%cd /content/drive/My\ Drive/$FOLDERNAME

### Setting up Java, CLIP, PySpark, MongoDB

For the below cell, you need to press `Enter` to finish the installation process. 

In [None]:
'''
Run this cell only 1 time (even if runtime is restarted)
'''

!sudo add-apt-repository ppa:webupd8team/java
!sudo apt-get update
!sudo apt-get install oracle-java8-installer

In [None]:
!pip install decord
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install pytube 

In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
# !rm spark-3.0.1-bin-hadoop2.7.tgz
# !wget -q https://archive.apache.org/dist/spark/spark-3.0.1/spark-3.0.1-bin-hadoop2.7.tgz
!tar xf spark-3.0.1-bin-hadoop2.7.tgz
!pip install -q findspark
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "spark-3.0.1-bin-hadoop2.7"
import findspark
findspark.init()
findspark.find()
from pyspark.sql import SparkSession
from pyspark.sql.functions import array_contains

### Connecting to MongoDB server

In [None]:
USERNAME = PASSWORD = 'sri'
CNCT_STR = f'mongodb+srv://{USERNAME}:{PASSWORD}@svp-cluster.1uzpyjf.mongodb.net/svp_database.video_tags?retryWrites=true'
FORMAT = 'com.mongodb.spark.sql.DefaultSource'

In [None]:
videos_dir = '/content/drive/My Drive/{}/data/v'.format(FOLDERNAME)
num_videos = len(os.listdir(f'data/v'))
print('downloaded videos:', num_videos)
# if not os.path.exists(videos_dir):
#   os.makedirs(videos_dir)
#   print(f'{videos_dir} created!')

In [None]:
import torch
import clip 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType
from pyspark.sql.functions import spark_partition_id

spark = SparkSession.builder \
      .master("local[5]") \
      .appName("inference") \
      .config("spark.driver.memory", '15g') \
      .config('spark.ui.port', '4050') \
      .config('spark.mongodb.input.uri', CNCT_STR) \
      .config('spark.mongodb.output.uri', CNCT_STR) \
      .config('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.12:3.0.1') \
      .getOrCreate() 

In [None]:
video_id_csv = f'downloaded_video_ids_{n_infer}.csv'
df_small = spark.read.csv(video_id_csv, header=True)

In [None]:
paths_schema = StructType(
    [
        StructField("video_id", StringType(), True),
        StructField("video_path", StringType(), True)
    ]
)

def create_paths(pdf):
  ids = pdf.video_id
  video_paths = [f'data/v/{row.video_id}' for idx, row in pdf.iterrows()]
  return pdf.assign(video_path=video_paths)

pdf = df_small.groupby('video_id').applyInPandas(
    create_paths, schema=paths_schema
)
# pdf.show()
print('pdf original partitions:', pdf.rdd.getNumPartitions())
pdf_with_partition = pdf.coalesce(n_partitions)

num_partitions = pdf_with_partition.rdd.getNumPartitions()
print('num_partitions:', num_partitions)

pdf_with_partition = pdf_with_partition.withColumn('partition', spark_partition_id())
pdf_with_partition.show()

In [None]:
from clip_tagger import CLIPTag
def tagging_func(pdf):
  # initiate a new clip_tagger
  model, preprocess = clip.load("ViT-B/32", device=device)
  clip_tagger = CLIPTag(model, preprocess)
  all_tags = [clip_tagger.tag_video(row.video_path) for idx, row in pdf.iterrows()]
  return pdf.assign(tags=all_tags)

tags_schema = StructType(
    [
        StructField("video_id", StringType(), True),
        StructField("video_path", StringType(), True),
        StructField("partition", FloatType(), True),
        StructField("tags", ArrayType(StringType()), True)
    ]
)

import time
start = time.time()

tags_df = pdf_with_partition.groupby(spark_partition_id().alias("_pid")).applyInPandas(
    tagging_func, schema=tags_schema
)

tags_df.show()

num_partitions = pdf_with_partition.rdd.getNumPartitions()
end = time.time() - start
print(f'num_partitions: {num_partitions} time for {n_infer} videos: {end/60:.3f} mins')

In [None]:
'''
Writing tags to the database server.
'''

try:
  tags_df.write.format(FORMAT).mode('overwrite').save()
except:
  print('Error writing to database.')

In [None]:
xticks = [100, 200, 300, 400, 500]
p1 = [20.98, 41.70, 65.607, 90, 1108.79]
p2 = [10.76, 23.62, 32.56, 42.1, 62.104]
p3 = [7.64, 13.69, 19.55, 25.44, 38.40]
p4 = [12.96, 24.75, 32.75, 46.45, 65.56]

In [None]:
import matplotlib.pyplot as plt 
import numpy as np 

xticks = [100, 200, 300, 400, 500]
p1 = [20.98, 41.70, 65.607, 90, 108.79]
p2 = [10.76, 23.62, 32.56, 42.1, 62.104]
p3 = [7.64, 13.69, 19.55, 25.44, 38.40]
p4 = [12.96, 24.75, 32.75, 46.45, 65.56]

x = np.arange(5)

plt.figure()

plt.plot(x, p1, label='#workers=1')
plt.plot(x, p2, label='#workers=2')
plt.plot(x, p3, label='#workers=3')
plt.plot(x, p4, label='#workers=4')

plt.xticks(x, xticks)
plt.xlabel('#videos')
plt.ylabel('Inference Time (mins)')
plt.legend(loc='best')

# plt.show()

plt.savefig('infer.pdf')
plt.close()