In [None]:
# Copyright  2024 Forusone
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Ray PyTorch guide

* https://docs.ray.io/en/latest/train/getting-started-pytorch.html#quickstart

### Local GPU env

In [30]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0


In [31]:
!nvidia-smi

Mon Feb 17 00:22:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      On  |   00000000:00:03.0 Off |                    0 |
| N/A   76C    P0             34W /   72W |     257MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L4                      On  |   00

### Configuration

In [32]:
! pip install --user -q "google-cloud-aiplatform[ray]>=1.56.0" \
                        "ray[data,train,tune,serve]>=2.9.3"

In [33]:
# @title Define constants
PROJECT_NBR = "721521243942"
PROJECT_ID = "ai-hangsik"
REGION="us-central1"
RAY_CLUSTER_NM = "ray293-cluster-20250217-075541"

In [34]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=REGION)

In [35]:
import ray
from ray.runtime_env import RuntimeEnv
from ray.air.config import RunConfig
from ray.air import CheckpointConfig, ScalingConfig
from ray.util.joblib import register_ray

In [36]:
ray.__version__

'2.9.3'

### Connect to Ray on Vertex AI

In [38]:
ray.shutdown()

In [40]:

RAY_ADDRESS=f"vertex_ray://projects/{PROJECT_NBR}/locations/{REGION}/persistentResources/{RAY_CLUSTER_NM}"
print(f"RAY_ADDRESS:{RAY_ADDRESS}")

RUNTIME_ENV = {
  "pip": [
        "google-cloud-aiplatform[ray]>=1.56.0",
        "ray[data]==2.9.3",
        "ray[train]==2.9.3",
        "ray[tune]==2.9.3",
        "torch==2.1.2",
        "torchvision==0.16.2",
        "torchmetrics==1.2.1",
        "setuptools==69.5.1",
        "ipython",
  ],
}

ray.init(address=RAY_ADDRESS,runtime_env=RUNTIME_ENV)

RAY_ADDRESS:vertex_ray://projects/721521243942/locations/us-central1/persistentResources/ray293-cluster-20250217-075541
[Ray on Vertex AI]: Cluster State = State.RUNNING


SIGTERM handler is not set because current thread is not the main thread.


0,1
Python version:,3.10.14
Ray version:,2.9.3
Vertex SDK version:,1.80.0
Dashboard:,22af56cc3c3b6ab6-dot-us-central1.aiplatform-training.googleusercontent.com
Interactive Terminal Uri:,001d0c5669882547-dot-us-central1.aiplatform-training.googleusercontent.com
Cluster Name:,ray293-cluster-20250217-075541


### Training code

In [41]:
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func():
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    # model.to("cuda")  # This is done by `prepare_model`
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=5, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(3):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)


### Training configuration

In [42]:
# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=3, use_gpu=True)
scaling_config

Setting,Value
num_workers,3
use_gpu,True
placement_strategy,PACK


### Execute traning job

In [43]:
# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    run_config=ray.train.RunConfig(storage_path="gs://sllm_checkpoints/tmp_store/pytorch"),
)
result = trainer.fit()
result


