# Rate-Distortion Optimization III
*Also check out [Part I](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1969271421694072/827948633476116/5612335034456173/latest.html) and [Part II](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1969271421694072/4057322776779238/5612335034456173/latest.html)*

In the [second notebook](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1969271421694072/4057322776779238/5612335034456173/latest.html), we used `pyscenedetect` and some calculus to find knee points of rate-VMAF curves for higher resolutions.

<img src="https://miro.medium.com/max/2228/1*1Q3Xx7CDywwdVbaLlpnRCg.png" alt="drawing" width="500"/>

The abililty to guess these distinguished points will help to constrain our grid search, reducing the compute burden of rate-distortion optimization. 

In this notebook, we explore regressing these points from content information like images sampled from each shot.

In [0]:
import io
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input

import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType

from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator

from pyspark.sql.functions import udf
from pyspark.ml.linalg import Vectors, VectorUDT

After some imports, we introduce helper functions to load a ResNet50 model pretrained on Imagenet.

Extracted embeddings provide a simple featurization for each image.

In [0]:
model = ResNet50(include_top=False)
bc_model_weights = sc.broadcast(model.get_weights())

def model_fn():
  model = ResNet50(weights=None, include_top=False)
  model.set_weights(bc_model_weights.value)
  return model

def preprocess(content):
  img = tf.io.decode_png(content, 3)
  arr = tf.image.resize(img, [224,224], method='nearest')
  return preprocess_input(arr)

def featurize_series(model, content_series):
  input = np.stack(content_series.map(preprocess))
  preds = model.predict(input)
  output = [p.flatten() for p in preds]
  return pd.Series(output)

@pandas_udf('array<float>', PandasUDFType.SCALAR_ITER)
def featurize_udf(content_series_iter):
  model = model_fn()
  for content_series in content_series_iter:
    yield featurize_series(model, content_series)

A helper function to obtain image byte arrays.

In [0]:
@udf(returnType=T.BinaryType())
def image_to_byte_array(content):
  image = content.to_pil()
  imgByteArr = io.BytesIO()
  image.save(imgByteArr, format=image.format)
  imgByteArr = imgByteArr.getvalue()
  return imgByteArr
  
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())

Here, we read parquet files storing the (image, knee_QP) pairs we extracted in part II.

In [0]:
df = (spark.read.parquet("/mnt/vmafs.parquet")
                .withColumn("content", image_to_byte_array(F.col("image")))
                .withColumnRenamed("knee_QPs", "label"))
display(df)

