diff --git a/facexlib/alignment/__init__.py b/facexlib/alignment/__init__.py index e989ca7..90b5867 100644 --- a/facexlib/alignment/__init__.py +++ b/facexlib/alignment/__init__.py @@ -16,7 +16,7 @@ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=N model_path = load_file_from_url( url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) - model.load_state_dict(torch.load(model_path)['state_dict'], strict=True) + model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) model.eval() model = model.to(device) return model