[36m(TunerInternal pid=2047)[0m [output] This will use the new output engine with verbosity 1. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949
[36m(TunerInternal pid=2047)[0m AIR_VERBOSITY is set, ignoring passed-in ProgressReporter for now.


[36m(TunerInternal pid=2047)[0m 
[36m(TunerInternal pid=2047)[0m View detailed results here: sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58
[36m(TunerInternal pid=2047)[0m To visualize your results with TensorBoard, run: `tensorboard --logdir /root/ray_results/TorchTrainer_2025-02-17_00-26-58`
[36m(TunerInternal pid=2047)[0m 
[36m(TunerInternal pid=2047)[0m Training started without custom configuration.


[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Setting up process group for: env:// [rank=0, world_size=3]
[36m(TorchTrainer pid=761, ip=10.127.0.20)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=761, ip=10.127.0.20)[0m - (ip=10.127.0.20, pid=820) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=761, ip=10.127.0.20)[0m - (ip=10.127.0.20, pid=821) world_rank=1, local_rank=1, node_rank=0
[36m(TorchTrainer pid=761, ip=10.127.0.20)[0m - (ip=10.127.0.22, pid=818) world_rank=2, local_rank=0, node_rank=1
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Moving model to device: cuda:0
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Wrapping provided model in DistributedDataParallel.
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Wrapping provided model in DistributedDataParallel.


[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.g

  0%|          | 0/26421880 [00:00<?, ?it/s])[0m 
  0%|          | 0/26421880 [00:00<?, ?it/s])[0m 
  0%|          | 0/26421880 [00:00<?, ?it/s])[0m 
  0%|          | 32768/26421880 [00:00<01:43, 255923.21it/s]
  0%|          | 32768/26421880 [00:00<01:34, 280254.31it/s]
  0%|          | 32768/26421880 [00:00<01:51, 237473.05it/s]
  0%|          | 65536/26421880 [00:00<01:30, 291407.72it/s]
  0%|          | 65536/26421880 [00:00<02:16, 193237.32it/s]
  0%|          | 98304/26421880 [00:00<01:26, 305330.89it/s]
  1%|          | 163840/26421880 [00:00<01:05, 401242.69it/s]
  0%|          | 98304/26421880 [00:00<02:08, 204110.64it/s]
  1%|          | 196608/26421880 [00:00<00:55, 472544.24it/s]
  1%|▏         | 360448/26421880 [00:00<00:31, 838719.69it/s]
  3%|▎         | 688128/26421880 [00:00<00:17, 1506137.77it/s]
  3%|▎         | 720896/26421880 [00:00<00:14, 1803162.44it/s]
  1%|          | 294912/26421880 [00:00<00:47, 553671.44it/s]
  5%|▌         | 1441792/26421880 [00:00<00:07

[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw


 83%|████████▎ | 22020096/26421880 [00:01<00:00, 25794360.53it/s]
100%|██████████| 26421880/26421880 [00:01<00:00, 16496218.36it/s]
100%|██████████| 26421880/26421880 [00:01<00:00, 14950134.06it/s]


[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m 
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw


  0%|          | 0/29515 [00:00<?, ?it/s].22)[0m 
100%|██████████| 29515/29515 [00:00<00:00, 273549.02it/s]


[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m 
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s].20)[0m 
100%|██████████| 29515/29515 [00:00<00:00, 272147.24it/s]


[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Using downloaded and verified file: /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 

  0%|          | 0/4422102 [00:00<?, ?it/s]2)[0m 
  1%|          | 32768/4422102 [00:00<00:15, 279515.63it/s]
  1%|▏         | 65536/4422102 [00:00<00:14, 290670.26it/s]


[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  2%|▏         | 98304/4422102 [00:00<00:14, 304059.83it/s]
  0%|          | 0/4422102 [00:00<?, ?it/s]0)[0m 
  4%|▎         | 163840/4422102 [00:00<00:10, 421371.01it/s]
  0%|          | 0/4422102 [00:00<?, ?it/s]0)[0m 
  1%|          | 32768/4422102 [00:00<00:21, 204850.87it/s]
  8%|▊         | 360448/4422102 [00:00<00:04, 919057.29it/s]
  1%|          | 32768/4422102 [00:00<00:17, 256647.71it/s]
 16%|█▋        | 720896/4422102 [00:00<00:02, 1727275.95it/s]
  2%|▏         | 98304/4422102 [00:00<00:13, 310265.04it/s]
 32%|███▏      | 1409024/4422102 [00:00<00:00, 3216413.68it/s]
  2%|▏         | 98304/4422102 [00:00<00:14, 304518.82it/s]
  4%|▍         | 196608/4422102 [00:00<00:08, 515230.33it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5023918.82it/s]
  4%|▍         | 196608/4422102 [00:00<00:08, 511753.89it/s]
  9%|▉         | 393216/4422102 [00:00<00:04, 939367.80it/s]


[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m 
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz


  8%|▊         | 360448/4422102 [00:00<00:04, 836857.76it/s]
 18%|█▊        | 786432/4422102 [00:00<00:02, 1785502.49it/s]
 16%|█▋        | 720896/4422102 [00:00<00:02, 1625827.17it/s]
 36%|███▌      | 1572864/4422102 [00:00<00:00, 3464076.33it/s]


[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw


 33%|███▎      | 1441792/4422102 [00:00<00:00, 3183941.76it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5002330.25it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 4956888.32it/s]


[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m 


100%|██████████| 5148/5148 [00:00<00:00, 34770172.29it/s]


[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m 
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m 


100%|██████████| 5148/5148 [00:00<00:00, 34714271.69it/s]
100%|██████████| 5148/5148 [00:00<00:00, 41845498.05it/s]
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000000)
[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000000)
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000000)


[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m {'loss': 0.055006999522447586, 'epoch': 0}
[36m(TunerInternal pid=2047)[0m 
[36m(TunerInternal pid=2047)[0m Training finished iteration 1 at 2025-02-17 00:30:59. Total running time: 3min 59s
[36m(TunerInternal pid=2047)[0m ╭─────────────────────────────────────────╮
[36m(TunerInternal pid=2047)[0m │ Training result                         │
[36m(TunerInternal pid=2047)[0m ├─────────────────────────────────────────┤
[36m(TunerInternal pid=2047)[0m │ checkpoint_dir_name   checkpoint_000000 │
[36m(TunerInternal pid=2047)[0m │ time_this_iter_s              172.13949 │
[36m(TunerInternal pid=2047)[0m │ time_total_s                  172.13949 │
[36m(TunerInternal pid=2047)[0m │ training_iteration                    1 │
[36m(TunerInternal pid=2047)[0m │ epoch                                 0 │
[36m(TunerInternal pid=2047)[0m │ loss                            0.05501 │
[36m(TunerInternal pid=2047)[0m ╰──────────────────

[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000001)
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000001)
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000001)


[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m {'loss': 0.8109909892082214, 'epoch': 2}
[36m(TunerInternal pid=2047)[0m 
[36m(TunerInternal pid=2047)[0m Training finished iteration 3 at 2025-02-17 00:36:19. Total running time: 9min 18s
[36m(TunerInternal pid=2047)[0m ╭─────────────────────────────────────────╮
[36m(TunerInternal pid=2047)[0m │ Training result                         │
[36m(TunerInternal pid=2047)[0m ├─────────────────────────────────────────┤
[36m(TunerInternal pid=2047)[0m │ checkpoint_dir_name   checkpoint_000002 │
[36m(TunerInternal pid=2047)[0m │ time_this_iter_s              163.14938 │
[36m(TunerInternal pid=2047)[0m │ time_total_s                   491.7296 │
[36m(TunerInternal pid=2047)[0m │ training_iteration                    3 │
[36m(TunerInternal pid=2047)[0m │ epoch                                 2 │
[36m(TunerInternal pid=2047)[0m │ loss                            0.81099 │
[36m(TunerInternal pid=2047)[0m ╰────────────────────

[36m(RayTrainWorker pid=820, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000002)
[36m(RayTrainWorker pid=821, ip=10.127.0.20)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000002)
[36m(RayTrainWorker pid=818, ip=10.127.0.22)[0m Checkpoint successfully created at: Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000002)


[36m(TunerInternal pid=2047)[0m 
[36m(TunerInternal pid=2047)[0m Training completed after 3 iterations at 2025-02-17 00:36:20. Total running time: 9min 20s
[36m(TunerInternal pid=2047)[0m 


Result(
  metrics={'loss': 0.8109909892082214, 'epoch': 2},
  path='sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00',
  filesystem='gcs',
  checkpoint=Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000002)
)

In [44]:
result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path        # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

In [45]:
result.metrics, result.checkpoint, result.path , result.error

({'loss': 0.8109909892082214,
  'epoch': 2,
  'timestamp': 1739752578,
  'checkpoint_dir_name': 'checkpoint_000002',
  'should_checkpoint': True,
  'done': True,
  'training_iteration': 3,
  'trial_id': 'e6f25_00000',
  'date': '2025-02-17_00-36-19',
  'time_this_iter_s': 163.14938020706177,
  'time_total_s': 491.72960352897644,
  'pid': 761,
  'hostname': 'gke-vertex-persistent-02-worker-pool1-b55b6127-dl0f',
  'node_ip': '10.127.0.20',
  'config': {},
  'time_since_restore': 491.72960352897644,
  'iterations_since_restore': 3,
  'experiment_tag': '0'},
 Checkpoint(filesystem=gcs, path=sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00/checkpoint_000002),
 'sllm_checkpoints/tmp_store/pytorch/TorchTrainer_2025-02-17_00-26-58/TorchTrainer_e6f25_00000_0_2025-02-17_00-27-00',
 None)

### Load model

In [46]:
# [6] Load the trained model.

with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

In [49]:
ray.shutdown()