# Data Generation for Spark MLLib based Model Training on Snowpark Container Services

## Setup and Imports

In [1]:
import ray
import pprint
import warnings
import logging    
import time
import sys
import os
import numpy as np
import socket
import snowflake.connector
from snowflake.snowpark import Session
print(f"Ray version: {ray.__version__}")

  from .autonotebook import tqdm as notebook_tqdm
2025-07-08 02:47:38,686	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Ray version: 2.42.0


## Initialize Snowpark Session

In [2]:
def connection() -> snowflake.connector.SnowflakeConnection:
    if os.path.isfile("/snowflake/session/token"):
        creds = {
            'host': os.getenv('SNOWFLAKE_HOST'),
            'port': os.getenv('SNOWFLAKE_PORT'),
            'protocol': "https",
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'authenticator': "oauth",
            'token': open('/snowflake/session/token', 'r').read(),
            'warehouse': "LARGE_WH",
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True,
            'check_arrow_conversion_error_on_every_column': False
        }
    else:
        creds = {
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'user': os.getenv('SNOWFLAKE_USER'),
            'password': os.getenv('SNOWFLAKE_PASSWORD'),
            'warehouse': snowflake_warehouse,
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True,
            'check_arrow_conversion_error_on_every_column': False
        }

    connection = snowflake.connector.connect(**creds)
    return connection

def get_session() -> Session:
    return Session.builder.configs({"connection": connection()}).create()

In [3]:
session = get_session()

In [4]:
session.get_current_database()

'"RAYDP_SIS_DB"'

In [5]:
cli = ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)

