# Data Generation for Spark MLLib based Model Training on Snowpark Container Services

## Setup and Imports

In [1]:
import ray
import pprint
import warnings
import logging    
import time
import os
import numpy as np
import socket
import snowflake.connector
from snowflake.snowpark import Session
from snowflake.ml.ray.datasink import SnowflakeTableDatasink
print(f"Ray version: {ray.__version__}")

  import pkg_resources


Ray version: 2.46.0


## Initialize Snowpark Session

In [2]:
def connection() -> snowflake.connector.SnowflakeConnection:
    if os.path.isfile("/snowflake/session/token"):
        creds = {
            'host': os.getenv('SNOWFLAKE_HOST'),
            'port': os.getenv('SNOWFLAKE_PORT'),
            'protocol': "https",
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'authenticator': "oauth",
            'token': open('/snowflake/session/token', 'r').read(),
            'warehouse': "LARGE_WH",
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True
        }
    else:
        creds = {
            'account': os.getenv('SNOWFLAKE_ACCOUNT'),
            'user': os.getenv('SNOWFLAKE_USER'),
            'password': os.getenv('SNOWFLAKE_PASSWORD'),
            'warehouse': snowflake_warehouse,
            'database': os.getenv('SNOWFLAKE_DATABASE'),
            'schema': os.getenv('SNOWFLAKE_SCHEMA'),
            'client_session_keep_alive': True
        }

    connection = snowflake.connector.connect(**creds)
    return connection

def get_session() -> Session:
    return Session.builder.configs({"connection": connection()}).create()

In [3]:
session = get_session()

In [4]:
session.get_current_database()

'"RAYDP_SIS_DB"'

In [5]:
ray.init(address="auto", ignore_reinit_error=True)

