# RayDP - Distributed Spark MLLib based Model Training on Snowpark Container Services

This notebook demonstrates how to use RayDP to perform distributed Spark MLLIb based model training on Ray cluster in Snowpark Container Services.

## Setup and Imports

In [1]:
import ray
import raydp
import pprint
import warnings
import logging    
import time
import numpy as np
import socket
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand, when, round as spark_round
from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression
from pyspark.ml.regression import RandomForestRegressor, LinearRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml import Pipeline

print(f"Ray version: {ray.__version__}")
print(f"RayDP version: {raydp.__version__}")

  from .autonotebook import tqdm as notebook_tqdm
2025-06-28 01:49:17,241	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2025-06-28 01:49:17,802	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Ray version: 2.42.0
RayDP version: 1.6.2


In [2]:
ray.init(address="auto", ignore_reinit_error=True)

2025-06-28 01:49:18,000	INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 10.244.186.11:6379...
2025-06-28 01:49:18,011	INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at [1m[32m10.244.186.11:8265 [39m[22m
[2025-06-28 01:49:18,014 I 11924 11924] logging.cc:293: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1


0,1
Python version:,3.10.12
Ray version:,2.42.0
Dashboard:,http://10.244.186.11:8265


[33m(raylet, ip=10.244.187.11)[0m [2025-06-28 01:50:20,575 I 6048 6048] logging.cc:293: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
[36m(RayDPSparkMaster pid=6048, ip=10.244.187.11)[0m [2025-06-28 01:50:23,047 I 6086 6113] gcs_client.cc:98: GcsClient has no Cluster ID set, and won't fetch from GCS.
[36m(RayDPSparkMaster pid=6048, ip=10.244.187.11)[0m [2025-06-28 01:50:23,230 I 6086 6113] gcs_client.cc:98: GcsClient has no Cluster ID set, and won't fetch from GCS.
[33m(raylet, ip=10.244.187.11)[0m [2025-06-28 01:50:29,926 I 6155 6159] gcs_client.cc:98: GcsClient has no Cluster ID set, and won't fetch from GCS.
[33m(raylet, ip=10.244.187.11)[0m [2025-06-28 01:50:29,926 I 6155 6159] logging.cc:293: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
[33m(raylet)[0m [2025-06-28 01:50:49,187 I 12336 12336] logging.cc:293: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1


In [3]:
cluster_resources = ray.cluster_resources()
nodes = ray.nodes()

In [4]:
nodes

[{'NodeID': 'd76155411c1a927b83d1faf9b743fe15960f4be2f97024d1e2856765',
  'Alive': True,
  'NodeManagerAddress': '10.244.188.11',
  'NodeManagerHostname': 'statefulset-1',
  'NodeManagerPort': 8077,
  'ObjectManagerPort': 8076,
  'ObjectStoreSocketName': '/raylogs/ray/session_2025-06-28_00-08-44_284474_26/sockets/plasma_store',
  'RayletSocketName': '/raylogs/ray/session_2025-06-28_00-08-44_284474_26/sockets/raylet',
  'MetricsExportPort': 8082,
  'NodeName': '10.244.188.11',
  'RuntimeEnvAgentPort': 60018,
  'DeathReason': 0,
  'DeathReasonMessage': '',
  'alive': True,
  'Resources': {'CPU': 6.0,
   'GPU': 1.0,
   'node:10.244.188.11': 1.0,
   'accelerator_type:A10G': 1.0,
   'object_store_memory': 8658839961.0,
   'memory': 20203959911.0},
  'Labels': {'ray.io/node_id': 'd76155411c1a927b83d1faf9b743fe15960f4be2f97024d1e2856765'}},
 {'NodeID': 'd527cbecd7d4e1789f6551a13cae166c27fb8084f943a827664b7c8b',
  'Alive': True,
  'NodeManagerAddress': '10.244.186.11',
  'NodeManagerHostname':

In [5]:
print("  Cluster Information:")
for i, node in enumerate(nodes):
    node_resources = node.get('Resources', {})
    print(f"   Node {i+1}: {node_resources}")
print()

  Cluster Information:
   Node 1: {'CPU': 6.0, 'GPU': 1.0, 'node:10.244.188.11': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8658839961.0, 'memory': 20203959911.0}
   Node 2: {'node:__internal_head__': 1.0, 'CPU': 6.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8572531507.0, 'node:10.244.186.11': 1.0, 'memory': 17145063015.0}
   Node 3: {'CPU': 6.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8658742886.0, 'node:10.244.187.11': 1.0, 'memory': 20203733402.0}
   Node 4: {'CPU': 6.0, 'node:10.244.189.11': 1.0, 'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8658803097.0, 'memory': 20203873895.0}



## Get optimal spark config

In [6]:
def configure_logging():
    logging.getLogger("py4j").setLevel(logging.ERROR)
    logging.getLogger("pyspark").setLevel(logging.ERROR)
    warnings.filterwarnings("ignore", category=FutureWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    print("Logging configured to suppress common Spark warnings")
def get_spark_configs_with_warning_suppression():
    """
    Get Spark configurations that suppress common warnings
    """
    return {
        "spark.sql.adaptive.enabled": "true",
        "spark.sql.adaptive.coalescePartitions.enabled": "true",
        "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
        "spark.sql.adaptive.advisoryPartitionSizeInBytes": "256MB",
        "spark.ml.tree.maxMemoryInMB": "2048",
        "spark.task.maxFailures": "3",
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.adaptive.maxRecordsPerPartition": "50000",
        "spark.rdd.compress": "true",
        "spark.io.compression.codec": "snappy",
        "spark.sql.autoBroadcastJoinThreshold": "50MB",
        "spark.broadcast.blockSize": "8m",
        "spark.sql.execution.arrow.maxRecordsPerBatch": "20000",        
        "spark.sql.shuffle.partitions": "200",
        "spark.default.parallelism": "200",
        "spark.network.timeout": "800s",
        "spark.executor.heartbeatInterval": "60s"
    }

In [7]:
def get_optimal_spark_config(cluster_resources):
    """
    Calculate optimal Spark configuration for the cluster
    """
    total_cpus = int(cluster_resources.get('CPU', 24))
    total_memory_gb = cluster_resources.get('memory', 72 * 1024**3) / (1024**3)
    num_nodes = len(ray.nodes())
    print(f" Cluster Analysis:")
    print(f"   Nodes: {num_nodes}")
    print(f"   Total CPUs: {total_cpus}")
    print(f"   Total Memory: {total_memory_gb:.1f} GB")
    # Leave 1 CPU per node for Ray head/driver and OS
    available_cpus = max(1, total_cpus - num_nodes)
    # Target 1 executor per node for good distribution
    num_executors = min(num_nodes, 4)
    executor_cores = max(1, available_cpus // num_executors)
    # Conservative memory allocation (leave 4GB per node for OS/Ray)
    available_memory = max(8, total_memory_gb - (num_nodes * 4))
    executor_memory_gb = max(2, int(available_memory // (num_executors + 1)))  # +1 for driver
    driver_memory_gb = min(4, executor_memory_gb)
    config = {
        'num_executors': num_executors,
        'executor_cores': executor_cores,
        'executor_memory': f"{executor_memory_gb}g",
        'driver_memory': f"{driver_memory_gb}g"
    }
    print(f" Optimal Spark Configuration:")
    print(f"   Executors: {config['num_executors']}")
    print(f"   Executor cores: {config['executor_cores']}")
    print(f"   Executor memory: {config['executor_memory']}")
    print(f"   Driver memory: {config['driver_memory']}")
    return config

In [8]:
configure_logging()
spark_config = get_optimal_spark_config(cluster_resources)
spark_configs = get_spark_configs_with_warning_suppression()
spark_configs["spark.driver.memory"] = spark_config['driver_memory']    

 Cluster Analysis:
   Nodes: 4
   Total CPUs: 24
   Total Memory: 72.4 GB
 Optimal Spark Configuration:
   Executors: 4
   Executor cores: 5
   Executor memory: 11g
   Driver memory: 4g


## Initialize Ray DP with optimal Spark config

In [9]:
print("\n Initializing RayDP with optimal configuration...")
spark = raydp.init_spark(
    app_name="RayDP_MLLib_Distributed_Training",
    num_executors=spark_config['num_executors'],
    executor_cores=spark_config['executor_cores'],
    executor_memory=spark_config['executor_memory'],
    configs=spark_configs
)


 Initializing RayDP with optimal configuration...


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/28 01:50:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [10]:
def set_spark_log_level(spark_session):
    try:
        spark_context = spark_session.sparkContext
        spark_context.setLogLevel("ERROR")
        print("Spark log level set to ERROR (warnings suppressed)")
    except Exception as e:
        print(f"Could not set Spark log level: {e}")
set_spark_log_level(spark)



In [11]:
print(f"Spark session initialized")
print(f"Application ID: {spark.sparkContext.applicationId}")
print(f"Spark Version: {spark.version}")

Spark session initialized
Application ID: spark-application-1751075428163
Spark Version: 3.5.4


## Perform classification

### Create dataset

In [12]:
@ray.remote
def generate_data_chunk(chunk_id, chunk_size, num_features):
    np.random.seed(chunk_id)
    data = []
    for i in range(chunk_size):
        features = np.random.normal(0, 1, num_features)
        features[1] = features[0] * 0.5 + np.random.normal(0, 0.5)
        features[2] = features[0] * 0.3 + features[1] * 0.2 + np.random.normal(0, 0.8)
        signal = (features[0] * 0.5 + 
                 features[1] * 0.3 + 
                 features[2] * 0.2 + 
                 np.sum(features[3:6]) * 0.1)
        target = 1 if signal + np.random.normal(0, 0.5) > 0 else 0
        row = [chunk_id * chunk_size + i] + features.tolist() + [target]
        data.append(row)
    return data

In [13]:
def create_large_dataset(spark, num_rows=1_000_000, num_features=20):
    """
    Create a large synthetic dataset using Ray for parallel data generation
    """
    print(f"Generating {num_rows:,} rows with {num_features} features for classification...")
    # Generate data in parallel using Ray
    num_chunks = 50  # Distribute data generation
    chunk_size = num_rows // num_chunks
    print(f"   Using {num_chunks} parallel data generation tasks...")
    # Generate data chunks in parallel
    chunk_futures = [
        generate_data_chunk.remote(i, chunk_size, num_features) 
        for i in range(num_chunks)
    ]
    chunks = ray.get(chunk_futures)
    # Flatten data
    all_data = []
    for chunk in chunks:
        all_data.extend(chunk)
    print(f"Generated {len(all_data):,} data points")
    # Create schema
    feature_fields = [StructField(f"feature_{i}", DoubleType(), True) for i in range(num_features)]
    target_field = StructField("target", IntegerType(), True)
    schema = StructType([StructField("id", IntegerType(), True)] + feature_fields + [target_field])
    # Create Spark DataFrame with high partition count for distribution
    optimal_partitions = ray.cluster_resources().get('CPU', 24) * 2
    df = spark.createDataFrame(all_data, schema).repartition(int(optimal_partitions))
    print(f"Created DataFrame with {df.rdd.getNumPartitions()} partitions")
    return df

In [14]:
spark

In [15]:
rows = 1_000_000
features = 20
df_class = create_large_dataset(spark, rows, features)

Generating 1,000,000 rows with 20 features for classification...
   Using 50 parallel data generation tasks...
Generated 1,000,000 data points




Created DataFrame with 48 partitions


In [16]:
df_class.show(1)

                                                                                

+----+-------------------+-------------------+-------------------+--------------------+------------------+----------------+-----------------+-------------------+------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+-----------------+------------------+-------------------+-------------------+------------------+------+
|  id|          feature_0|          feature_1|          feature_2|           feature_3|         feature_4|       feature_5|        feature_6|          feature_7|         feature_8|          feature_9|         feature_10|         feature_11|         feature_12|        feature_13|         feature_14|       feature_15|        feature_16|         feature_17|         feature_18|        feature_19|target|
+----+-------------------+-------------------+-------------------+--------------------+------------------+----------------+-----------------+-------------------+------------------+--------------

### Train classification model

In [17]:
def train_classification_model(df, spark_config):
    feature_cols = [f"feature_{i}" for i in range(20)]
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="raw_features")
    scaler = StandardScaler(inputCol="raw_features", outputCol="features", withStd=True, withMean=True)
    rf = RandomForestClassifier(
        labelCol="target",
        featuresCol="features",
        numTrees=100,
        maxDepth=10,
        maxBins=32,
        minInstancesPerNode=10,
        seed=42
    )
    pipeline = Pipeline(stages=[assembler, scaler, rf])
    train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)
    print(f"Training set: {train_data.count():,} rows")
    print(f"Test set: {test_data.count():,} rows")
    print("Starting distributed training...")
    start_time = time.time()
    model = pipeline.fit(train_data)
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    print("Evaluating model...")
    predictions = model.transform(test_data)
    evaluator_auc = BinaryClassificationEvaluator(labelCol="target", metricName="areaUnderROC")
    evaluator_acc = MulticlassClassificationEvaluator(labelCol="target", predictionCol="prediction", metricName="accuracy")
    auc = evaluator_auc.evaluate(predictions)
    accuracy = evaluator_acc.evaluate(predictions)
    print(f"Model Performance:")
    print(f"AUC: {auc:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Training time: {training_time:.2f}s")
    return model, {'auc': auc, 'accuracy': accuracy, 'training_time': training_time}

In [18]:
model_class, metrics_class = train_classification_model(df_class, spark_config)

                                                                                

Training set: 799,674 rows


                                                                                

Test set: 200,326 rows
Starting distributed training...




Training completed in 87.44 seconds
Evaluating model...




Model Performance:
AUC: 0.9018
Accuracy: 0.8157
Training time: 87.44s


                                                                                

### Hyperparameter tuning

In [19]:
def hyperparameter_tuning(df):
    feature_cols = [f"feature_{i}" for i in range(20)]
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    estimator = RandomForestClassifier(labelCol="target", featuresCol="features", seed=42)
    evaluator = BinaryClassificationEvaluator(labelCol="target", metricName="areaUnderROC")
    paramGrid = ParamGridBuilder() \
        .addGrid(estimator.numTrees, [50, 100]) \
        .addGrid(estimator.maxDepth, [5, 10]) \
        .addGrid(estimator.minInstancesPerNode, [5, 10]) \
        .build()
    pipeline = Pipeline(stages=[assembler, estimator])
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
        numFolds=3,
        seed=42
    )
    train_data, _ = df.randomSplit([0.8, 0.2], seed=42)
    print(f"Starting hyperparameter tuning with {len(paramGrid)} parameter combinations...")
    start_time = time.time()
    cv_model = crossval.fit(train_data)
    tuning_time = time.time() - start_time
    print(f"Hyperparameter tuning completed in {tuning_time:.2f} seconds")
    # Get best model
    best_model = cv_model.bestModel
    print(f"Best hyperparameters found:")
    rf_stage = best_model.stages[-1]
    print(f"numTrees: {rf_stage.getNumTrees()}")
    print(f"maxDepth: {rf_stage.getMaxDepth()}")
    print(f"minInstancesPerNode: {rf_stage.getMinInstancesPerNode()}")
    return cv_model, tuning_time

In [None]:
cv_model_class, tune_time_class = hyperparameter_tuning(df_class)

Starting hyperparameter tuning with 8 parameter combinations...