2025-07-08 02:47:39,607	INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 10.244.24.9:6379...
2025-07-08 02:47:39,616	INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at [1m[32m10.244.24.9:8265 [39m[22m
[2025-07-08 02:47:39,620 I 25445 25445] logging.cc:293: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to 1


In [6]:
cluster_resources = ray.cluster_resources()
nodes = ray.nodes()

In [7]:
print("  Cluster Information:")
for i, node in enumerate(nodes):
    node_resources = node.get('Resources', {})
    print(f"   Node {i+1}: {node_resources}")
print()

  Cluster Information:
   Node 1: {'node:10.244.24.137': 1.0, 'CPU': 28.0, 'object_store_memory': 11220602060.0, 'memory': 246362374963.0}
   Node 2: {'node:__internal_head__': 1.0, 'CPU': 28.0, 'object_store_memory': 11220602060.0, 'memory': 235982446387.0, 'node:10.244.24.9': 1.0}
   Node 3: {'node:10.244.24.73': 1.0, 'CPU': 28.0, 'object_store_memory': 11220602060.0, 'memory': 246362051379.0}
   Node 4: {'object_store_memory': 11220602060.0, 'node:10.244.24.201': 1.0, 'memory': 246362272563.0, 'CPU': 28.0}



### Create sample dataset

In [8]:
@ray.remote
def generate_data_chunk(chunk_id, chunk_size, num_features):
    import numpy as np
    np.random.seed(chunk_id)
    data = []
    for i in range(chunk_size):
        features = np.random.normal(0, 1, num_features)
        features[1] = features[0] * 0.5 + np.random.normal(0, 0.5)
        features[2] = features[0] * 0.3 + features[1] * 0.2 + np.random.normal(0, 0.8)
        signal = (features[0] * 0.5 + 
                 features[1] * 0.3 + 
                 features[2] * 0.2 + 
                 np.sum(features[3:6]) * 0.1)
        target = 1 if signal + np.random.normal(0, 0.5) > 0 else 0
        
        # Create dictionary instead of list
        row = {"ID": chunk_id * chunk_size + i}
        for j, feature_val in enumerate(features):
            row[f"FEATURE_{j}"] = float(feature_val)
        row["TARGET"] = int(target)
        
        data.append(row)
    return data

In [9]:
def create_large_dataset_and_dump_into_snowflake(num_rows=1_000_000, num_features=20):
    """
    Create a large synthetic dataset using Ray for parallel data generation
    """
    import pyarrow as pa
    print(f"Generating {num_rows:,} rows with {num_features} features for classification...")
    # Generate data in parallel using Ray
    num_chunks = 50  # Distribute data generation
    chunk_size = num_rows // num_chunks
    print(f"   Using {num_chunks} parallel data generation tasks...")
    # Generate data chunks in parallel
    chunk_futures = [
        generate_data_chunk.remote(i, chunk_size, num_features) 
        for i in range(num_chunks)
    ]
    chunks = ray.get(chunk_futures)
    # Flatten data
    all_data = []
    for chunk in chunks:
        all_data.extend(chunk)
    print(f"Generated {len(all_data):,} data points")
    # Define Arrow schema
    fields = [pa.field("ID", pa.int64())]
    for i in range(num_features):
        fields.append(pa.field(f"FEATURE_{i}", pa.float64()))
    fields.append(pa.field("TARGET", pa.int32()))
    schema = pa.schema(fields)
    
    # Create Arrow Table with schema
    print("🔧 Creating Arrow Table with schema...")
    table = pa.Table.from_pylist(all_data, schema=schema)
    
    # Create Ray Dataset from Arrow Table
    print("🔧 Creating Ray Dataset...")
    ray_ds = ray.data.from_arrow(table)
    
    print(f"✅ Created Ray Dataset with schema: {ray_ds.schema()}")
    return ray_ds

In [10]:
ray_ds = create_large_dataset_and_dump_into_snowflake(1_000_000, 20)

Generating 1,000,000 rows with 20 features for classification...
   Using 50 parallel data generation tasks...
Generated 1,000,000 data points
🔧 Creating Arrow Table with schema...


2025-07-08 02:47:51,589	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


🔧 Creating Ray Dataset...
✅ Created Ray Dataset with schema: Column      Type
------      ----
ID          int64
FEATURE_0   double
FEATURE_1   double
FEATURE_2   double
FEATURE_3   double
FEATURE_4   double
FEATURE_5   double
FEATURE_6   double
FEATURE_7   double
FEATURE_8   double
FEATURE_9   double
FEATURE_10  double
FEATURE_11  double
FEATURE_12  double
FEATURE_13  double
FEATURE_14  double
FEATURE_15  double
FEATURE_16  double
FEATURE_17  double
FEATURE_18  double
FEATURE_19  double
TARGET      int32


In [11]:
ray_ds.show(1)

2025-07-08 02:47:52,328	INFO dataset.py:2704 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-07-08 02:47:52,331	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /raylogs/ray/session_2025-07-07_17-26-31_254142_14/logs/ray-data
2025-07-08 02:47:52,332	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> LimitOperator[limit=1]
Running 0: 0.00 row [00:00, ? row/s]
                                                                              
✔️  Dataset execution finished in 0.02 seconds: : 1.00 row [00:00, 61.1 row/s]

- limit=1 1: 0.00 row [00:00, ? row/s][A
- limit=1: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 172.0B object store: : 0.00 row [00:00, ? row/s][A
- limit=1: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 172.0B object store: : 1.00 row [00:00, 58.3 row/s][A
- limit=1: Tasks: 0; Queued blocks: 0; Resources: 0.0 CPU, 172.0B object store: 

{'ID': 0, 'FEATURE_0': 1.764052345967664, 'FEATURE_1': -0.39446873493320733, 'FEATURE_2': 0.9732168331559461, 'FEATURE_3': 2.240893199201458, 'FEATURE_4': 1.8675579901499675, 'FEATURE_5': -0.977277879876411, 'FEATURE_6': 0.9500884175255894, 'FEATURE_7': -0.1513572082976979, 'FEATURE_8': -0.10321885179355784, 'FEATURE_9': 0.41059850193837233, 'FEATURE_10': 0.144043571160878, 'FEATURE_11': 1.454273506962975, 'FEATURE_12': 0.7610377251469934, 'FEATURE_13': 0.12167501649282841, 'FEATURE_14': 0.44386323274542566, 'FEATURE_15': 0.33367432737426683, 'FEATURE_16': 1.4940790731576061, 'FEATURE_17': -0.20515826376580087, 'FEATURE_18': 0.31306770165090136, 'FEATURE_19': -0.8540957393017248, 'TARGET': 1}





In [12]:
ray_ds.count()

1000000

### Define helper utility to write ray dataset to Snowflake

In [13]:
from ray.data import Datasink
from ray.data import Datasource
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.data.datasource import Reader, WriteResult
from snowflake.connector import connect
from snowflake.connector.pandas_tools import write_pandas
from typing import Optional, Iterable
from ray.data._internal.remote_fn import cached_remote_fn
class SnowflakeTableDatasink(Datasink):
    def __init__(
        self,
        *,
        table_name: str,
        database: Optional[str] = None,
        schema: Optional[str] = None,
        auto_create_table: bool = False,
        override: bool = False,
    ):
        self._table_name = table_name
        self._database = database
        self._schema = schema
        self._override = override
        self._auto_create_table = auto_create_table

    def on_write_start(self):
        if not self._auto_create_table:
            # table should already exist
            session = get_session()
            tables = session.sql(
                f"SHOW TABLES LIKE '{self._table_name}' IN {self._database}.{self._schema}"
            ).collect()
            if not tables:
                raise ValueError(
                    f"Table {self._database}.{self._schema}.{self._table_name} does not exist."
                )

        if self._override:
            logger.warning(
                f"Overriding table {self._database}.{self._schema}.{self._table_name}"
            )
            session = get_session()
            session.sql(
                f"DROP TABLE IF EXISTS {self._database}.{self._schema}.{self._table_name}"
            ).collect()

    def write(self, blocks: Iterable[Block], ctx: TaskContext) -> None:
        # Write the blocks to the Snowflake table
        def write_single_block(block: Block):
            session = get_session()
            pandas_block = BlockAccessor.for_block(block).to_pandas()
            if not pandas_block.empty:
                session.write_pandas(
                    df=pandas_block,
                    table_name=self._table_name,
                    database=self._database.strip('"'),
                    schema=self._schema.strip('"'),
                    auto_create_table=self._auto_create_table,
                    overwrite=False,  # append to the existing table as we are using multi-process to write the table.
                )

        # We need cache_remote_fn because there is a circular dependency issue with the ray.remote
        # this is from one of the existing ray datasink for writing to BigQuery table.
        remote_write_single_block = cached_remote_fn(write_single_block)
        # Write the block to the Snowflake table
        ray.get([remote_write_single_block.remote(block) for block in blocks])

In [14]:
session.sql("drop table if exists SPARK_MLLIB_SAMPLE_DATASET").collect()

[Row(status='SPARK_MLLIB_SAMPLE_DATASET successfully dropped.')]

In [15]:
datasink = SnowflakeTableDatasink(
    table_name="SPARK_MLLIB_SAMPLE_DATASET",
    database=session.get_current_database(),
    schema=session.get_current_schema(),
    auto_create_table=True
)

In [16]:
ray_ds.write_datasink(datasink)

2025-07-08 02:47:52,528	INFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /raylogs/ray/session_2025-07-07_17-26-31_254142_14/logs/ray-data
2025-07-08 02:47:52,528	INFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Write]
Running 0: 0.00 row [00:00, ? row/s]
- Write 1: 0.00 row [00:00, ? row/s][A
- Write: Tasks: 1; Queued blocks: 0; Resources: 1.0 CPU, 256.0MB object store: : 0.00 row [00:01, ? row/s][A
Running Dataset. Active & requested resources: 1/112 CPU, 256.0MB/20.9GB object store: : 0.00 row [00:01, ? row/s]
Running Dataset. Active & requested resources: 1/112 CPU, 256.0MB/20.9GB object store: : 0.00 row [00:02, ? row/s]
Running Dataset. Active & requested resources: 1/112 CPU, 256.0MB/20.9GB object store: : 0.00 row [00:03, ? row/s]
Running Dataset. Active & requested resources: 1/112 CPU, 256.0MB/20.9GB object store: : 0.00 row [00:04, ? row/s]
Running Dataset. Active & requested resource

### Clean up

In [None]:
@ray.remote
def trigger_gc():
    import gc
    gc.collect()

ray.get([trigger_gc.remote() for _ in range(len(ray.nodes()))])

In [22]:
session.close()

In [20]:
cli.disconnect()