Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] Multi GPU, Multi Node hyperparameter search not functioning #38505

Closed
f2010126 opened this issue Aug 16, 2023 · 7 comments
Closed

[Tune] Multi GPU, Multi Node hyperparameter search not functioning #38505

f2010126 opened this issue Aug 16, 2023 · 7 comments
Assignees
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue tune Tune-related issues

Comments

@f2010126
Copy link

What happened + What you expected to happen

I am unable to run Ray Tune for multiple GPUs on multiple nodes. I am using a slurm cluster.
Steps:

  • Start the cluster and connect at-least 1 master and 1 worker node, each having atleast 2 GPUs and 2 CPUs alloted.
  • check status to see nodes are available ray status
  • Run the script. Resources used per trial, 2 GPUs, 2 CPUs

Expected behavior

The Ray tune should use the available resources and do the hyperparmeter tuning. Instead it is stuck at PENDING.
A similar issue is referenced here: #24259
I too wish to use HPBandster. But even with the default example Asha, it isn't working.

Observed behavior

The output remains at:

Number of trials: 10/10 (9 PENDING, 1 RUNNING)

ray status returns
Demands: {'CPU': 1.0} * 1, {'CPU': 2.0, 'GPU': 2.0} * 2 (PACK): 9+ pending placement groups

Would using ray_lightning plugin be a better solution?
Note When I have 1 GPU each on separate node or 8 GPUs on a single node, things work just fine.
It's the n GPUs on m Nodes that's the issue.

Please advise.

Versions / Dependencies

ray == 2.5.0
pytorch-lightning == 2.0.7
NCCL version 2.14.3+cuda11.7
Python 3.10.6

Reproduction script

import argparse
import torch
import time
from pytorch_lightning.loggers import TensorBoardLogger

from torch.nn import functional as F
from torchmetrics import Accuracy
import pytorch_lightning as pl
import os
from typing import List
import traceback
from torch.utils.data import random_split, DataLoader
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from torchvision import datasets, transforms
from ray import air, tune
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
import ray
from ray.tune import CLIReporter
from ray.train.lightning import LightningConfigBuilder, LightningTrainer



class DataModuleMNIST(pl.LightningDataModule):

    def __init__(self):
        super().__init__()
        self.download_dir = './data'
        self.batch_size = 32
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.dataworkers = int(os.cpu_count() / 2)
        self.prepare_data_per_node = True

    def prepare_data(self):
        datasets.MNIST(self.download_dir,
                       train=True, download=True)

        datasets.MNIST(self.download_dir, train=False,
                       download=True)

    def setup(self, stage=None):
        data = datasets.MNIST(self.download_dir,
                              train=True, transform=self.transform, download=True)

        self.train_data, self.valid_data = random_split(data, [55000, 5000])

        self.test_data = datasets.MNIST(self.download_dir,
                                        train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=self.dataworkers)

    def val_dataloader(self):
        return DataLoader(self.valid_data, batch_size=self.batch_size, num_workers=self.dataworkers)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.dataworkers)