video,resolution,segment,bitrates,rd_curve,log_params,label,image,content
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1034, end_fno=1123)","List(100, 300, 500, 800)","List(10.990573, 43.128649, 62.126679, 76.250367)","List(35.73323, 31.384493, -163.41652)",31.384493,Image('1056.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz9WbMkSZYeiJ1F1czc7xKRkZlV1VuhMRAIKTIzBJ8I4a/HA4knDikUEsQUBo3urkZVd1VmbHdxdzNTPefw4aiqq/uNiMyq7p4kB/ZJSsoNd3MzNV0= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=743, end_fno=779)","List(100, 300, 500, 800)","List(13.794634, 30.596304, 60.352211, 78.12543)","List(123.27073, 863.97205, -834.61383)",863.97205,Image('766.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz9WbPkSJImiqmqGeDuZ4stI5eqmqyq6easFIpwlRFeIR8ofOMf4O+dl/t2KUOZ5pU73X27pjrXWM6Js/gCwEyVDwozGAzLcT8RmVE1g09SInHgBoM= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=388, end_fno=469)","List(100, 300, 500, 800)","List(57.546788, 71.097714, 75.739294, 80.404416)","List(8.527234, -47.758415, 23.820723)",-47.758415,Image('434.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz925bkOG4oDAMgpYisqu7x2N/y5X7/B9sXXuu3Pd1dVZkREgn8FyAgiFJEZdX0TLe9iZlVHakDxQOIM0CAAQMGDBgwYMCAAQMGDBgwYMCAAQMGDBg= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1728, end_fno=1753)","List(100, 300, 500, 800)","List(24.550767, 52.682413, 67.692335, 78.676923)","List(26.447737, 1.6189202, -97.726814)",1.6189202,Image('1731.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz9W68kSZImiImomZ+IyEtdume2e7Ybu0MuCBD7N0jw7/OBIAESi2WzZvp+qcyMzIhzjpup8EFUP/tURM0joqp6snraBIET7uZqqqJyV9GbyAUXXHA= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=663, end_fno=685)","List(100, 300, 500, 800)","List(57.090229, 80.330496, 86.557698, 91.124407)","List(9.02127, -83.742455, 31.931927)",-83.742455,Image('680.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAADgPElEQVR4nO3925LsuK42igHM6uWwI+w//Aq+2e//SvaFH8D23mvOriR8QQkJ4URIyqwavSa/6BitUpIgSII48CSAhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYWFhYU= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1123, end_fno=1154)","List(100, 300, 500, 800)","List(53.162244, 74.736938, 82.575255, 87.952249)","List(11.616462, -63.938427, 11.50484)",-63.938427,Image('1132.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAD09klEQVR4nOz925bkOI4gAAKUzCOjqrt35+y+7Ms+7v9/W/dUVbqbROwDJAgCQIiSmUdkzQiVJ8pcIkEQBHHjRQA33HDDDTfccMMNN9xwww033HDDDTfccMMNN9w= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1416, end_fno=1438)","List(100, 300, 500, 800)","List(26.136607, 48.638136, 60.693532, 74.490998)","List(32.310654, 104.241646, -145.67456)",104.241646,Image('1427.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz93ZIkOY4mCn4AqGruEZlZPUd6zjlyRPYJ9mrv9v3fZS9muqeqKzMi3EyVAPYCJBWq5h4VVV3d2T1NSIqnhZn+kCCIf4DAhAkTJkyYMGHChAkTJkw= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=685, end_fno=709)","List(100, 300, 500, 800)","List(23.43186, 65.246006, 78.005745, 86.142427)","List(17.654144, -79.91771, -29.538025)",-79.91771,Image('694.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz9ebMkOY4nCP4AkmrHO9wjKrOqeqZ3dvb7f6OVHVmR3l7prs6qjHD358/MVAlg/wBJpaqavfA8qmJ2WiEhHvbMqFQeIG4CwA477LDDDjvssMMOO+w= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1669, end_fno=1713)","List(100, 300, 500, 800)","List(26.56869, 58.948064, 72.493913, 85.643628)","List(26.280107, -16.765522, -89.61804)",-16.765522,Image('1672.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nOz925IkOZImjOkBMPfIzOruHdklhZf7DLz534YivOODU/bnvzM9XZUZEe4GQJUXCsDUTh4eWdVTe8AnJVke7mY4KvQEhQJgYGBgYGBgYGBgYGBgYGA= (truncated)
VideoStream(uri='/dbfs/mnt/vids/Jaws.mp4'),2560:1080,"Segment(start_fno=1019, end_fno=1034)","List(100, 300, 500, 800)","List(77.305476, 91.620222, 92.354226, 92.354226)","List(0.80608046, -99.99999, 87.31276)",-99.99999,Image('1031.png'),iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAAGfElEQVR4nO3ZsQnAQAwEwZdx/zW7BkcLr5kKLluEzgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= (truncated)


In [0]:
features_df = (
               df.repartition(16)
                 .select(featurize_udf("content").alias("features"), F.col("label"))
                 .withColumn("features", list_to_vector_udf(F.col("features")))
              )
(trainingData, testData) = features_df.randomSplit([0.9, 0.1])

A simple baseline using mllib's `RandomForestRegressor`.

In [0]:
rf = RandomForestRegressor(featuresCol="features")
pipeline = Pipeline(stages=[rf])
model = pipeline.fit(trainingData)

Finally, we evaluate our baseline model.

In [0]:
predictions = model.transform(testData)
predictions.select("prediction", "label", "features").show(5)
evaluator = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

rfModel = model.stages[0]
print(rfModel)

### Future improvements might include:
* other image features like entropy, video features like optical flow, metadata like genre
* other regressors, embeddings or end-to-end models

After refining our regressor, we can scale up training and use models like this to optimize the bitrate ladder.

By building a model to regress knee QPs in the rate-VMAF curves of high resolution encodings, we reduce the workload of rate distortion optimization.