# RayDP - Distributed Spark MLLib based Model Training on Snowpark Container Services

This notebook demonstrates how to use RayDP to perform distributed Spark MLLIb based model training on Ray cluster in Snowpark Container Services.

## Setup and Imports

In [1]:
import ray
import raydp
import pprint
import warnings
import logging    
import time
import os
import numpy as np
import socket
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand, when, round as spark_round
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression
from pyspark.ml.regression import RandomForestRegressor, LinearRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml import Pipeline
import snowflake.connector
from snowflake.snowpark import Session
from snowflake.ml.data.data_connector import DataConnector
from snowflake.ml.ray.datasink import SnowflakeTableDatasink
print(f"Ray version: {ray.__version__}")
print(f"RayDP version: {raydp.__version__}")

  import pkg_resources


Ray version: 2.46.0
RayDP version: 1.6.2


## Initialize Snowpark Session

In [2]:
def connection() -> snowflake.connector.SnowflakeConnection:
    if os.path.isfile("/snowflake/session/token"):
        creds = {
            'host': os.getenv('SNOWFLAKE_HOST'),
            'port': os.getenv('SNOWFLAKE_PORT'),
            'protocol': "https",
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'authenticator': "oauth",
            'token': open('/snowflake/session/token', 'r').read(),
            'warehouse': "LARGE_WH",
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True
        }
    else:
        creds = {
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'user': os.getenv('SNOWFLAKE_USER'),
            'password': os.getenv('SNOWFLAKE_PASSWORD'),
            'warehouse': snowflake_warehouse,
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True
        }

    connection = snowflake.connector.connect(**creds)
    return connection

def get_session() -> Session:
    return Session.builder.configs({"connection": connection()}).create()

In [3]:
session = get_session()

In [4]:
session.get_current_database()

'"RAYDP_SIS_DB"'

In [6]:
ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)

2025-07-06 23:00:24,761	INFO worker.py:1694 -- Connecting to existing Ray cluster at address: 10.244.10.75:6379...
2025-07-06 23:00:24,773	INFO worker.py:1879 -- Connected to Ray cluster. View the dashboard at [1m[32m10.244.10.75:8265 [39m[22m


0,1
Python version:,3.10.17
Ray version:,2.46.0
Dashboard:,http://10.244.10.75:8265


In [7]:
cluster_resources = ray.cluster_resources()
nodes = ray.nodes()

In [8]:
print("  Cluster Information:")
for i, node in enumerate(nodes):
    node_resources = node.get('Resources', {})
    print(f"   Node {i+1}: {node_resources}")
print()

  Cluster Information:
   Node 1: {'CPU': 6.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659282329.0, 'node:10.244.11.75': 1.0, 'memory': 20204992103.0}
   Node 2: {'node:10.244.10.75': 1.0, 'CPU': 6.0, 'node:__internal_head__': 1.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8420179968.0, 'memory': 19647086592.0}
   Node 3: {'CPU': 6.0, 'node:10.244.10.203': 1.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659241779.0, 'memory': 20204897485.0}
   Node 4: {'CPU': 6.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659227033.0, 'GPU': 1.0, 'node:10.244.10.139': 1.0, 'memory': 20204863079.0}



## See data

In [9]:
raw_data_snowdf = session.table("SPARK_MLLIB_SAMPLE_DATASET")
raw_data_snowdf.limit(1).to_pandas()

Unnamed: 0,ID,FEATURE_0,FEATURE_1,FEATURE_2,FEATURE_3,FEATURE_4,FEATURE_5,FEATURE_6,FEATURE_7,FEATURE_8,...,FEATURE_11,FEATURE_12,FEATURE_13,FEATURE_14,FEATURE_15,FEATURE_16,FEATURE_17,FEATURE_18,FEATURE_19,TARGET
0,0,1.764052,-0.394469,0.973217,2.240893,1.867558,-0.977278,0.950088,-0.151357,-0.103219,...,1.454274,0.761038,0.121675,0.443863,0.333674,1.494079,-0.205158,0.313068,-0.854096,1


In [19]:
train_snowdf, test_snowdf = raw_data_snowdf.random_split(weights=[0.70, 0.30], seed=0)

In [20]:
train_snowdf.write.mode("overwrite").save_as_table("TRAIN_SPARK_MLLIB_DATASET")
test_snowdf.write.mode("overwrite").save_as_table("TEST_SPARK_MLLIB_DATASET")

In [11]:
train_snowdf = session.table("TRAIN_SPARK_MLLIB_DATASET")
test_snowdf = session.table("TEST_SPARK_MLLIB_DATASET")

In [12]:
train_rayds =  DataConnector.from_dataframe(train_snowdf).to_ray_dataset()

SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


Info - 2025-07-06 23:01:05.069953 - Loading data from Snowpark Dataframe from query id 01bd8445-0205-de3e-0000-50070c8982de
Info - 2025-07-06 23:01:05.592688 - Finished executing data load query.
Info - 2025-07-06 23:01:05.774074 - Loaded data into ray dataset.


### Get optimal spark config

In [13]:
def configure_logging():
    logging.getLogger("py4j").setLevel(logging.ERROR)
    logging.getLogger("pyspark").setLevel(logging.ERROR)
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    print("Logging configured to suppress common Spark warnings")
def get_spark_configs_with_warning_suppression():
    """
    Get Spark configurations that suppress common warnings
    """
    return {
        "spark.sql.adaptive.enabled": "true",
        "spark.sql.adaptive.coalescePartitions.enabled": "true",
        "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
        "spark.sql.adaptive.advisoryPartitionSizeInBytes": "256MB",
        "spark.ml.tree.maxMemoryInMB": "2048",
        "spark.task.maxFailures": "3",
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.adaptive.maxRecordsPerPartition": "50000",
        "spark.rdd.compress": "true",
        "spark.io.compression.codec": "snappy",
        "spark.sql.autoBroadcastJoinThreshold": "50MB",
        "spark.broadcast.blockSize": "8m",
        "spark.sql.execution.arrow.maxRecordsPerBatch": "20000",        
        "spark.sql.shuffle.partitions": "200",
        "spark.default.parallelism": "200",
        "spark.network.timeout": "800s",
        "spark.executor.heartbeatInterval": "60s"
    }

In [14]:
def get_optimal_spark_config(cluster_resources):
    """
    Calculate optimal Spark configuration for the cluster
    """
    total_cpus = int(cluster_resources.get('CPU', 24))
    total_memory_gb = cluster_resources.get('memory', 72 * 1024**3) / (1024**3)
    num_nodes = len(ray.nodes())
    print(f" Cluster Analysis:")
    print(f"   Nodes: {num_nodes}")
    print(f"   Total CPUs: {total_cpus}")
    print(f"   Total Memory: {total_memory_gb:.1f} GB")
    # Leave 1 CPU per node for Ray head/driver and OS
    available_cpus = max(1, total_cpus - num_nodes)
    # Target 1 executor per node for good distribution
    num_executors = min(num_nodes, 4)
    executor_cores = max(1, available_cpus // num_executors)
    # Conservative memory allocation (leave 4GB per node for OS/Ray)
    available_memory = max(8, total_memory_gb - (num_nodes * 4))
    executor_memory_gb = max(2, int(available_memory // (num_executors + 1)))  # +1 for driver
    driver_memory_gb = min(4, executor_memory_gb)
    config = {
        'num_executors': num_executors,
        'executor_cores': executor_cores,
        'executor_memory': f"{executor_memory_gb}g",
        'driver_memory': f"{driver_memory_gb}g"
    }
    print(f" Optimal Spark Configuration:")
    print(f"   Executors: {config['num_executors']}")
    print(f"   Executor cores: {config['executor_cores']}")
    print(f"   Executor memory: {config['executor_memory']}")
    print(f"   Driver memory: {config['driver_memory']}")
    return config

In [15]:
configure_logging()
spark_config = get_optimal_spark_config(cluster_resources)
spark_configs = get_spark_configs_with_warning_suppression()
spark_configs["spark.driver.memory"] = spark_config['driver_memory']    

 Cluster Analysis:
   Nodes: 4
   Total CPUs: 24
   Total Memory: 74.7 GB
 Optimal Spark Configuration:
   Executors: 4
   Executor cores: 5
   Executor memory: 11g
   Driver memory: 4g


## Initialize Ray DP with optimal Spark config

In [16]:
print("\n Initializing RayDP with optimal configuration...")
spark = raydp.init_spark(
    app_name="RayDP_MLLib_Distributed_Training",
    num_executors=spark_config['num_executors'],
    executor_cores=spark_config['executor_cores'],
    executor_memory=spark_config['executor_memory'],
    configs=spark_configs
)


 Initializing RayDP with optimal configuration...


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/06 23:01:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [17]:
def set_spark_log_level(spark_session):
    try:
        spark_context = spark_session.sparkContext
        spark_context.setLogLevel("ERROR")
        print("Spark log level set to ERROR (warnings suppressed)")
    except Exception as e:
        print(f"Could not set Spark log level: {e}")
set_spark_log_level(spark)



In [18]:
print(f"Spark session initialized")
print(f"Application ID: {spark.sparkContext.applicationId}")
print(f"Spark Version: {spark.version}")

Spark session initialized
Application ID: spark-application-1751842884006
Spark Version: 3.5.4


In [19]:
spark

In [22]:
optimal_partitions = ray.cluster_resources().get('CPU', 24) * 2
print(f"Optimal partitions: {optimal_partitions}")

Optimal partitions: 48.0


In [23]:
df = train_rayds.to_spark(spark).repartition(int(optimal_partitions))
print(f"Created Spark DataFrame with {df.rdd.getNumPartitions()} partitions")

2025-07-06 23:02:56,816	INFO logging.py:290 -- Registered dataset logger for dataset dataset_12_1
2025-07-06 23:02:56,821	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_12_1. Full logs are in /raylogs/ray/session_2025-07-06_20-03-57_750929_35/logs/ray-data
2025-07-06 23:02:56,821	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_12_1: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadResultSetDataSource]
2025-07-06 23:03:00,837	INFO streaming_executor.py:220 -- ✔️  Dataset dataset_12_1 execution finished in 4.02 seconds

Created Spark DataFrame with 48 partitions


## Train classification model

In [26]:
def train_classification_model(df, spark_config):
    feature_cols = [f"FEATURE_{i}" for i in range(20)]
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="RAW_FEATURES")
    scaler = StandardScaler(inputCol="RAW_FEATURES", outputCol="FEATURES", withStd=True, withMean=True)
    rf = RandomForestClassifier(
        labelCol="TARGET",
        featuresCol="FEATURES",
        numTrees=100,
        maxDepth=10,
        maxBins=32,
        minInstancesPerNode=10,
        seed=42
    )
    pipeline = Pipeline(stages=[assembler, scaler, rf])
    train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)
    print(f"Training set: {train_data.count():,} rows")
    print(f"Evaluation set: {test_data.count():,} rows")
    print("Starting distributed training...")
    start_time = time.time()
    model = pipeline.fit(train_data)
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    print("Evaluating model...")
    predictions = model.transform(test_data)
    evaluator_auc = BinaryClassificationEvaluator(labelCol="TARGET", metricName="areaUnderROC")
    evaluator_acc = MulticlassClassificationEvaluator(labelCol="TARGET", predictionCol="prediction", metricName="accuracy")
    auc = evaluator_auc.evaluate(predictions)
    accuracy = evaluator_acc.evaluate(predictions)
    print(f"Model Performance:")
    print(f"AUC: {auc:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Training time: {training_time:.2f}s")
    return model, {'auc': auc, 'accuracy': accuracy, 'training_time': training_time}

In [27]:
model_class, metrics_class = train_classification_model(df, spark_config)

                                                                                

Training set: 560,180 rows


                                                                                

Evaluation set: 140,177 rows
Starting distributed training...


25/07/06 23:13:11 ERROR TaskSchedulerImpl: Lost executor 4 on 10.244.10.75: Remote RPC client disassociated. Likely due to containers exceeding thresholds, or network issues. Check driver logs for WARN messages.
25/07/06 23:14:34 ERROR TaskSchedulerImpl: Lost executor 5 on 10.244.10.75: Remote RPC client disassociated. Likely due to containers exceeding thresholds, or network issues. Check driver logs for WARN messages.

Training completed in 265.54 seconds
Evaluating model...


25/07/06 23:17:11 ERROR TaskSchedulerImpl: Lost executor 6 on 10.244.10.75: Remote RPC client disassociated. Likely due to containers exceeding thresholds, or network issues. Check driver logs for WARN messages.

Model Performance:
AUC: 0.9002
Accuracy: 0.8136
Training time: 265.54s


                                                                                

### Hyperparameter tuning

In [None]:
def hyperparameter_tuning(df):
    feature_cols = [f"feature_{i}" for i in range(20)]
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    estimator = RandomForestClassifier(labelCol="target", featuresCol="features", seed=42)
    evaluator = BinaryClassificationEvaluator(labelCol="target", metricName="areaUnderROC")
    paramGrid = ParamGridBuilder() \
        .addGrid(estimator.numTrees, [50, 100]) \
        .addGrid(estimator.maxDepth, [5, 10]) \
        .addGrid(estimator.minInstancesPerNode, [5, 10]) \
        .build()
    pipeline = Pipeline(stages=[assembler, estimator])
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
        numFolds=3,
        seed=42
    )
    train_data, _ = df.randomSplit([0.8, 0.2], seed=42)
    print(f"Starting hyperparameter tuning with {len(paramGrid)} parameter combinations...")
    start_time = time.time()
    cv_model = crossval.fit(train_data)
    tuning_time = time.time() - start_time
    print(f"Hyperparameter tuning completed in {tuning_time:.2f} seconds")
    # Get best model
    best_model = cv_model.bestModel
    print(f"Best hyperparameters found:")
    rf_stage = best_model.stages[-1]
    print(f"numTrees: {rf_stage.getNumTrees()}")
    print(f"maxDepth: {rf_stage.getMaxDepth()}")
    print(f"minInstancesPerNode: {rf_stage.getMinInstancesPerNode()}")
    return cv_model, tuning_time

In [None]:
cv_model_class, tune_time_class = hyperparameter_tuning(df)