## Spark Pipeline as the datasource

In [None]:
# Preparation - define variables
import os
PROJECT_NAME = os.getenv("PROJECT_NAME", "demo_spark_pipeline")
REFRESH_TOKEN = os.getenv("REFRESH_TOKEN")
SPARK_DEPS_AZURE = os.getenv("SPARK_DEPS_JAR")

In [None]:
! pip install pyspark==3.4.1 h2o-featurestore

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder \
    .master("local") \
    .config("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.1,io.delta:delta-core_2.12:2.4.0,org.apache.hadoop:hadoop-azure:3.3.1") \
    .config("spark.jars", SPARK_DEPS_AZURE) \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

In [None]:
# Login and authenticate
from featurestore import *
client = Client(API, secure=True)
client.auth.set_auth_token(REFRESH_TOKEN)

In [None]:
# Define credentials for the data source
S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY")
S3_SECRET_KEY = os.getenv("S3_SECRET_KEY")
S3_REGION = os.getenv("S3_REGION")
credentials = S3Credentials(S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION)

In [None]:
# Delete project if exists
try:
    client.projects.get(PROJECT_NAME).delete()
except:
    pass

## Create a project

In [None]:
# Create a project
project = client.projects.create(PROJECT_NAME)

In [None]:
# Specify source
source = CSVFile("s3a://feature-store-test-data/creditcard.csv")

## Extract schema from the source

In [None]:
# Extract schema
schema = client.extract_schema_from_source(source, credentials)

In [None]:
schema

## Register a feature set 

In [None]:
# Register featureset
fs = project.feature_sets.register(schema, "fs_spark_pipeline")

## Ingest data

In [None]:
# Ingest data
fs.ingest(source, credentials)

## Create a Spark pipeline

In [None]:
from featurestore import SparkPipeline
from pyspark.ml.feature import SQLTransformer
from pyspark.ml import Pipeline
from featurestore import SparkPipeline

In [None]:
# Spark pipeline has only one col age transformed 
query = "select avg(age) AS ave_age, avg(LIMIT_BAL) AS ave_limit_bal from __THIS__ group by education"
transformer = SQLTransformer(statement=query)
spark_pipeline = Pipeline(stages=[transformer])
pipeline_transformation = SparkPipeline(spark_pipeline)

## Extract schema from the Spark pipeline

In [None]:
## Extract schema from spark pipeline
derived_schema = client.extract_derived_schema([fs], pipeline_transformation)

## Register a feature set based on spark pipeline transformation

In [None]:
derived_fs = project.feature_sets.register(derived_schema, "derived_fs_spark_pipeline")

In [None]:
# Wait for data from input feature set to be propagated into derived
from featurestore.core.job_types import INGEST
jobs = derived_fs.get_active_jobs(INGEST)
if len(jobs) > 0:
   jobs[0].wait_for_result()

## Retrieve data

In [None]:
df = project.feature_sets.get(derived_fs.feature_set_name).retrieve().as_spark_frame(spark)

In [None]:
df.show()

## Cleanups

In [None]:
client.projects.get(PROJECT_NAME).delete()
