In [None]:
# Copyright  2024 Forusone
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Ray Data operation examples

### Configuration

In [13]:
! pip install --user -q "google-cloud-aiplatform[ray]>=1.56.0" \
                        "ray[data,train,tune,serve]>=2.33.0"

In [20]:
# @title Define constants
PROJECT_NBR = "721521243942"
PROJECT_ID = "ai-hangsik"
REGION="us-central1"
RAY_CLUSTER_NM = "ray33-cluster-20250216-192557"

In [21]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=REGION)

In [22]:
import ray
from ray.runtime_env import RuntimeEnv
from ray.air.config import RunConfig
from ray.air import CheckpointConfig, ScalingConfig
from ray.util.joblib import register_ray

In [23]:
ray.__version__

'2.33.0'

In [62]:

ray.shutdown()

In [None]:

RAY_ADDRESS=f"vertex_ray://projects/{PROJECT_NBR}/locations/{REGION}/persistentResources/{RAY_CLUSTER_NM}"
print(f"RAY_ADDRESS:{RAY_ADDRESS}")

RUNTIME_ENV = {
  "pip": [
      "google-cloud-aiplatform[ray]>=1.56.0",
      "ray[data,train,tune,serve]>=2.33.0",
      "datasets",
      "evaluate",
      "accelerate==0.18.0",
      "transformers==4.26.0",
      "torch>=1.12.0",
      "deepspeed==0.12.3",
  ],
}

ray.init(address=RAY_ADDRESS,runtime_env=RUNTIME_ENV)

### Basic Operation

#### Get data from GCS

* https://docs.ray.io/en/latest/data/api/doc/ray.data.read_csv.html#ray.data.read_csv

In [48]:

@ray.remote
def load_data():
    
    import ray
    
    ds = ray.data.read_csv("gs://sllm_checkpoints/data/Iris.csv" )
    print(ds.schema())
    print(ds.take(1))

In [49]:
ray.get(load_data.remote())

#### Different parse option

In [None]:

@ray.remote
def load_data():
    
    from pyarrow import csv
    
    parse_options = csv.ParseOptions(delimiter="\t")
    ds = ray.data.read_csv("gs://sllm_checkpoints/data/Iris.tsv", parse_options=parse_options )
    print(ds.schema())


In [None]:
ray.get(load_data.remote())

In [59]:
# https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.write_csv.html#ray.data.Dataset.write_csv

@ray.remote
def save_data():

    import os
    
    ds = ray.data.read_csv("gs://sllm_checkpoints/data/Iris.csv" )
    ds.write_csv("gs://sllm_checkpoints/tmp_store", num_rows_per_file = 1000)


In [60]:
ray.get(save_data.remote())