2025-07-06 20:46:26,821	INFO worker.py:1694 -- Connecting to existing Ray cluster at address: 10.244.10.75:6379...
2025-07-06 20:46:26,833	INFO worker.py:1879 -- Connected to Ray cluster. View the dashboard at [1m[32m10.244.10.75:8265 [39m[22m


0,1
Python version:,3.10.17
Ray version:,2.46.0
Dashboard:,http://10.244.10.75:8265


[33m(raylet, ip=10.244.10.139)[0m [2025-07-06 20:46:54,267 I 406 406] logging.cc:297: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
[36m(Write pid=2537)[0m   import pkg_resources
[36m(write_single_block pid=516, ip=10.244.10.203)[0m SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. 


In [6]:
cluster_resources = ray.cluster_resources()
nodes = ray.nodes()

In [8]:
print("  Cluster Information:")
for i, node in enumerate(nodes):
    node_resources = node.get('Resources', {})
    print(f"   Node {i+1}: {node_resources}")
print()

  Cluster Information:
   Node 1: {'GPU': 1.0, 'CPU': 6.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659282329.0, 'node:10.244.11.75': 1.0, 'memory': 20204992103.0}
   Node 2: {'GPU': 1.0, 'node:__internal_head__': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8420179968.0, 'node:10.244.10.75': 1.0, 'CPU': 6.0, 'memory': 19647086592.0}
   Node 3: {'GPU': 1.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659241779.0, 'node:10.244.10.203': 1.0, 'CPU': 6.0, 'memory': 20204897485.0}
   Node 4: {'node:10.244.10.139': 1.0, 'GPU': 1.0, 'CPU': 6.0, 'accelerator_type:A10G': 1.0, 'object_store_memory': 8659227033.0, 'memory': 20204863079.0}



### Create sample dataset

In [9]:
@ray.remote
def generate_data_chunk(chunk_id, chunk_size, num_features):
    import numpy as np
    np.random.seed(chunk_id)
    data = []
    for i in range(chunk_size):
        features = np.random.normal(0, 1, num_features)
        features[1] = features[0] * 0.5 + np.random.normal(0, 0.5)
        features[2] = features[0] * 0.3 + features[1] * 0.2 + np.random.normal(0, 0.8)
        signal = (features[0] * 0.5 + 
                 features[1] * 0.3 + 
                 features[2] * 0.2 + 
                 np.sum(features[3:6]) * 0.1)
        target = 1 if signal + np.random.normal(0, 0.5) > 0 else 0
        
        # Create dictionary instead of list
        row = {"ID": chunk_id * chunk_size + i}
        for j, feature_val in enumerate(features):
            row[f"FEATURE_{j}"] = float(feature_val)
        row["TARGET"] = int(target)
        
        data.append(row)
    return data

In [10]:
def create_large_dataset_and_dump_into_snowflake(num_rows=1_000_000, num_features=20):
    """
    Create a large synthetic dataset using Ray for parallel data generation
    """
    import pyarrow as pa
    print(f"Generating {num_rows:,} rows with {num_features} features for classification...")
    # Generate data in parallel using Ray
    num_chunks = 50  # Distribute data generation
    chunk_size = num_rows // num_chunks
    print(f"   Using {num_chunks} parallel data generation tasks...")
    # Generate data chunks in parallel
    chunk_futures = [
        generate_data_chunk.remote(i, chunk_size, num_features) 
        for i in range(num_chunks)
    ]
    chunks = ray.get(chunk_futures)
    # Flatten data
    all_data = []
    for chunk in chunks:
        all_data.extend(chunk)
    print(f"Generated {len(all_data):,} data points")
    # Define Arrow schema
    fields = [pa.field("ID", pa.int64())]
    for i in range(num_features):
        fields.append(pa.field(f"FEATURE_{i}", pa.float64()))
    fields.append(pa.field("TARGET", pa.int32()))
    schema = pa.schema(fields)
    
    # Create Arrow Table with schema
    print("🔧 Creating Arrow Table with schema...")
    table = pa.Table.from_pylist(all_data, schema=schema)
    
    # Create Ray Dataset from Arrow Table
    print("🔧 Creating Ray Dataset...")
    ray_ds = ray.data.from_arrow(table)
    
    print(f"✅ Created Ray Dataset with schema: {ray_ds.schema()}")
    return ray_ds

In [11]:
ray_ds = create_large_dataset_and_dump_into_snowflake(1_000_000, 20)

Generating 1,000,000 rows with 20 features for classification...
   Using 50 parallel data generation tasks...
Generated 1,000,000 data points
🔧 Creating Arrow Table with schema...
🔧 Creating Ray Dataset...
✅ Created Ray Dataset with schema: Column      Type
------      ----
ID          int64
FEATURE_0   double
FEATURE_1   double
FEATURE_2   double
FEATURE_3   double
FEATURE_4   double
FEATURE_5   double
FEATURE_6   double
FEATURE_7   double
FEATURE_8   double
FEATURE_9   double
FEATURE_10  double
FEATURE_11  double
FEATURE_12  double
FEATURE_13  double
FEATURE_14  double
FEATURE_15  double
FEATURE_16  double
FEATURE_17  double
FEATURE_18  double
FEATURE_19  double
TARGET      int32


In [12]:
ray_ds.show(1)

2025-07-06 20:47:10,389	INFO dataset.py:3027 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-07-06 20:47:10,396	INFO logging.py:290 -- Registered dataset logger for dataset dataset_6_0
2025-07-06 20:47:10,411	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_6_0. Full logs are in /raylogs/ray/session_2025-07-06_20-03-57_750929_35/logs/ray-data
2025-07-06 20:47:10,412	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_6_0: InputDataBuffer[Input] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- limit=1 1: 0.00 row [00:00, ? row/s]

2025-07-06 20:47:10,461	INFO streaming_executor.py:220 -- ✔️  Dataset dataset_6_0 execution finished in 0.05 seconds


{'ID': 0, 'FEATURE_0': 1.764052345967664, 'FEATURE_1': -0.39446873493320733, 'FEATURE_2': 0.9732168331559461, 'FEATURE_3': 2.240893199201458, 'FEATURE_4': 1.8675579901499675, 'FEATURE_5': -0.977277879876411, 'FEATURE_6': 0.9500884175255894, 'FEATURE_7': -0.1513572082976979, 'FEATURE_8': -0.10321885179355784, 'FEATURE_9': 0.41059850193837233, 'FEATURE_10': 0.144043571160878, 'FEATURE_11': 1.454273506962975, 'FEATURE_12': 0.7610377251469934, 'FEATURE_13': 0.12167501649282841, 'FEATURE_14': 0.44386323274542566, 'FEATURE_15': 0.33367432737426683, 'FEATURE_16': 1.4940790731576061, 'FEATURE_17': -0.20515826376580087, 'FEATURE_18': 0.31306770165090136, 'FEATURE_19': -0.8540957393017248, 'TARGET': 1}


### Write sample dataset to Snowflake table

In [13]:
session.sql("drop table if exists SPARK_MLLIB_SAMPLE_DATASET").collect()

[Row(status='SPARK_MLLIB_SAMPLE_DATASET successfully dropped.')]

In [14]:
datasink = SnowflakeTableDatasink(
    table_name="SPARK_MLLIB_SAMPLE_DATASET",
    database=session.get_current_database(),
    schema=session.get_current_schema(),
    auto_create_table=True
)

In [15]:
ray_ds.write_datasink(datasink)

2025-07-06 20:47:10,719	INFO logging.py:290 -- Registered dataset logger for dataset dataset_8_0
2025-07-06 20:47:10,721	INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_8_0. Full logs are in /raylogs/ray/session_2025-07-06_20-03-57_750929_35/logs/ray-data
2025-07-06 20:47:10,722	INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_8_0: InputDataBuffer[Input] -> TaskPoolMapOperator[Write]


Running 0: 0.00 row [00:00, ? row/s]

- Write 1: 0.00 row [00:00, ? row/s]

2025-07-06 20:47:42,154	INFO streaming_executor.py:220 -- ✔️  Dataset dataset_8_0 execution finished in 31.43 seconds
2025-07-06 20:47:42,198	INFO dataset.py:4537 -- Data sink SnowflakeTable finished. 1000000 rows and 164.0MB data written.


In [16]:
session.table("SPARK_MLLIB_SAMPLE_DATASET").limit(10).to_pandas()

Unnamed: 0,ID,FEATURE_0,FEATURE_1,FEATURE_2,FEATURE_3,FEATURE_4,FEATURE_5,FEATURE_6,FEATURE_7,FEATURE_8,...,FEATURE_11,FEATURE_12,FEATURE_13,FEATURE_14,FEATURE_15,FEATURE_16,FEATURE_17,FEATURE_18,FEATURE_19,TARGET
0,0,1.764052,-0.394469,0.973217,2.240893,1.867558,-0.977278,0.950088,-0.151357,-0.103219,...,1.454274,0.761038,0.121675,0.443863,0.333674,1.494079,-0.205158,0.313068,-0.854096,1
1,1,-0.742165,0.604305,-0.50951,0.045759,-0.187184,1.532779,1.469359,0.154947,0.378163,...,-0.347912,0.156349,1.230291,1.20238,-0.387327,-0.302303,-1.048553,-1.420018,-1.70627,0
2,2,-1.252795,-1.441497,-0.293912,-0.21274,-0.895467,0.386902,-0.510805,-1.180632,-0.028182,...,0.302472,-0.634322,-0.362741,-0.67246,-0.359553,-0.813146,-1.726283,0.177426,-0.401781,0
3,3,0.051945,0.553199,-0.196318,1.139401,-1.234826,0.402342,-0.68481,-0.870797,-0.57885,...,-1.16515,0.900826,0.465662,-1.536244,1.488252,1.895889,1.17878,-0.179925,-1.070753,1
4,4,0.208275,-0.326475,1.525239,0.706573,0.0105,1.78587,0.126912,0.401989,1.883151,...,0.969397,-1.173123,1.943621,-0.413619,-0.747455,1.922942,1.480515,1.867559,0.906045,1
5,5,0.802456,0.739445,0.849899,0.614079,0.922207,0.376426,-1.099401,0.298238,1.326386,...,-0.435154,1.849264,0.672295,0.407462,-0.769916,0.539249,-0.674333,0.031831,-0.635846,1
6,6,0.396007,-0.375731,-0.3066,0.439392,0.166673,0.635031,2.383145,0.944479,-0.912822,...,-0.461585,-0.068242,1.713343,-0.744755,-0.826439,-0.098453,-0.663478,1.126636,-1.079932,0
7,7,1.929532,0.27729,0.119423,-1.225436,0.844363,-1.000215,-1.544771,1.18803,0.316943,...,0.856831,-0.651026,-1.034243,0.681595,-0.80341,-0.68955,-0.455533,0.017479,-0.353994,0
8,8,0.625231,0.632681,-0.979459,0.052165,-0.739563,1.543015,-1.292857,0.267051,-0.039283,...,-0.171546,0.771791,0.823504,2.163236,1.336528,-0.369182,-0.239379,1.09966,0.655264,1
9,9,-0.738031,-0.252925,-0.749847,0.910179,0.317218,0.786328,-0.466419,-0.944446,-0.41005,...,2.259309,-0.042257,-0.955945,-0.345982,-0.463596,0.481481,-1.540797,0.063262,0.156507,0


In [17]:
session.close()