Skip to content

Commit

Permalink
fix(api): run PyTorch GC on ROCm devices (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 11, 2023
1 parent de61e38 commit 00be4f4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
1 change: 1 addition & 0 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def load_gfpgan(
arch="clean",
bg_upsampler=None,
channel_multiplier=2,
device=device.torch_str(),
model_path=face_path,
upscale=upscale.face_outscale,
)
Expand Down
9 changes: 8 additions & 1 deletion api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,16 @@ def sess_options(self, cache=True) -> SessionOptions:
return sess

def torch_str(self) -> str:
# TODO: return cuda devices for ROCm as well
if self.device.startswith("cuda"):
if self.options is not None and "device_id" in self.options:
return f"{self.device}:{self.options['device_id']}"

return self.device
elif self.device.startswith("rocm"):
if self.options is not None and "device_id" in self.options:
return f"cuda:{self.options['device_id']}"

return "cuda"
else:
return "cpu"

Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,13 @@ def load_platforms(server: ServerContext) -> None:
platform_providers[potential] in providers
and potential not in server.block_platforms
):
if potential == "cuda":
if potential == "cuda" or potential == "rocm":
for i in range(torch.cuda.device_count()):
options = {
"device_id": i,
}

if server.memory_limit is not None:
if potential == "cuda" and server.memory_limit is not None:
options["arena_extend_strategy"] = "kSameAsRequested"
options["gpu_mem_limit"] = server.memory_limit

Expand Down
2 changes: 1 addition & 1 deletion api/requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ diffusers
onnx
# onnxruntime has many platform-specific packages
safetensors
timm
transformers

#### Upscaling and face correction
Expand All @@ -17,7 +18,6 @@ codeformer-perceptor
facexlib
gfpgan
realesrgan
timm

### Server packages ###
boto3
Expand Down

0 comments on commit 00be4f4

Please sign in to comment.