In [None]:
# Copyright  2024 Forusone
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Pytorch with Ray on Vertex AI

### Configuration

In [2]:
! pip install --user -q "google-cloud-aiplatform[ray]>=1.56.0" \
                        "ray[data,train,tune,serve]==2.9.3"
                        # datasets \
                        # evaluate \
                        # accelerate \
                        # transformers \
                        # torch \
                        # numpy \
                        # pandas

In [5]:
import time
import numpy as np
import joblib
import pandas as pd
import seaborn as sns
import xgboost as xgb

# from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, LabelEncoder
# from sklearn.compose import ColumnTransformer,make_column_transformer
# from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, KFold, cross_val_score
# from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix,confusion_matrix,classification_report
# from sklearn.pipeline import Pipeline
# from sklearn.ensemble import RandomForestClassifier

# from google.cloud import bigquery
from google.cloud import aiplatform
# from google.cloud.aiplatform.preview import vertex_ray

import ray
from ray.runtime_env import RuntimeEnv
from ray.air.config import RunConfig
from ray.air import CheckpointConfig, ScalingConfig
from ray.util.joblib import register_ray

pd.__version__, ray.__version__

('2.1.4', '2.9.3')

In [6]:
# @title Define constants
PROJECT_NBR = "721521243942"
PROJECT_ID = "ai-hangsik"
REGION="us-central1"
RAY_CLUSTER_NM = "ray293-cluster-20250217-075541"

In [7]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=REGION)

In [8]:
ray.shutdown()

In [9]:

RAY_ADDRESS=f"vertex_ray://projects/{PROJECT_NBR}/locations/{REGION}/persistentResources/{RAY_CLUSTER_NM}"
print(f"RAY_ADDRESS:{RAY_ADDRESS}")

RUNTIME_ENV = {
  "pip": [
        "google-cloud-aiplatform[ray]>=1.56.0",
        "ray[data]==2.9.3",
        "ray[train]==2.9.3",
        "ray[tune]==2.9.3",
        "torch==2.1.2",
        "torchvision==0.16.2",
        "torchmetrics==1.2.1",
        "setuptools==69.5.1",
        "ipython",
  ],
}

ray.init(address=RAY_ADDRESS,runtime_env=RUNTIME_ENV)

RAY_ADDRESS:vertex_ray://projects/721521243942/locations/us-central1/persistentResources/ray293-cluster-20250217-075541
[Ray on Vertex AI]: Cluster State = State.RUNNING


SIGTERM handler is not set because current thread is not the main thread.


0,1
Python version:,3.10.14
Ray version:,2.9.3
Vertex SDK version:,1.80.0
Dashboard:,22af56cc3c3b6ab6-dot-us-central1.aiplatform-training.googleusercontent.com
Interactive Terminal Uri:,001d0c5669882547-dot-us-central1.aiplatform-training.googleusercontent.com
Cluster Name:,ray293-cluster-20250217-075541