class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self, config):
        super(LightningMNISTClassifier, self).__init__()

        self.lr = config["lr"]
        layer_1, layer_2 = config["layer_1"], config["layer_2"]
        self.batch_size = config["batch_size"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
        self.layer_2 = torch.nn.Linear(layer_1, layer_2)
        self.layer_3 = torch.nn.Linear(layer_2, 10)
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.val_output_list = []

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def on_fit_start(self):
        
        if torch.cuda.is_available():
            print(f" No of GPUs available : {torch.cuda.device_count()}")
        else:
            print("No GPU available")

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        self.log("ptl/train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True,sync_dist=True)
        self.log("ptl/train_accuracy", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True,sync_dist=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        result = {"val_loss": loss, "val_accuracy": acc}
        self.val_output_list.append(result)
        return {"val_loss": loss, "val_accuracy": acc}

    def on_validation_epoch_end(self):
        outputs = self.val_output_list
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()

        self.log("ptl/val_loss", avg_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("ptl/val_accuracy", avg_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {"loss": avg_loss, "acc": avg_acc}

def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0, exp_name="tune_mnist"):

    # Static configs that does not change across trials
    dm = DataModuleMNIST()
    # doesn't call prepare data
    logger = TensorBoardLogger(save_dir=os.getcwd(), name="tune-ptl-example", version=".")
    if torch.cuda.is_available():
        n_devices = torch.cuda.device_count()
        accelerator = 'gpu'
        use_gpu = True
        gpus_per_trial = 2
    else:
        n_devices = 0
        accelerator = 'cpu'
        use_gpu = False
    print(f" No of GPUs available : {torch.cuda.device_count()} and accelerator is {accelerator}")

    static_lightning_config = (
        LightningConfigBuilder()
        .module(cls=LightningMNISTClassifier)
        .trainer(max_epochs=num_epochs, accelerator=accelerator, logger=logger,)
        .fit_params(datamodule=dm)
        # .strategy(name='ddp')
        .checkpointing(monitor="ptl/val_accuracy", save_top_k=2, mode="max")
        .build()
    )

    # Searchable configs across different trials
    searchable_lightning_config = (
        LightningConfigBuilder()
        .module(config={
            "layer_1": tune.choice([32, 64, 128]),
            "layer_2": tune.choice([64, 128, 256]),
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": tune.choice([32, 64, 128]),
        })
        .build()
    )

    # Make sure to also define an AIR CheckpointConfig here
    # to properly save checkpoints in AIR format.
    run_config = RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute="ptl/val_accuracy",
            checkpoint_score_order="max",
        ),
    )

    scheduler = ASHAScheduler(max_t=num_epochs, # max no of epochs a trial can run
                              grace_period=1, reduction_factor=2,
                              time_attr = "training_iteration")

    scaling_config = ScalingConfig(
        # no of other nodes?
        num_workers=2, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU": gpus_per_trial}
    )

    lightning_trainer = LightningTrainer(
        lightning_config=static_lightning_config,
        scaling_config=scaling_config,
    )

    tuner = tune.Tuner(
        lightning_trainer,
        param_space={"lightning_config": searchable_lightning_config},
        tune_config=tune.TuneConfig( # for Tuner
            time_budget_s=3000,
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples, # Number of times to sample from the hyperparameter space
            scheduler=scheduler,
        ),
        run_config=air.RunConfig( # for Tuner.run
            name=exp_name,
            verbose=2,
            storage_path="./ray_results",
            log_to_file=True,
            # configs given to Tuner are used.
            # progress_reporter=reporter,
            checkpoint_config=CheckpointConfig(
                num_to_keep=2,
                checkpoint_score_attribute="ptl/val_accuracy",
                checkpoint_score_order="max",
            ),
        ),

    )

    try:
        start = time.time()
        results = tuner.fit()
        end = time.time()
        hours, rem = divmod(end - start, 3600)
        minutes, seconds = divmod(rem, 60)
        print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds))
        best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
        print("Best hyperparameters found were: ", best_result)
    except ray.exceptions.RayTaskError:
        print("User function raised an exception!")
    except Exception as e:
        print("Other error", e)
        print(traceback.format_exc())

def parse_args():
    parser = argparse.ArgumentParser(description="Tune on MultiNode")
    parser.add_argument(
        "--cuda",
        action="store_true",
        default=False,
        help="Enables GPU training")
    parser.add_argument(
        "--smoke-test", action="store_false", help="Finish quickly for testing")
    parser.add_argument(
        "--ray-address",
        help="Address of Ray cluster for seamless distributed execution.")
    parser.add_argument(
        "--server-address",
        type=str,
        default=None,
        required=False,
        help="The address of server to connect to if using "
             "Ray Client.")
    parser.add_argument("--exp-name", type=str, default="tune_mnist")
    args, _ = parser.parse_known_args()
    return args


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Tune on local")
    parser.add_argument(
        "--smoke-test", action="store_false", help="Finish quickly for testing")  # store_false will default to True
    parser.add_argument("--exp-name", type=str, default="tune_mnist")
    args, _ = parser.parse_known_args()

    # Start training
    if args.smoke_test:
        print("Smoketesting...")
        tune_mnist(num_samples=10, num_epochs=3, gpus_per_trial=0, exp_name=args.exp_name)
    else:
        tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=8, exp_name=args.exp_name)

