Skip to content

Commit

Permalink
Added GPU selection feature to python inference (#321)
Browse files Browse the repository at this point in the history
* Added GPU selection feature to python inference

* pylint pep8 fixes

* pep8 fixes
  • Loading branch information
cobanov committed May 24, 2022
1 parent bc77ca5 commit 6b15fc6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 5 additions & 1 deletion inference_realesrgan.py
Expand Up @@ -39,6 +39,9 @@ def main():
type=str,
default='auto',
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
parser.add_argument(
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')

args = parser.parse_args()

# determine models according to model names
Expand Down Expand Up @@ -71,7 +74,8 @@ def main():
tile=args.tile,
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32)
half=not args.fp32,
gpu_id=args.gpu_id)

if args.face_enhance: # Use GFPGAN for face enhancement
from gfpgan import GFPGANer
Expand Down
17 changes: 15 additions & 2 deletions realesrgan/utils.py
Expand Up @@ -26,7 +26,16 @@ class RealESRGANer():
half (float): Whether to use half precision during inference. Default: False.
"""

def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False, device=None):
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
Expand All @@ -35,7 +44,11 @@ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=1
self.half = half

# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
Expand Down

0 comments on commit 6b15fc6

Please sign in to comment.