[36m(TunerInternal pid=8339)[0m [output] This will use the new output engine with verbosity 1. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949
[36m(TunerInternal pid=8339)[0m AIR_VERBOSITY is set, ignoring passed-in ProgressReporter for now.


[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m View detailed results here: /root/ray_results/TorchTrainer_2025-02-17_00-50-32
[36m(TunerInternal pid=8339)[0m To visualize your results with TensorBoard, run: `tensorboard --logdir /root/ray_results/TorchTrainer_2025-02-17_00-50-32`
[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training started with configuration:
[36m(TunerInternal pid=8339)[0m ╭─────────────────────────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training config                                 │
[36m(TunerInternal pid=8339)[0m ├─────────────────────────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ train_loop_config/batch_size_per_worker      10 │
[36m(TunerInternal pid=8339)[0m │ train_loop_config/epochs                      5 │
[36m(TunerInternal pid=8339)[0m │ train_loop_config/lr                      0.001 │
[36m(TunerInternal pid=8339)[0m ╰──────────────────────────────────────

[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Setting up process group for: env:// [rank=0, world_size=3]
[36m(TorchTrainer pid=951, ip=10.127.0.21)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=951, ip=10.127.0.21)[0m - (ip=10.127.0.21, pid=1011) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=951, ip=10.127.0.21)[0m - (ip=10.127.0.21, pid=1012) world_rank=1, local_rank=1, node_rank=0
[36m(TorchTrainer pid=951, ip=10.127.0.21)[0m - (ip=10.127.0.20, pid=1107) world_rank=2, local_rank=0, node_rank=1


[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]0)[0m 
  0%|          | 0/26421880 [00:00<?, ?it/s]1)[0m 
  0%|          | 32768/26421880 [00:00<01:49, 241724.43it/s]
  0%|          | 32768/26421880 [00:00<01:44, 251943.03it/s]
  0%|          | 98304/26421880 [00:00<01:26, 303393.55it/s]
  0%|          | 98304/26421880 [00:00<01:26, 305392.45it/s]
  1%|          | 196608/26421880 [00:00<00:51, 507393.95it/s]
  1%|          | 196608/26421880 [00:00<00:51, 511646.06it/s]
  1%|▏         | 393216/26421880 [00:00<00:28, 928253.66it/s]
  1%|▏         | 393216/26421880 [00:00<00:27, 936860.61it/s]
  3%|▎         | 786432/26421880 [00:00<00:14, 1769536.97it/s]
  3%|▎         | 753664/26421880 [00:00<00:15, 1690565.06it/s]
  6%|▌         | 1540096/26421880 [00:00<00:07, 3346051.56it/s]
  6%|▌         | 1507328/26421880 [00:00<00:07, 3315287.36it/s]
 12%|█▏        | 3112960/26421880 [00:00<00:03, 6691067.74it/s]
 11%|█▏        | 3014656/26421880 [00:00<00:03, 6535453.26it/s]
 22%|██▏       | 57671

[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Extracting /root/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Extracting /root/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw


100%|██████████| 26421880/26421880 [00:01<00:00, 15856320.96it/s]


[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m 
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]0.21)[0m 
100%|██████████| 29515/29515 [00:00<00:00, 267643.78it/s]
  0%|          | 0/29515 [00:00<?, ?it/s]0.20)[0m 


[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Extracting /root/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m 
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Extracting /root/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 268228.91it/s]


[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]21)[0m 
  0%|          | 0/4422102 [00:00<?, ?it/s]20)[0m 
  1%|          | 32768/4422102 [00:00<00:17, 256381.52it/s]
  1%|          | 32768/4422102 [00:00<00:18, 235244.00it/s]
  2%|▏         | 98304/4422102 [00:00<00:14, 307010.18it/s]
  2%|▏         | 98304/4422102 [00:00<00:13, 309457.62it/s]
  4%|▍         | 196608/4422102 [00:00<00:08, 515671.91it/s]
  4%|▍         | 196608/4422102 [00:00<00:08, 517825.21it/s]
  9%|▉         | 393216/4422102 [00:00<00:04, 945875.83it/s]
  9%|▉         | 393216/4422102 [00:00<00:04, 947283.55it/s]
 17%|█▋        | 753664/4422102 [00:00<00:02, 1709688.46it/s]
 18%|█▊        | 786432/4422102 [00:00<00:02, 1805243.71it/s]
 35%|███▍      | 1540096/4422102 [00:00<00:00, 3448623.15it/s]
 36%|███▌      | 1572864/4422102 [00:00<00:00, 3508997.01it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5057249.65it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5029860.32it/s]


[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Extracting /root/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m 
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Extracting /root/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-cen

100%|██████████| 5148/5148 [00:00<00:00, 44520158.75it/s]


[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Extracting /root/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/FashionMNIST/raw
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m 


100%|██████████| 5148/5148 [00:00<00:00, 37228063.78it/s]
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=1011, ip=10.127.0.21)[0m Wrapping provided model in DistributedDataParallel.
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=1107, ip=10.127.0.20)[0m Wrapping provided model in DistributedDataParallel.
Train Epoch 0:   0%|          | 0/2000 [00:00<?, ?it/s]
Train Epoch 0:   0%|          | 0/2000 [00:00<?, ?it/s]
Train Epoch 0:   0%|          | 0/2000 [00:00<?, ?it/s]
Train Epoch 0:   0%|          | 1/2000 [00:00<05:34,  5.98it/s]
Train Epoch 0:   0%|          | 1/2000 [00:00<05:24,  6.17it/s]
Train Epoch 0:   0%|          | 1/2000 [00:00<05:31,  6.02it/s]
Train Epoch 0:   1%|          | 18/2000 [00:00<00:24, 80.77it/s]
Train Epoch 0:   1%|          | 18/2000 [00:00<00:24, 82.48it/s]
Train Epoch 0:   1%|          | 17/2000 [00:00<00:25, 76.53it/s]
Train Epoch 0:   

[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training finished iteration 1 at 2025-02-17 00:51:02. Total running time: 28s
[36m(TunerInternal pid=8339)[0m ╭───────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training result               │
[36m(TunerInternal pid=8339)[0m ├───────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ checkpoint_dir_name           │
[36m(TunerInternal pid=8339)[0m │ time_this_iter_s      22.4977 │
[36m(TunerInternal pid=8339)[0m │ time_total_s          22.4977 │
[36m(TunerInternal pid=8339)[0m │ training_iteration          1 │
[36m(TunerInternal pid=8339)[0m │ accuracy              0.79184 │
[36m(TunerInternal pid=8339)[0m │ loss                  0.54291 │
[36m(TunerInternal pid=8339)[0m ╰───────────────────────────────╯


Test Epoch 0: 100%|██████████| 334/334 [00:01<00:00, 313.51it/s]
Train Epoch 1:   0%|          | 0/2000 [00:00<?, ?it/s]
Test Epoch 0: 100%|██████████| 334/334 [00:01<00:00, 328.19it/s]
Train Epoch 1:   0%|          | 0/2000 [00:00<?, ?it/s]
Train Epoch 1:   1%|          | 20/2000 [00:00<00:10, 196.78it/s]
Train Epoch 1:   1%|          | 22/2000 [00:00<00:09, 215.27it/s]
Train Epoch 1:   1%|          | 19/2000 [00:00<00:10, 181.12it/s]
Train Epoch 1:   2%|▏         | 41/2000 [00:00<00:09, 202.48it/s]
Train Epoch 1:   2%|▏         | 46/2000 [00:00<00:08, 227.10it/s]
Train Epoch 1:   2%|▏         | 39/2000 [00:00<00:10, 189.04it/s]
Train Epoch 1:   3%|▎         | 62/2000 [00:00<00:09, 205.34it/s]
Train Epoch 1:   4%|▎         | 70/2000 [00:00<00:08, 230.74it/s]
Train Epoch 1:   3%|▎         | 59/2000 [00:00<00:10, 190.62it/s]
Train Epoch 1:   4%|▍         | 83/2000 [00:00<00:09, 206.75it/s]
Train Epoch 1:   5%|▍         | 94/2000 [00:00<00:08, 230.52it/s]
Train Epoch 1:   4%|▍         | 

[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training finished iteration 2 at 2025-02-17 00:51:13. Total running time: 39s
[36m(TunerInternal pid=8339)[0m ╭───────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training result               │
[36m(TunerInternal pid=8339)[0m ├───────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ checkpoint_dir_name           │
[36m(TunerInternal pid=8339)[0m │ time_this_iter_s      11.2978 │
[36m(TunerInternal pid=8339)[0m │ time_total_s          33.7955 │
[36m(TunerInternal pid=8339)[0m │ training_iteration          2 │
[36m(TunerInternal pid=8339)[0m │ accuracy              0.82543 │
[36m(TunerInternal pid=8339)[0m │ loss                  0.47401 │
[36m(TunerInternal pid=8339)[0m ╰───────────────────────────────╯


Train Epoch 2:   2%|▏         | 46/2000 [00:00<00:08, 227.65it/s]
Train Epoch 2:   2%|▏         | 42/2000 [00:00<00:09, 207.63it/s]
Train Epoch 2:   2%|▏         | 39/2000 [00:00<00:10, 192.74it/s]
Train Epoch 2:   4%|▎         | 70/2000 [00:00<00:08, 231.57it/s]
Train Epoch 2:   3%|▎         | 63/2000 [00:00<00:09, 206.91it/s]
Train Epoch 2:   3%|▎         | 59/2000 [00:00<00:10, 193.79it/s]
Train Epoch 2:   5%|▍         | 94/2000 [00:00<00:08, 230.42it/s]
Train Epoch 2:   4%|▍         | 84/2000 [00:00<00:09, 207.60it/s]
Train Epoch 2:   4%|▍         | 79/2000 [00:00<00:09, 195.49it/s]
Train Epoch 2:   5%|▍         | 99/2000 [00:00<00:09, 196.38it/s]
Train Epoch 2:   6%|▌         | 118/2000 [00:00<00:11, 166.34it/s]
Train Epoch 2:   5%|▌         | 105/2000 [00:00<00:11, 163.98it/s]
Train Epoch 2:   6%|▌         | 119/2000 [00:00<00:09, 196.39it/s]
Train Epoch 2:   7%|▋         | 141/2000 [00:00<00:10, 183.10it/s]
Train Epoch 2:   6%|▋         | 126/2000 [00:00<00:10, 176.50it/s]
Train

[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training finished iteration 3 at 2025-02-17 00:51:24. Total running time: 51s
[36m(TunerInternal pid=8339)[0m ╭───────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training result               │
[36m(TunerInternal pid=8339)[0m ├───────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ checkpoint_dir_name           │
[36m(TunerInternal pid=8339)[0m │ time_this_iter_s      11.4413 │
[36m(TunerInternal pid=8339)[0m │ time_total_s          45.2368 │
[36m(TunerInternal pid=8339)[0m │ training_iteration          3 │
[36m(TunerInternal pid=8339)[0m │ accuracy              0.83893 │
[36m(TunerInternal pid=8339)[0m │ loss                  0.43932 │
[36m(TunerInternal pid=8339)[0m ╰───────────────────────────────╯


Train Epoch 3:   2%|▏         | 46/2000 [00:00<00:08, 227.06it/s]
Train Epoch 3:   2%|▏         | 35/2000 [00:00<00:11, 172.06it/s]
Train Epoch 3:   2%|▏         | 47/2000 [00:00<00:08, 231.30it/s]
Train Epoch 3:   3%|▎         | 69/2000 [00:00<00:09, 211.20it/s]
Train Epoch 3:   3%|▎         | 53/2000 [00:00<00:11, 173.76it/s]
Train Epoch 3:   4%|▎         | 71/2000 [00:00<00:08, 234.59it/s]
Train Epoch 3:   5%|▍         | 91/2000 [00:00<00:09, 199.86it/s]
Train Epoch 3:   4%|▎         | 72/2000 [00:00<00:10, 179.05it/s]
Train Epoch 3:   5%|▍         | 95/2000 [00:00<00:08, 219.01it/s]
Train Epoch 3:   5%|▍         | 92/2000 [00:00<00:10, 183.80it/s]
Train Epoch 3:   6%|▌         | 112/2000 [00:00<00:12, 147.97it/s]
Train Epoch 3:   6%|▌         | 112/2000 [00:00<00:10, 186.58it/s]
Train Epoch 3:   6%|▌         | 118/2000 [00:00<00:11, 160.89it/s]
Train Epoch 3:   7%|▋         | 134/2000 [00:00<00:11, 165.12it/s]
Train Epoch 3:   7%|▋         | 132/2000 [00:00<00:09, 187.96it/s]
Train

[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training finished iteration 4 at 2025-02-17 00:51:35. Total running time: 1min 1s
[36m(TunerInternal pid=8339)[0m ╭───────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training result               │
[36m(TunerInternal pid=8339)[0m ├───────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ checkpoint_dir_name           │
[36m(TunerInternal pid=8339)[0m │ time_this_iter_s       10.604 │
[36m(TunerInternal pid=8339)[0m │ time_total_s          55.8408 │
[36m(TunerInternal pid=8339)[0m │ training_iteration          4 │
[36m(TunerInternal pid=8339)[0m │ accuracy              0.84823 │
[36m(TunerInternal pid=8339)[0m │ loss                   0.4239 │
[36m(TunerInternal pid=8339)[0m ╰───────────────────────────────╯


Train Epoch 4:   2%|▏         | 38/2000 [00:00<00:10, 187.77it/s]
Train Epoch 4:   3%|▎         | 68/2000 [00:00<00:08, 224.89it/s]
Train Epoch 4:   2%|▏         | 46/2000 [00:00<00:08, 223.11it/s]
Train Epoch 4:   3%|▎         | 59/2000 [00:00<00:09, 195.59it/s]
Train Epoch 4:   5%|▍         | 92/2000 [00:00<00:08, 227.15it/s]
Train Epoch 4:   4%|▎         | 70/2000 [00:00<00:08, 226.86it/s]
Train Epoch 4:   4%|▍         | 80/2000 [00:00<00:09, 198.77it/s]
Train Epoch 4:   5%|▍         | 93/2000 [00:00<00:08, 227.74it/s]
Train Epoch 4:   5%|▌         | 103/2000 [00:00<00:09, 207.33it/s]
Train Epoch 4:   6%|▌         | 115/2000 [00:00<00:10, 177.66it/s]
Train Epoch 4:   6%|▌         | 116/2000 [00:00<00:09, 202.70it/s]
Train Epoch 4:   6%|▋         | 127/2000 [00:00<00:08, 216.16it/s]
Train Epoch 4:   7%|▋         | 137/2000 [00:00<00:09, 188.20it/s]
Train Epoch 4:   7%|▋         | 137/2000 [00:00<00:09, 202.65it/s]
Train Epoch 4:   8%|▊         | 151/2000 [00:00<00:08, 221.60it/s]
Tra

[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training finished iteration 5 at 2025-02-17 00:51:45. Total running time: 1min 12s
[36m(TunerInternal pid=8339)[0m ╭───────────────────────────────╮
[36m(TunerInternal pid=8339)[0m │ Training result               │
[36m(TunerInternal pid=8339)[0m ├───────────────────────────────┤
[36m(TunerInternal pid=8339)[0m │ checkpoint_dir_name           │
[36m(TunerInternal pid=8339)[0m │ time_this_iter_s      10.3631 │
[36m(TunerInternal pid=8339)[0m │ time_total_s          66.2039 │
[36m(TunerInternal pid=8339)[0m │ training_iteration          5 │
[36m(TunerInternal pid=8339)[0m │ accuracy              0.85633 │
[36m(TunerInternal pid=8339)[0m │ loss                  0.40158 │
[36m(TunerInternal pid=8339)[0m ╰───────────────────────────────╯
[36m(TunerInternal pid=8339)[0m 
[36m(TunerInternal pid=8339)[0m Training completed after 5 iterations at 2025-02-17 00:51:47. Total running time: 1min 13s
[36m(Tu

In [10]:
import os
from typing import Dict

import torch
from filelock import FileLock
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Normalize, ToTensor
from tqdm import tqdm

import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def get_dataloaders(batch_size):
    # Transform to normalize the input images
    transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

    with FileLock(os.path.expanduser("~/data.lock")):
        # Download training data from open datasets
        training_data = datasets.FashionMNIST(
            root="~/data",
            train=True,
            download=True,
            transform=transform,
        )

        # Download test data from open datasets
        test_data = datasets.FashionMNIST(
            root="~/data",
            train=False,
            download=True,
            transform=transform,
        )

    # Create data loaders
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, test_dataloader


# Model Definition
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train_func_per_worker(config: Dict):
    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # Get dataloaders inside the worker training function
    train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

    model = NeuralNetwork()

    # [2] Prepare and wrap your model with DistributedDataParallel
    # Move the model to the correct GPU/CPU device
    # ============================================================
    model = ray.train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Model training loop
    for epoch in range(epochs):
        if ray.train.get_context().get_world_size() > 1:
            # Required for the distributed sampler to shuffle properly across epochs.
            train_dataloader.sampler.set_epoch(epoch)

        model.train()
        for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss, num_correct, num_total = 0, 0, 0
        with torch.no_grad():
            for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)

                test_loss += loss.item()
                num_total += y.shape[0]
                num_correct += (pred.argmax(1) == y).sum().item()

        test_loss /= len(test_dataloader)
        accuracy = num_correct / num_total

        # [3] Report metrics to Ray Train
        # ===============================
        ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})


def train_fashion_mnist(num_workers=3, use_gpu=False):
    global_batch_size = 32

    train_config = {
        "lr": 1e-3,
        "epochs": 5,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        scaling_config=scaling_config,
    )

    # [4] Start distributed training
    # Run `train_func_per_worker` on all workers
    # =============================================
    result = trainer.fit()
    print(f"Training result: {result}")


In [11]:
if __name__ == "__main__":
    train_fashion_mnist(num_workers=3, use_gpu=True)

Training result: Result(
  metrics={'loss': 0.40158225087627414, 'accuracy': 0.8563287342531494},
  path='/root/ray_results/TorchTrainer_2025-02-17_00-50-32/TorchTrainer_3156e_00000_0_2025-02-17_00-50-33',
  filesystem='local',
  checkpoint=None
)


In [12]:
ray.shutdown()