# Local PySpark on SageMaker Studio

This notebook shows how to run local PySpark code within a SageMaker Studio notebook. For this example we use the **Data Science - Python3** image and kernel, but this methodology should work for any kernel within SM Studio, including BYO custom images.

## Setup
There are two things that must be done to enable local PySpark within SageMaker Studio.
1. Make sure there is an available Java installation. The easiest way to install JDK and set the proper paths is to utilize conda
2. We need to append the local container's hostname into `/etc/hosts` in order for Spark to properly communicate

In [None]:
# Setup - Run only once per Kernel App
%conda install openjdk -y
!grep `hostname` /etc/hosts >/dev/null || echo 127.0.0.1 `hostname` >> /etc/hosts

## Install PySpark

In [None]:
! pip install pyspark==3.2.1

## Utilize S3 Data within local PySpark
* By specifying the `hadoop-aws` jar in our Spark config we're able to access S3 datasets using the s3a file prefix. 
* Since we've already authenticated ourself to SageMaker Studio , we can use our assumed SageMaker ExecutionRole for any S3 reads/writes by setting the credential provider as `ContainerCredentialsProvider`

### Download data

In [None]:
! mkdir ./../../data

In [None]:
! aws s3 cp s3://ws-assets-prod-iad-r-iad-ed304a55c2ca1aee/9e2e09b0-7142-4ab8-8b89-531349b817b9/deep-ar-electricity/LD2011_2014.csv.gz ./../../data

### Upload Data to S3

In [None]:
import boto3

In [None]:
s3_client = boto3.client("s3")

In [None]:
s3_bucket = ""

object_name = "./../../data/LD2011_2014.csv.gz"

In [None]:
s3_client.upload_file(object_name, s3_bucket, "data/input/{}".format(object_name.split("/")[-1]))

***

## Work with Local PySpark

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

import random
from pyspark.sql import SparkSession
import pyspark.sql.functions as fn
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType, StringType, IntegerType

# Important: PySpark version 3.2.x

Run the cell below if you are using a PySpark version >= 3.2.x

If you want to use a `pyspark >= 3.2.x`, you need to provide the hadoop-aws jars version >=3.2.x for interacting with AWS services, such as Amazon S3

In [None]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("PySparkApp")
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.2.2")
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )
    .getOrCreate()
)

print(spark.version)

***

# Important: PySpark version 2.4.x

Run the cell below if you are using a PySpark version ~= 2.4.x

If you want to use a pyspark version ~= 2.4.x, you have to provide the list of aws-java-sdk jars for interacting with AWS services, such as Amazon S3.

You can use the python module `sagemaker_spark==1.4.2` and extract the list of jars to provide for the creation of the spark session.

In [None]:
! pip install pyspark==2.4.1

In [None]:
%pip install sagemaker_pyspark==1.4.2

In [None]:
import sagemaker_pyspark

classpath = ":".join(sagemaker_pyspark.classpath_jars())

In [None]:
# Import pyspark and build Spark session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("PySparkApp")
    .config("spark.driver.extraClassPath", classpath)
    .config(
        "fs.s3a.aws.credentials.provider",
        "com.amazonaws.auth.ContainerCredentialsProvider",
    )
    .getOrCreate()
)

print(spark.version)

***

In [None]:
schema = "date TIMESTAMP, client STRING, value FLOAT"

In [None]:
df = spark \
    .read \
    .schema(schema) \
    .options(sep =',', header=True, mode="FAILFAST", timestampFormat="yyyy-MM-dd HH:mm:ss") \
    .csv("s3a://{}/data/input/{}".format(s3_bucket, object_name.split("/")[-1]), header=True)

In [None]:
df.show()

In [None]:
{"date": "2011-01-01 00:15:00", "client": "MT_001", "value": "0.0"}
{"date": "2011-01-01 00:30:00", "client": "MT_001", "value": "0.1"}
{"date": "2011-01-01 00:45:00", "client": "MT_001", "value": "0.1"}

In [None]:
# resample from 15min intervals to one hour to speed up training
df = df \
    .groupBy(fn.date_trunc("HOUR", fn.col("date")).alias("date"), fn.col("client")) \
    .agg(fn.mean("value").alias("value"))

In [None]:
# create a dictionary to Integer encode each client
client_list = df.select("client").distinct().collect()
client_list = [rec["client"] for rec in client_list]
client_encoder = dict(zip(client_list, range(len(client_list)))) 

In [None]:
random_client_list = random.sample(client_list, 6)

random_clients_pandas_df = df \
                            .where(fn.col("client").isin(random_client_list)) \
                            .groupBy("date") \
                            .pivot("client").max().toPandas()

random_clients_pandas_df.set_index("date", inplace=True)

In [None]:
random_clients_pandas_df

Aggregating data for removing gaps. So for example if you have data that only comes in Monday to Friday (e.g. stock trading activity), we'd have to insert NaN data points to account for Saturdays and Sundays. A quick way to check if our data has any gaps is to aggregate by the day of the week. Running the commands below we can see that the difference between the count and the lowest count is 24 Hours which is ok as it just means that the last datapoint falls midweek. Also the counts match across all customers so it appears that this dataset does not have any gaps

In [None]:
weekday_counts = df \
                .withColumn("dayofweek", fn.dayofweek("date")) \
                .groupBy("client") \
                .pivot("dayofweek") \
                .count()

In [None]:
weekday_counts.show(5) # show aggregates for several clients
weekday_counts.agg(*[fn.min(col) for col in weekday_counts.columns[1:]]).show() # show minimum counts of observations across all clients
weekday_counts.agg(*[fn.max(col) for col in weekday_counts.columns[1:]]).show() # show maximum counts of observations across all clients

In [None]:
train_start_date = df.select(fn.min("date").alias("date")).collect()[0]["date"]
test_start_date = "2014-01-01"
end_date = df.select(fn.max("date").alias("date")).collect()[0]["date"]

In [None]:
# split the data into train and test set
train_data = df.where(fn.col("date") < test_start_date)
test_data = df.where(fn.col("date") >= test_start_date)

In [None]:
# pandasUDFs require an output schema. This one matches the format required for DeepAR
dataset_schema = StructType([StructField("target", ArrayType(DoubleType())),
                             StructField("cat", ArrayType(IntegerType())),
                             StructField("start", StringType())
                            ])

In [None]:
@pandas_udf(dataset_schema, PandasUDFType.GROUPED_MAP)
def prep_deep_ar(df):
    
    df = df.sort_values(by="date")
    client_name = df.loc[0, "client"]
    targets = df["value"].values.tolist()
    cat = [client_encoder[client_name]]
    start = str(df.loc[0,"date"])
    
    return pd.DataFrame([[targets, cat, start]], columns=["target", "cat", "start"])

In [None]:
train_data = train_data.groupBy("client").apply(prep_deep_ar)

### Upload data to S3

In [None]:
file_name_processed = object_name.split("/")[-1].split(".")[0] + "_processed.json"

In [None]:
train_data.write.mode("overwrite").json("s3a://{}/data/output/{}".format(s3_bucket, file_name_processed))