From 2c34f33c7d09f014fa535029db5a8b8b78a0e1b4 Mon Sep 17 00:00:00 2001 From: Jake Hall <60800749+jakeh-gc@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:00:51 +0100 Subject: [PATCH] torch.load to correct device. Use `map_location` in `torch.load` to respect user's choice of device. This fixes a failure when using torch-cpu to load a model with cuda state. --- facexlib/alignment/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/facexlib/alignment/__init__.py b/facexlib/alignment/__init__.py index e989ca7..90b5867 100644 --- a/facexlib/alignment/__init__.py +++ b/facexlib/alignment/__init__.py @@ -16,7 +16,7 @@ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=N model_path = load_file_from_url( url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) - model.load_state_dict(torch.load(model_path)['state_dict'], strict=True) + model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) model.eval() model = model.to(device) return model