In [6]:
import mlrun
import os

# Initialize the MLRun project object
project = mlrun.get_or_create_project('remote-artifacts',user_project=True,context='./')

# Required credentials :
# AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, GOOGLE_APPLICATION_CREDENTIALS, S3_BUCKET
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID', None)
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY', None)
GOOGLE_APPLICATION_CREDENTIALS = os.environ.get('GOOGLE_APPLICATION_CREDENTIALS', None)
assert AWS_ACCESS_KEY_ID != None and AWS_SECRET_ACCESS_KEY != None and GOOGLE_APPLICATION_CREDENTIALS != None, "please provide credentials"

secrets = {'AWS_ACCESS_KEY_ID': AWS_ACCESS_KEY_ID,
           'AWS_SECRET_ACCESS_KEY':AWS_SECRET_ACCESS_KEY,
           'GOOGLE_APPLICATION_CREDENTIALS': GOOGLE_APPLICATION_CREDENTIALS}

project.set_secrets(secrets=secrets, provider='kubernetes')

S3_BUCKET = os.environ.get('S3_BUCKET', 'testbucket-igz')

project.artifact_path = os.path.join('s3://', S3_BUCKET + '/remote-artifacts/')

> 2023-01-09 16:26:00,744 [info] loaded project remote-artifacts from MLRun DB


In [2]:
#mlrun: start-code
from pyspark.sql import SparkSession
import mlrun
import os
import pandas as pd

def get_dataitem(context: mlrun.MLClientCtx,
                             key: str):
    context.logger.info(key)
    for artifact in context.artifacts:
        context.logger.info(artifact)
        if artifact['kind'] == 'model' and artifact['metadata'].get('key',None) == key:
            return mlrun.get_dataitem(artifact['spec']['target_path'] + artifact['spec']['model_file'])
        elif artifact['kind'] == 'dataset' and artifact['metadata'].get('key',None) == key:
            return mlrun.get_dataitem(artifact['spec']['target_path'])
        elif artifact['metadata'].get('key',None) == key:
            return mlrun.get_dataitem(artifact['spec']['target_path'])
    context.logger.info('Artifact not found')
    

def spark_func(context: mlrun.MLClientCtx):
    spark = SparkSession \
        .builder \
        .appName("Python Spark SQL basic example") \
        .config("spark.some.config.option", "some-value") \
        .config("fs.s3a.access.key", context.get_secret('AWS_ACCESS_KEY_ID'))\
        .config("fs.s3a.secret.key", context.get_secret('AWS_SECRET_ACCESS_KEY'))\
        .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")\
        .config("com.amazonaws.services.s3.enableV4", True)\
        .config("spark.driver.extraJavaOptions", "-Dcom.amazonaws.services.s3.enableV4=true")\
        .getOrCreate()
        
    project = mlrun.get_current_project()
    
    df = mlrun.get_dataitem(context.artifact_path).as_df()
    sparkDF = spark.createDataFrame(df)
    sparkDF.show()
    
    sparkDF.write.option("header","true").mode("overwrite").parquet(context.artifact_path + 'transactions_cut.parquet')
    
#mlrun: end-code

In [3]:
# For Spark operator
from mlrun.runtimes import Spark3Runtime
# Spark3Runtime.deploy_default_image()

In [None]:
sj = mlrun.code_to_function(name='spark_func', 
                            kind='spark', 
                            image='.spark-job-default-image',
                            handler='spark_func')

# set spark driver config (gpu_type & gpus=<number_of_gpus>  supported too)
sj.with_driver_limits(cpu="1300m")
sj.with_driver_requests(cpu=1, mem="512m") 

# set spark executor config (gpu_type & gpus=<number_of_gpus> are supported too)
sj.with_executor_limits(cpu="1400m")
sj.with_executor_requests(cpu=1, mem="512m")

# adds fuse, daemon & iguazio's jars support
sj.with_igz_spark() 

project.set_function(name = 'spark_func', func=sj)
project.get_function('spark_func').apply(mlrun.platforms.mount_s3())

task = mlrun.new_task().with_secrets("kubernetes", ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "GOOGLE_APPLICATION_CREDENTIALS"])

project.get_function('spark_func').run(task,artifact_path = 's3a' + project.artifact_path[2:])

## Cleaup

In [None]:
import boto3
from urllib.parse import urlparse


s3 = boto3.resource('s3')
bucket = s3.Bucket(urlparse(project.artifact_path).netloc)
bucket.objects.filter(Prefix=urlparse(project.artifact_path).path[1:]).delete()
mlrun.get_run_db().delete_project(name=project.name, deletion_strategy='cascade')