Skip to content

Commit

Permalink
Change models and variables to float16
Browse files Browse the repository at this point in the history
  • Loading branch information
Ghassen-Chaabouni committed Jun 5, 2024
1 parent 981027f commit c654293
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/dot/gpen/face_model/face_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def load_model(self, channel_multiplier=2, narrow=1, use_gpu=True):
self.model.load_state_dict(pretrained_dict)
self.model.eval()

self.model = self.model.half()

def process(self, img, use_gpu=True):
img = cv2.resize(img, (self.resolution, self.resolution))
img_t = self.img2tensor(img, use_gpu)

with torch.no_grad():
out, __ = self.model(img_t)
out, __ = self.model(img_t.half())

out = self.tensor2img(out)

Expand Down
2 changes: 1 addition & 1 deletion src/dot/gpen/face_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
(self.weight * self.scale).half(),
bias=self.bias,
stride=self.stride,
padding=self.padding,
Expand Down
2 changes: 1 addition & 1 deletion src/dot/gpen/retinaface/facemodels/retinaface.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, cfg=None, phase="train"):
import torchvision.models as models

backbone = models.resnet50(pretrained=cfg["pretrain"])

# backbone = backbone.half()
self.body = _utils.IntermediateLayerGetter(backbone, cfg["return_layers"])
in_channels_stage2 = cfg["in_channel"]
in_channels_list = [
Expand Down
6 changes: 5 additions & 1 deletion src/dot/gpen/retinaface/retinaface_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True):
self.load_model(load_to_cpu=True)
self.net = self.net.cpu()

self.net = self.net.half()

def check_keys(self, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(self.net.state_dict().keys())
Expand Down Expand Up @@ -71,6 +73,8 @@ def load_model(self, load_to_cpu=False):
self.net.load_state_dict(pretrained_dict, strict=False)
self.net.eval()

self.net = self.net.half()

def detect(
self,
img_raw,
Expand All @@ -96,7 +100,7 @@ def detect(
img = img.cpu()
scale = scale.cpu()

loc, conf, landms = self.net(img) # forward pass
loc, conf, landms = self.net(img.half()) # forward pass

priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
Expand Down
2 changes: 2 additions & 0 deletions src/dot/simswap/fs_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def initialize(
self.netArc = netArc_checkpoint
self.netArc = self.netArc.to(device)
self.netArc.eval()
self.netArc = self.netArc.half()

pretrained_path = ""
self.load_network(self.netG, "G", opt_which_epoch, pretrained_path)
self.netG = self.netG.half()
return

def forward(self, img_id, img_att, latent_id, latent_att, for_G=False):
Expand Down
1 change: 1 addition & 0 deletions src/dot/simswap/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def load_network(self, network, network_label, epoch_label, save_dir=""):

print(sorted(not_initialized))
network.load_state_dict(model_dict)
network = network.half()

def update_learning_rate(self):
pass
9 changes: 7 additions & 2 deletions src/dot/simswap/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def create_model( # type: ignore
)

self.net.eval()
self.net = self.net.half()
else:
self.net = None

Expand Down Expand Up @@ -141,7 +142,7 @@ def change_option(self, image: np.array, **kwargs) -> None:

# create latent id
img_id_downsample = F.interpolate(img_id, size=(112, 112))
source_image = self.model.netArc(img_id_downsample)
source_image = self.model.netArc(img_id_downsample.half())
source_image = source_image.detach().to("cpu")
source_image = source_image / np.linalg.norm(
source_image, axis=1, keepdims=True
Expand Down Expand Up @@ -186,7 +187,11 @@ def process_image(self, image: np.array, **kwargs) -> np.array:
)[None, ...].cpu()

swap_result = self.model(
None, frame_align_crop_tenor, self.source_image, None, True
None,
frame_align_crop_tenor.half(),
self.source_image.half(),
None,
True,
)[0]
swap_result_list.append(swap_result)
frame_align_crop_tenor_list.append(frame_align_crop_tenor)
Expand Down
2 changes: 1 addition & 1 deletion src/dot/simswap/util/reverse2original.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def reverse2wholeimage(
if use_mask:
source_img_norm = norm(source_img, use_gpu=use_gpu)
source_img_512 = F.interpolate(source_img_norm, size=(512, 512))
out = pasring_model(source_img_512)[0]
out = pasring_model(source_img_512.half())[0]
parsing = out.squeeze(0).argmax(0)

tgt_mask = encode_segmentation_rgb(parsing, device)
Expand Down
3 changes: 2 additions & 1 deletion src/dot/simswap/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def load_parsing_model(path, use_mask, use_gpu):
net.load_state_dict(torch.load(path, map_location=torch.device("cpu")))

net.eval()
net = net.half()
return net
else:
return None
Expand Down Expand Up @@ -180,7 +181,7 @@ def crop_align(

# create latent id
img_id_downsample = F.interpolate(img_id, size=(112, 112))
id_vector = swap_model.netArc(img_id_downsample)
id_vector = swap_model.netArc(img_id_downsample.half())
id_vector = id_vector.detach().to("cpu")
id_vector = id_vector / np.linalg.norm(id_vector, axis=1, keepdims=True)

Expand Down

0 comments on commit c654293

Please sign in to comment.