New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TorchServe fails to start multiple workers threads on multiple GPUs with large model. #71
Comments
We only do a round robin distribution of workers on GPUs: Looks like it needs testing with large models. |
It would be helpful for testing and diagnosis if we also got logging when a worker successfully attached to a GPU - right now I can only see the intended device target when there's a failure. There's no way to verify correct behavior from the logs. |
@fbbradheintz Can you share the model archive you tested this with? |
The files are very large. I'll be a while uploading. |
@mycpuorg - Your email inbox should contain a Dropbox link to a folder containing:
Please let me know if there's anything else you need in relation to this. |
We are currently analyzing this issue with the shared mar file. |
We need to specify the gpu id while creating the torch.device object so that every worker uses a different GPU. Following small change in your custom model handler resolves the issue :
After the above changes I was able to register the fairseq model with TorchServe.
Could you please try the above changes and confirm? Note: We have updated the MNIST example's custom handler to use similar change as well in "stage_release" branch. |
@harshbafna Is this the java properties or the context.system_properties? It will help to update one of the examples to use gpus using the proposed sol |
@chauhang : It is contex.system_properties. As indicated in the previous comment. We have already updated the custom handler example (MNIST digit classifier) in "stage_release" branch https://github.com/pytorch/serve/blob/stage_release/examples/image_classifier/mnist/mnist_handler.py All the default handlers have been updated to use the same as well. |
I can confirm that this works with the changes prescribed by @harshbafna - thank you for the added guidance. The thing I don't see is documentation. If we don't tell people about this quirk in custom handlers, they're going to the same thing I did, and file a bug like I did. Please document this, preferably in examples/image_classifier/mnist/README.md, and then we can close this after merge. |
@fbbradheintz , Could you please review the following commit : 7a029ef |
Looks good! Please close after merging. |
Stage release - Merge Periodic Merge From Stage Release Branch. Listing the highlights of this merge below. When things settle down a bit I expect this will slow down and we can give a merge a new rc version. Highlights: Fixed: Benchmarks have dependency on Mxnet #72 TorchServe fails to start multiple workers threads on multiple GPUs with large model #71 Java concurrency crash when attempting batch processing #66 Add handler for audio models with an example #60 Undocumented options for config.properties #55 Can't log custom metrics #53 Add example for Custom Service #49 WIP: benchmark dependencies install script failing on fresh ubuntu 18.04 #36 Unable to install on GPU machine #30 Incorrect docs for --model-store option #23
Hi @fbbradheintz, I am interested in deploying my Fairseq models with TorchServe. Would it be possible to share some examples or pointers that you used above? I noticed that in the MNIST example, custom model file is given but it is only a single file. What should be the correct way to do it when we have multiple files for the model definition like in Fairseq? |
@okgrammer, you can supply all the other dependency files using --extra-files param as a comma separated list while creating the model-archive. You can refer the image_segmenter example, where it supplies multiple |
On a c5.12xlarge instance, I was able to run 16 instances of the FairSeq English-to-German translation model, all simultaneously running translations. This model's weights take up about 2.5GB on disk (though its resident footprint in memory seems smaller).
Attempting a similar feat on a p3.8xlarge turned out to be impossible. I could get a single instance running, but if I attempted to get even 4 workers running, they crash repeatedly with OOMEs:
On digging through the logs, it appears that it's attempting to start all workers on the same GPU. The following is the output of
grep GPU ts_log.log
:Note that multiple workers (W-9000, W-9001, W-9003) are shown, but only one GPU turns up (GPU 0). The p3.8xlarge has 4 GPUs.
I attempted to use arguments of the Management API, such as
number_gpu=4
, to fix this, but nothing worked. Same result every time.The text was updated successfully, but these errors were encountered: