In [1]:
import os
import numpy as np
from PIL import Image
import sys
__file__path = os.path.abspath("../")
sys.path.append(__file__path)

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset

# Configuration 
from omegaconf import DictConfig, OmegaConf
import hydra
config_path = "../configs"
config_name = "config.yaml"
try: hydra.initialize(version_base=None, 
                      config_path=config_path)
except: pass
cfg = hydra.compose(config_name=config_name)
# print(OmegaConf.to_yaml(cfg))

from id_emb import IdEmbedder
from generator import Generator
from discriminator import Discriminator


class FaceSwapModel(nn.Module):
    def __init__(self, G, E, D):
        super(FaceSwapModel, self).__init__()
        self.G = G
        self.D = D
        self.E = E

    def forward(self, x_src, x_tar):
        z_src = self.E(x_src)
        swap = self.G(x_tar, z_src)
        disc = self.D(swap)
        z_swap = self.E(swap)
        return z_src, swap, z_swap, disc
    
    def get_discriminant(self, x):
        return self.D(x)

    def get_id_embedding(self, x):
        return self.E(x)
    
    def get_swap_face(self, x_src, x_tar):
        temb = self.E(x_src)
        return self.G(x_tar, temb)
    

/FaceSwap/swappers/SmoothSwap/models
/FaceSwap/swappers/SmoothSwap/models
/FaceSwap/swappers/SmoothSwap/models




In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# model
G = Generator(cfg)
D = Discriminator(cfg.discriminator.image_size)
E = IdEmbedder(cfg)
if cfg.id_emb.checkpoint_path:
    E.load_state_dict(cfg.id_emb.checkpoint_path)
E.requires_grad_(False)
E.eval()

cuda = torch.cuda.is_available()
if cuda:
    print("cuda")
    G.cuda()
    D.cuda()
    E.cuda()

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


cuda


In [4]:
fsm = FaceSwapModel(G, E, D)

# fsm = fsm.cuda()

# src.get_device(), tar.get_device()

# for p in fsm.G.parameters():
#     print(p.get_device())

In [5]:
src, tar = torch.rand((2, 3, 256, 256)).to(DEVICE), torch.rand((2, 3, 256, 256)).to(DEVICE)
fsm.get_swap_face(src, tar)

tensor([[[[ 0.5289,  0.0999,  0.3193,  ...,  1.0439,  0.3769,  0.6612],
          [ 0.3331,  0.7995,  0.0398,  ...,  0.6734,  0.1045,  0.4326],
          [ 0.3096,  0.8701, -0.2913,  ..., -0.1025,  0.0630,  0.5632],
          ...,
          [ 0.7968,  0.4074,  0.5252,  ...,  0.2419,  0.8396,  0.4551],
          [ 0.1080,  0.8045,  0.8402,  ...,  1.0944,  0.7276,  0.5265],
          [-0.0148,  0.8503, -0.0139,  ...,  0.9138,  0.8824,  0.2128]],

         [[ 0.8109,  0.7256,  0.4635,  ...,  0.7559,  0.7953,  0.2134],
          [ 1.0437,  0.3492,  0.3103,  ...,  0.0915,  0.5142, -0.0073],
          [ 0.1676,  0.3542,  0.8337,  ...,  0.4583,  0.9360,  0.4233],
          ...,
          [ 1.0549,  0.9396,  0.5832,  ...,  0.8090,  0.8225,  0.6771],
          [ 0.6588,  0.8847,  0.6304,  ...,  0.7612,  0.3081,  0.1469],
          [ 0.4728,  0.7343,  0.5883,  ...,  0.9456,  0.2606,  0.5176]],

         [[ 0.5006,  0.1963,  0.8221,  ...,  0.7159,  1.1728,  0.7141],
          [-0.0568,  0.6312,  