# Use Ray in PySpark

Spark + AI Summit 2020 talk, 
[Dean Wampler](mailto:dean@anyscale.com)

This notebook demonstrates one way to integrate Ray and PySpark applications, where Ray is embedded in a _UDF_. The use case simulates the requirement for _data governance_, where we want to trace each record processed by a PySpark job. 

Another, more conventional way to meet this requirement is to run a separate webservice and make remote calls to it (usually over HTTP). This approach is demonstrated in the `ray-serve` directory. (See the [README](README.md) for details.)

This notebook embeds Ray in a UDF, where the Ray cluster is co-resident on the same nodes as PySpark. We'll actually just use a single machine, but the results generalize to real cluster deployments with minor changes (noted where applicable).

Why use this approach instead of the standalone system? Here are the pros and cons:

**Pros:**
* Avoiding a network/HTTP call may be more efficient in many cases.
* Fewer services to manage. Once PySpark and Ray clusters are setup, you can allow them to do all the scaling and distribution required. Spark handles the data partitions, Ray handles distribution of the other tasks and object graphs (for distributed state).

**Cons:**
* You might prefer explicitly separate services for runtime visibility and independent management. For example, it's easier to upgrade a separate web service behind a router, whereas in the example here the PySpark and Data Governance "hook" are more closely linked.

You can learn more about Ray [here](http://ray.io).

> **Note:** This notebook connects to a running Ray cluster. Start Ray ahead of time with `ray start --head`.

> **Note:** Requires Java 8!

In [None]:
!java -version

In [None]:
import json, time
import pyspark
import ray

In [None]:
from pyspark.sql.types import DataType, BooleanType, NullType, IntegerType, StringType, MapType

In [None]:
from pyspark.sql.functions import udf

Define a `DataGovernanceSystem` Ray actor that represents our governance system. (This is also defined in the file `data_governance_system.py`.) All it does is add each reported `id` to an internal collection. 

In a more realistic implementation, this class would be a "hook" that forwards the ids and other useful metadata asynchronously to a real governance system, like [Apache Atlas](http://atlas.apache.org/#/). 

In [None]:
@ray.remote
class DataGovernanceSystem:
    def __init__(self, name = 'DataGovernanceSystem'):
        self.name = name
        self.ids = []
        self.start_time = time.time()

    def log(self, id_to_log):
        """
        Log record ids that have been processed.
        Returns the new count.
        """
        self.ids.append(id_to_log)
        return self.get_count()

    def get_ids(self):
        """Return the ids logged. Don't call this if the list is long!"""
        return self.ids

    def get_count(self):
        """Return the count of ids logged."""
        return len(self.ids)

    def reset(self):
        """Forget all ids that have been logged."""
        self.ids = []

    def get_start_time(self):
        return self.start_time

    def get_up_time(self):
        return time.time() - self.start_time

Define a simple `Record` type with a `record_id` field, used for logging to `DataGovernanceSystem`, and an opaque `data` field with everything else.

In [None]:
class Record:
    def __init__(self, record_id, data):
        self.record_id = record_id
        self.data = data
    def __str__(self):
        return f'Record(record_id={self.record_id},data={self.data})'

Now initialize Ray in this application. Passing `address='auto'` tells Ray to connect to the running cluster. (If this node isn't part of that cluster, i.e., Ray isn't already running on this node, then pass the correct server address and port.)

In [None]:
ray.init(address='auto', ignore_reinit_error=True) # The `ignore_reinit_error=True` lets us rerun this cell without error...

In [None]:
print(f'Click here to open the Ray Dashboard: http://{ray.get_webui_url()}')

In [None]:
actor_name = 'dgs'
gov = DataGovernanceSystem.options(name=actor_name, detached=True).remote()
gov

Then use it somewhere "else".

In [None]:
dgs = ray.util.get_actor(actor_name)
test_records = [Record(i, f'data: {i}') for i in range(3)] 
for record in test_records:
    print(record)
    dgs.log.remote(record.record_id)

In [None]:
def gov_status():
    dgs = ray.util.get_actor(actor_name)
    print(f'count:   {ray.get(dgs.get_count.remote())}')
    print(f'ids:     {ray.get(dgs.get_ids.remote())}')
    print(f'up time: {ray.get(dgs.get_up_time.remote())}')

In [None]:
gov_status()

Reset the server:

In [None]:
dgs.reset.remote()
gov_status()

In [None]:
def log_record(id):
    """
    This function will become a UDF for Spark. Since each Spark task runs in a separate process, 
    we'll initialize Ray, connecting to the running cluster, if it is not already initialized.
    """
    did_initialization = 0
    if not ray.is_initialized():
        ray.init(address='auto', redis_password='5241590000000000')
        did_initialization = 1
        
    dgs = ray.util.get_actor(actor_name)
    count_id = dgs.log.remote(id)   # Runs asynchronously, returning an object id for a future.
    count = ray.get(count_id)       # But this blocks!
    return {'initialized': did_initialization, 'count': count}

In [None]:
spark = pyspark.sql.SparkSession.builder \
    .master("local[*]") \
    .appName("Data Governance Example") \
    .getOrCreate()

In [None]:
log_record_udf = udf(lambda id: log_record(id), MapType(StringType(), IntegerType()))

In [None]:
num_records=50

In [None]:
records = [Record(i, f'str: {i}') for i in range(num_records)] 

In [None]:
df = spark.createDataFrame(records, ['id', 'data'])

In [None]:
df_ray = df.select('id', 'data', log_record_udf('id').alias('logged'))

In [None]:
display(df_ray)

In [None]:
%time df_ray.show(n=num_records, truncate=False)

As you can see in the `logged` column, there are several PySpark processes (four on my laptop), each of which initializes Ray once.

You probably also see that the `count` values are out of order, because updates happen asynchronously from several PySpark tasks to the single `DataGovernanceSystem` actor, but Ray's actor model handles thread-safe updates, so that the final count is correct! 

In [None]:
gov_status()

In [None]:
gov.reset.remote()
gov_status()

In [None]:
ray.shutdown()