Skip to content

Commit

Permalink
allow paralllevel as 1 to start torchrun npro-per-node (#2608)
Browse files Browse the repository at this point in the history
  • Loading branch information
lxning committed Oct 9, 2023
1 parent f57240f commit 504c734
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 23 deletions.
Expand Up @@ -32,7 +32,7 @@ public class ModelConfig {
*/
private List<Integer> deviceIds;
/** this variable is auto calculated based on torchrun nproc-per-node. */
private int parallelLevel = 1;
private int parallelLevel;
/** the model parallel type can be tp, pp, pptp */
private ParallelType parallelType = ParallelType.NONE;
/** torchrun config */
Expand Down Expand Up @@ -259,9 +259,8 @@ public int getParallelLevel() {
}

public void setParallelLevel(int parallelLevel) {
if (parallelLevel <= 0) {
logger.warn("Invalid parallelLevel:{}, set as 1", parallelLevel);
this.parallelLevel = 1;
if (parallelLevel < 0) {
logger.warn("Invalid parallelLevel:{}, set as 0", parallelLevel);
return;
}
this.parallelLevel = parallelLevel;
Expand Down
Expand Up @@ -43,7 +43,7 @@ public void testInvalidYamlConfig() throws InvalidModelException, IOException {
Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100);
Assert.assertEquals(modelConfig.getResponseTimeout(), 120);
Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU);
Assert.assertEquals(modelConfig.getParallelLevel(), 1);
Assert.assertEquals(modelConfig.getParallelLevel(), 0);
Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP);
Assert.assertNull(modelConfig.getDeviceIds());
}
Expand Down
Expand Up @@ -40,7 +40,7 @@ public class Model {
private int maxWorkers;
private int batchSize;
private int maxBatchDelay;
private int parallelLevel = 1;
private int parallelLevel;
private long maxRetryTimeoutInMill = 5 * 60 * 1000;
private long clientTimeoutInMills;
private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE;
Expand Down Expand Up @@ -71,7 +71,7 @@ public Model(ModelArchive modelArchive, int queueSize) {
this.modelArchive = modelArchive;
if (modelArchive != null && modelArchive.getModelConfig() != null) {
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
if (modelArchive.getModelConfig().getParallelLevel() > 1
if (modelArchive.getModelConfig().getParallelLevel() > 0
&& modelArchive.getModelConfig().getParallelType()
!= ModelConfig.ParallelType.NONE) {
parallelLevel = modelArchive.getModelConfig().getParallelLevel();
Expand Down Expand Up @@ -138,7 +138,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay());
modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout());
if (parallelLevel > 1) {
if (parallelLevel > 0) {
modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel);
}

Expand Down
Expand Up @@ -461,7 +461,7 @@ public CompletableFuture<Integer> updateModel(
throw new ModelVersionNotFoundException(
"Model version: " + versionId + " does not exist for model: " + modelName);
}
if (model.getParallelLevel() > 1 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
/**
* Current capacity check for LMI is based on single node. TODO: multiple nodes check
* will be based on --proc-per-node + numCores.
Expand Down
Expand Up @@ -211,14 +211,17 @@ private void addThreads(
int gpuId = -1;

if (maxGpu > 0) {
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 1) {
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 0) {
gpuId =
model.getGpuCounter()
.getAndAccumulate(
maxGpu,
(prev, maxGpuId) ->
(prev + model.getParallelLevel()) % maxGpuId);
if (model.getParallelLevel() == 1) {
(prev + model.getParallelLevel() > 0
? model.getParallelLevel()
: 1)
% maxGpuId);
if (model.getParallelLevel() == 0) {
gpuId = model.getDeviceIds().get(gpuId);
}
} else {
Expand All @@ -235,7 +238,7 @@ private void addThreads(
aggregator = new BatchAggregator(model);
}
int currentPort =
model.getParallelLevel() > 1
model.getParallelLevel() > 0
? configManager.isDebug()
? distributionPort.get()
: distributionPort.getAndAdd(model.getParallelLevel())
Expand Down
Expand Up @@ -115,9 +115,9 @@ public void startWorker(int port, String deviceIds)
modelPath.getAbsolutePath(),
model.getModelArchive().getManifest().getModel().getHandler())));

if (model.getParallelLevel() > 1) {
if (model.getParallelLevel() > 0) {
attachRunner(argl, envp, port, deviceIds);
} else if (model.getParallelLevel() == 1) {
} else if (model.getParallelLevel() == 0) {
argl.add(EnvironmentUtils.getPythonRunTime(model));
}

Expand Down Expand Up @@ -153,7 +153,7 @@ public void startWorker(int port, String deviceIds)
argl.add(configManager.getMetricsConfigPath());

try {
latch = new CountDownLatch(model.getParallelLevel());
latch = new CountDownLatch(model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);

String[] args = argl.toArray(new String[argl.size()]);
String[] envs = envp.toArray(new String[envp.size()]);
Expand Down
Expand Up @@ -99,7 +99,9 @@ public WorkerThread(
this.listener = listener;
startTime = System.currentTimeMillis();
lifeCycle = new WorkerLifeCycle(configManager, model);
replies = new ArrayBlockingQueue<>(model.getParallelLevel());
replies =
new ArrayBlockingQueue<>(
model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
this.workerThreadTimeMetric =
MetricCache.getInstance().getMetricFrontend("WorkerThreadTime");
this.workerLoadTimeMetric = MetricCache.getInstance().getMetricFrontend("WorkerLoadTime");
Expand Down Expand Up @@ -198,10 +200,10 @@ public void run() {
|| ((req.getCommand() == WorkerCommands.PREDICT
|| req.getCommand()
== WorkerCommands.STREAMPREDICT)
&& model.getParallelLevel() > 1
&& model.getParallelLevel() > 0
&& model.getParallelType()
!= ModelConfig.ParallelType.PP)
? model.getParallelLevel()
? model.getParallelLevel() > 0 ? model.getParallelLevel() : 1
: 1;
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) {
backendChannel.get(i).writeAndFlush(req).sync();
Expand Down Expand Up @@ -305,7 +307,10 @@ public void run() {
// WorkerThread is running in thread pool, the thread will be assigned to next
// Runnable once this worker is finished. If currentThread keep holding the reference
// of the thread, currentThread.interrupt() might kill next worker.
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
for (int i = 0;
backendChannel.size() > 0
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
i++) {
backendChannel.get(i).disconnect();
}
currentThread.set(null);
Expand Down Expand Up @@ -346,7 +351,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio
String modelName = model.getModelName();
String modelVersion = model.getVersion();
setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK);
final int parallelLevel = model.getParallelLevel();
final int parallelLevel = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
final CountDownLatch latch = new CountDownLatch(parallelLevel);
final int responseBufferSize = configManager.getMaxResponseSize();
try {
Expand Down Expand Up @@ -449,7 +454,10 @@ public int getPid() {
public void shutdown() {
running.set(false);
setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK);
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
for (int i = 0;
backendChannel.size() > 0
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
i++) {
if (backendChannel.get(i) != null) {
backendChannel.get(i).close();
}
Expand Down Expand Up @@ -522,7 +530,7 @@ public void retry() {

private String getDeviceIds() {
List<Integer> deviceIds;
if (gpuId == -1 || model.getParallelLevel() == 1) {
if (gpuId == -1 || model.getParallelLevel() == 0) {
return null;
} else if (model.isHasCfgDeviceIds()) {
return model.getDeviceIds().subList(gpuId, gpuId + model.getParallelLevel()).stream()
Expand Down

0 comments on commit 504c734

Please sign in to comment.