# Distributed Model Training with Ray and Snowflake (10M)

In [1]:
from snowflake.snowpark import Session
import os

In [2]:
def initiate_snowpark_conn():
  with open("/snowflake/session/token", "r") as f:
      token = f.read()

  connection_parameters = {
      "account": os.getenv("SNOWFLAKE_ACCOUNT"),
      "host": os.getenv("SNOWFLAKE_HOST"),
      "authenticator": "oauth",
      "token": token,
      "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
      "database": os.getenv("SNOWFLAKE_DATABASE"),
      "schema": os.getenv("SNOWFLAKE_SCHEMA"),
      "role": os.getenv("SNOWFLAKE_ROLE"),
      "client_session_keep_alive": True
  }
  snowpark_session = Session.builder.configs(connection_parameters).create()
  return snowpark_session

session = initiate_snowpark_conn()

In [3]:
session.get_current_role()

'"RAY_ROLE"'

In [4]:
session.get_current_warehouse()

'"RAY_WH"'

In [5]:
session.get_current_user()

'"SF$SERVICE$_cuNE0Z9o2prLlePw3Spnw"'

### Dump data from Snowflake table into Snowflake stage mapped on container

In [6]:
snowflake_data_stage = "ARTIFACTSXGBOOSTHYPERPARAMETERTUNING"
snowflake_input_table_name = "BREAST_CANCER_DATA_10M"

In [7]:
import uuid
foldername = str(uuid.uuid4())
snowflake_data_folder_name = "snowflake_" + foldername.replace('-' , '_')
print(snowflake_data_folder_name)

snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c


In [8]:
breast_cancer_snowdf = session.table(snowflake_input_table_name)

In [9]:
breast_cancer_snowdf.count()

10000175

In [10]:
breast_cancer_snowdf.limit(10).to_pandas()

Unnamed: 0,MEANRADIUS,MEANTEXTURE,MEANPERIMETER,MEANAREA,MEANSMOOTHNESS,MEANCOMPACTNESS,MEANCONCAVITY,MEANCONCAVEPOINTS,MEANSYMMETRY,MEANFRACTALDIMENSION,...,WORSTTEXTURE,WORSTPERIMETER,WORSTAREA,WORSTSMOOTHNESS,WORSTCOMPACTNESS,WORSTCONCAVITY,WORSTCONCAVEPOINTS,WORSTSYMMETRY,WORSTFRACTALDIMENSION,TARGET
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,...,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,...,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,...,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,...,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,...,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0
5,12.45,15.7,82.57,477.1,0.1278,0.17,0.1578,0.08089,0.2087,0.07613,...,23.75,103.4,741.6,0.1791,0.5249,0.5355,0.1741,0.3985,0.1244,0
6,18.25,19.98,119.6,1040.0,0.09463,0.109,0.1127,0.074,0.1794,0.05742,...,27.66,153.2,1606.0,0.1442,0.2576,0.3784,0.1932,0.3063,0.08368,0
7,13.71,20.83,90.2,577.9,0.1189,0.1645,0.09366,0.05985,0.2196,0.07451,...,28.14,110.6,897.0,0.1654,0.3682,0.2678,0.1556,0.3196,0.1151,0
8,13.0,21.82,87.5,519.8,0.1273,0.1932,0.1859,0.09353,0.235,0.07389,...,30.73,106.2,739.3,0.1703,0.5401,0.539,0.206,0.4378,0.1072,0
9,12.46,24.04,83.97,475.9,0.1186,0.2396,0.2273,0.08543,0.203,0.08243,...,40.68,97.65,711.4,0.1853,1.058,1.105,0.221,0.4366,0.2075,0


### Unload data into Snowflake stage mapped on container

In [11]:
file_format_sql = f"""
    create or replace file format FILE_FORMAT_PARQUET type='PARQUET' COMPRESSION=NONE
    """
session.sql(file_format_sql).collect()

[Row(status='File format FILE_FORMAT_PARQUET successfully created.')]

In [12]:
data_unload_sql = f"""
        copy into @{snowflake_data_stage}/{snowflake_data_folder_name}/ from 
        (select * from {snowflake_input_table_name}) file_format=(format_name=FILE_FORMAT_PARQUET) single=false header=true max_file_size=100
    """
session.sql(data_unload_sql).collect()

[Row(rows_unloaded=10000175, input_bytes=409382564, output_bytes=409382564)]

### What distributed model training looks like with Ray

In [13]:
!pip install xgboost_ray xgboost==1.7.3