Issue Severity

High: It blocks me from completing my task.

@f2010126 f2010126 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Aug 16, 2023
@woshiyyya
Copy link
Member

woshiyyya commented Aug 24, 2023

Hi @f2010126 , I think it's the problem of scaling config, here you specified

scaling_config = ScalingConfig(
        num_workers=2, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU": gpus_per_trial}
    )

This ScalingConfig is for per trial. It will try to allocate num_workers * gpus_per_trial GPUs for 1 trial.

So the correct configuration would be

scaling_config = ScalingConfig(
        num_workers=gpus_per_trial, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU":1}
    )

In this case, you will launch num_samples trials, each trial has gpus_per_trial workers, each worker has 1 GPU.

@woshiyyya woshiyyya self-assigned this Aug 24, 2023
@woshiyyya woshiyyya added train Ray Train Related Issue tune Tune-related issues and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Aug 24, 2023
@f2010126
Copy link
Author

Hi @woshiyyya , I don't think that's it. Documentation calls the num_workers as number of workers (Ray actors) to launch.
I want 2 Actors. For the ScalingConfig I gave, I expect the ray tune to run across 2 nodes (each being one worker) and each worker/node has access to 2 GPUs.

scaling_config = ScalingConfig(
        num_workers=gpus_per_trial, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU":1}
    )

When I use the scaling config you suggest, and check the ray status:

Usage:
 5.0/128.0 CPU (5.0 used of 10.0 reserved in placement groups)
 2.0/4.0 GPU (2.0 used of 4.0 reserved in placement groups)
 0B/691.96GiB memory
 0B/300.55GiB object_store_memory

Demands:
 (no resource demands)

It's only using 2 GPUs, leaving out the other 2. When I use the following to try and use all 4:

scaling_config = ScalingConfig(
        # no of other nodes?
        num_workers=gpus_per_trial, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU": 2}
    )

My run hangs.:

Trial status: 1 RUNNING | 1 PENDING
Current time: 2023-08-25 09:33:48. Total running time: 4min 34s
Logical resource usage: 5.0/128 CPUs, 4.0/4 GPUs (0.0/2.0 accelerator_type:G)

Ray Status Output:

Usage:
 5.0/128.0 CPU (5.0 used of 5.0 reserved in placement groups)
 4.0/4.0 GPU (4.0 used of 4.0 reserved in placement groups)
 0B/691.96GiB memory
 0B/300.55GiB object_store_memory

Demands:
 {'CPU': 1.0} * 1, {'CPU': 2.0, 'GPU': 2.0} * 2 (PACK): 1+ pending placement groups

Does Ray Tune and Ray Train in general not work for the multinode combined with multi GPU use case?

@krfricke
Copy link
Contributor

krfricke commented Aug 30, 2023

Hi @f2010126,

I think there may be a misunderstanding of terms here. Let me clarify.

LightningTrainer is used for distributed training. The ScalingConfig only configures each LightningTrainer. Thus, a worker is the number of distributed workers each LightningTrainer uses.

A trial is a Ray Tune trial. If you run 10 trials, you'll run 10 LightningTrainers, each occupying the resources specified in the ScalingConfig.

If you use

scaling_config = ScalingConfig(
        num_workers=2, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU": 2}
    )

this means, each trial will start 2 workers, and each worker will occupy 2 CPUs and 2 GPUs. Thus, if your cluster has 4 GPUs, exactly one trial can run at the same time.

If you use e.g.

scaling_config = ScalingConfig(
        num_workers=2, use_gpu=use_gpu, resources_per_worker={"CPU": 1, "GPU": 1}
    )

this would mean that each trial will still start 2 workers, but each worker will only occupy 1 CPU and 1 GPU. In the same cluster this means that 2 trials can run at the same time.

Note that the LightningTrainer itself also occupies 1 CPU. So if your nodes have exactly 2 CPUs and 2 GPUs, you should consider passing trainer_resources={"CPU": 0} to the scaling config.

Multi node + multi GPU training is one of the core use cases for Ray Train and Ray Tune :-) It's mostly a matter of configuration.

@krfricke krfricke self-assigned this Aug 30, 2023
@f2010126
Copy link
Author

Hi @krfricke Thank you for the detailed explanation. I also used this tutorial and it is as you said. I am mixing up the concepts of actors, workers, and nodes.

I have a few doubts:

  1. If i use a config:
        num_workers=1, use_gpu=use_gpu, resources_per_worker={"CPU": 2, "GPU": 8}
    )

Does this only one LightningTrainer will run with access to 8 GPUs? Would it still train in a distributed fashion?

  1. You said with config
scaling_config = ScalingConfig(
        num_workers=2, use_gpu=use_gpu, resources_per_worker={"CPU": 1, "GPU": 1}
    )

Each trial has 2 workers, each using 1 GPU, 1 CPU. Given 4 GPUS on the SLURM cluster, 2 Trials would run.

What config can I use so 2 trials run in parallel, each with 2 GPUs per worker? I couldn't find the flag to set number of trials (Do I add the ConcurrencyLimiter?)

  1. How does the Ray cluster system know if I want to allocate a job to an entire SLURM node? Does it matter to the setup if it is running on the GPUs of the same node or if the GPUs during a particular trial are split across nodes? This part is a bit of a black box for me right now. I tried reading the documentation on the Ray Actors but are they created per node or just processes that are allocated the given number of GPU/CPU?

Thank you!

@f2010126
Copy link
Author

f2010126 commented Sep 5, 2023

@krfricke, @woshiyyya I tried your suggestions as well as what is suggested in the linked tutorial I added.
The tutorial link did not work out as I wanted but your advice worked! Thank you :)

A further doubt though, (should I raise this as a separate issue?):
In the example here, choosing batch size is mentioned but in the config space declared, it is not varied:

config={
        "layer_1_size": tune.choice([32, 64, 128]),
        "layer_2_size": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
    }

Is it possible to change the batch size of the DataModule by maybe adding it to searchable_lightning_config as:

searchable_lightning_config = (
    LightningConfigBuilder()
    .module(config={
        "layer_1_size": tune.choice([32, 64, 128]),
        "layer_2_size": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
    }).fit_params(datamodule=dm, config= {"batch_size": tune.choice([32, 64, 128])} )
    .build()
)

fit_params feeds into the Trainer.fit() of Lightning, so I do not think it would work, unless I use the fit_start hook of the Trainer to redeclare the Datamodule with my given batch size. Is there a better solution?

Or should I use the Vanilla Pytorch Lightning with Tune example instead if I want to add batch size as a Hyperparamter?

Thanks again!

@woshiyyya
Copy link
Member

I think for tuning batch size, you can define train_dataloader and test_dataloader in your LightningModule, and pass batch_size as an initialization arguments in LightningConfigBuilder().module(...).

By the way, for Ray 2.7, we are deprecating LightningTrainer, and support running lightning code with TorchTrainer (use your custom function) to provide more flexibility. In this case, you can pass any parameters including batch_size by the TorchTrainer(train_loop_config=) and define a search space over it. Please see https://docs.ray.io/en/master/train/getting-started-pytorch-lightning.html for more details.

@f2010126
Copy link
Author

@woshiyyya , yes I migrated to the TorchTrainer, it's much more customisable. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't train Ray Train Related Issue tune Tune-related issues
Projects
None yet
Development

No branches or pull requests

3 participants