Skip to content

Commit

Permalink
add custom device support for RetinaFace class in detection (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
caock committed Apr 15, 2023
1 parent 29d792e commit c2c767d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions facexlib/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half)
model = RetinaFace(network_name='resnet50', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
elif model_name == 'retinaface_mobile0.25':
model = RetinaFace(network_name='mobile0.25', half=half)
model = RetinaFace(network_name='mobile0.25', half=half, device=device)
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
Expand Down
22 changes: 11 additions & 11 deletions facexlib/detection/retinaface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def generate_config(network_name):

Expand Down Expand Up @@ -72,7 +70,9 @@ def generate_config(network_name):

class RetinaFace(nn.Module):

def __init__(self, network_name='resnet50', half=False, phase='test'):
def __init__(self, network_name='resnet50', half=False, phase='test', device=None):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device

super(RetinaFace, self).__init__()
self.half_inference = half
cfg = generate_config(network_name)
Expand All @@ -83,7 +83,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'):
self.phase = phase
self.target_size, self.max_size = 1600, 2150
self.resize, self.scale, self.scale1 = 1., None, None
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device)
self.reference = get_reference_facial_points(default_square=True)
# Build network.
backbone = None
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(self, network_name='resnet50', half=False, phase='test'):
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])

self.to(device)
self.to(self.device)
self.eval()
if self.half_inference:
self.half()
Expand Down Expand Up @@ -145,19 +145,19 @@ def forward(self, inputs):
def __detect_faces(self, inputs):
# get scale
height, width = inputs.shape[2:]
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device)
tmp = [width, height, width, height, width, height, width, height, width, height]
self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)

# forawrd
inputs = inputs.to(device)
inputs = inputs.to(self.device)
if self.half_inference:
inputs = inputs.half()
loc, conf, landmarks = self(inputs)

# get priorbox
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
priors = priorbox.forward().to(device)
priors = priorbox.forward().to(self.device)

return loc, conf, landmarks, priors

Expand Down Expand Up @@ -197,7 +197,7 @@ def detect_faces(
use_origin_size=True,
):
image, self.resize = self.transform(image, use_origin_size)
image = image.to(device)
image = image.to(self.device)
if self.half_inference:
image = image.half()
image = image - self.mean_tensor
Expand Down Expand Up @@ -316,7 +316,7 @@ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, us
"""
# self.t['forward_pass'].tic()
frames, self.resize = self.batched_transform(frames, use_origin_size)
frames = frames.to(device)
frames = frames.to(self.device)
frames = frames - self.mean_tensor

b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
Expand Down

0 comments on commit c2c767d

Please sign in to comment.