[0m

In [14]:
from xgboost_ray import RayDMatrix, train, RayFileType, RayParams

### Set params

In [15]:
ray_params=RayParams(
        num_actors=4,  # Number of remote actors
        cpus_per_actor=1,
        gpus_per_actor=0.1
)
xgb_params = {
    "tree_method": "gpu_hist",
    "objective": "binary:logistic",
    "eval_metric": ["logloss", "error"]
}

### Import ray

In [16]:
import ray
try:
    ray.shutdown()
except:
    pass
cli = ray.init(address="auto", runtime_env={"pip": ["xgboost_ray", "xgboost==1.7.3"]})

2024-05-08 14:58:20,253	INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.244.9.10:6379...
2024-05-08 14:58:21,440	INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://10.244.9.10:8265 [39m[22m
[2024-05-08 14:58:21,442 I 16067 16067] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
[33m(raylet, ip=10.244.13.10)[0m [2024-05-08 14:58:23,154 I 6334 6334] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m [14:58:44] task [xgboost.ray]:140614335971136 got new rank 1
[33m(raylet, ip=10.244.12.10)[0m [2024-05-08 14:58:24,501 I 6076 6076] logging.cc:230: Set ray log level from environment variable RAY_BACKEND_LOG_LEVEL to -1[32m [repeated 6x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/r

[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO cudaDriverVersion 12010
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO Bootstrap : Using eth0:10.244.11.10<0>
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO Failed to open libibverbs.so[.1]
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO NET/Socket : Using [0]eth0:10.244.11.10<0>
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO Using network Socket
[36m(_RemoteRayXGBoostActor pid=6175, ip=10.244.11.10)[0m statefulset-1:6175:6255 [0] NCCL INFO Trees [0] -1/-1/-1->1->2 [1] 2/0/-1->1->3
[36m(_RemoteRayXGBoostActor pid=61

### Prepare list of training files

In [17]:
query = f"list @{snowflake_data_stage}/{snowflake_data_folder_name}/"
data_files = session.sql(query).select('"name"').collect()

In [18]:
local_data_path = f"/home/artifacts/{snowflake_data_folder_name}/"

In [19]:
train_files = [local_data_path+str(row['name']).split('/')[-1] for row in data_files]

In [20]:
train_files

['/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_0.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_1.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_10.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_11.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_12.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_13.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_14.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_15.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_16.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_17.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_18.parquet',
 '/home/artifacts/snowflake_62e7dc33_8b1e_44e

In [21]:
len(train_files)

354

In [22]:
columns = breast_cancer_snowdf.columns

In [23]:
columns

['MEANRADIUS',
 'MEANTEXTURE',
 'MEANPERIMETER',
 'MEANAREA',
 'MEANSMOOTHNESS',
 'MEANCOMPACTNESS',
 'MEANCONCAVITY',
 'MEANCONCAVEPOINTS',
 'MEANSYMMETRY',
 'MEANFRACTALDIMENSION',
 'RADIUSERROR',
 'TEXTUREERROR',
 'PERIMETERERROR',
 'AREAERROR',
 'SMOOTHNESSERROR',
 'COMPACTNESSERROR',
 'CONCAVITYERROR',
 'CONCAVEPOINTSERROR',
 'SYMMETRYERROR',
 'FRACTALDIMENSIONERROR',
 'WORSTRADIUS',
 'WORSTTEXTURE',
 'WORSTPERIMETER',
 'WORSTAREA',
 'WORSTSMOOTHNESS',
 'WORSTCOMPACTNESS',
 'WORSTCONCAVITY',
 'WORSTCONCAVEPOINTS',
 'WORSTSYMMETRY',
 'WORSTFRACTALDIMENSION',
 'TARGET']

### Initiate model training

In [24]:
if len(train_files)>1:
    dtrain = RayDMatrix(train_files, label="TARGET", columns=columns, filetype = RayFileType.PARQUET)
else:
    dtrain = RayDMatrix(train_files, label="TARGET", columns=columns, filetype = RayFileType.PARQUET, distributed=False)

In [25]:
model = train(xgb_params, dtrain, ray_params=ray_params)

2024-05-08 14:58:23,798	INFO main.py:1140 -- [RayXGBoost] Created 4 new actors (4 total actors). Waiting until actors are ready for training.
2024-05-08 14:58:44,491	INFO main.py:1191 -- [RayXGBoost] Starting XGBoost training.
2024-05-08 14:58:50,232	INFO main.py:1708 -- [RayXGBoost] Finished XGBoost training on training data with total N=10,000,175 in 27.83 seconds (5.74 pure XGBoost training time).


In [26]:
model.save_model("model_trained_with_ray_on_snowflake_demo.xgb")

In [27]:
session.sql(f"remove @{snowflake_data_stage}/{snowflake_data_folder_name}/").collect()

[Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_6_30.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_7_27.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_5_35.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_0_15.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_6_15.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_1_51.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/snowflake_62e7dc33_8b1e_44ef_8dd0_81c3ce2d501c/data_0_4_23.parquet', result='removed'),
 Row(name='artifactsxgboosthyperparametertuning/

In [28]:
cli.disconnect()

In [29]:
session.close()