In [1]:
import os
import sys
import numpy as np
import cv2



import torch
from tqdm import tqdm

# Dataset
from dataset.frames_dataset_with_lmks import FramesDataset
from torch.utils.data import DataLoader

# Models

from modulesiris.generator import OcclusionAwareGenerator
import imageio

# Loss
device = 'cuda'

def load_model(ckpt):
    checkpoint = torch.load(ckpt, map_location=device)
     # Model here
    dense_motion_params = {"block_expansion":64, "max_features": 1024, "num_blocks":5, "scale_factor":0.25, "using_first_order_motion":False,"using_thin_plate_spline_motion":True}
    G = OcclusionAwareGenerator(num_channels=3, num_kp=8, block_expansion=64, max_features=512, num_down_blocks=2,
                 num_bottleneck_blocks=6, estimate_occlusion_map=True, dense_motion_params=dense_motion_params, estimate_jacobian=True)
    
    G.load_state_dict(checkpoint["G_state_dict"], strict=True)
    G = G.to(device)
    return G

def draw_landmarks(img, lmks, color=(255,0,0)):
    img = np.ascontiguousarray(img)
    for a in lmks:
        cv2.circle(img,(int(round(a[0])), int(round(a[1]))), 1, color, -1, lineType=cv2.LINE_AA)

    return img


def vis(x, x_prime_hat, kp_src, kp_driving):
        """
        x: Bx3x1xHxW
        x_prime: Bx3x1xHxW
        x_prime_hat: Bx3x1xHxW
        kp_src: Bx1x10x2
        kp_driving: Bx1x10x2
        """
        _,_,h,w = x.shape
        x = x.detach().cpu().numpy()
        x_prime_hat = x_prime_hat.detach().cpu().numpy()


        kp_src = kp_src.detach().cpu().numpy()
        kp_driving = kp_driving.detach().cpu().numpy()

        for i, (x1, x3, ks, kd) in enumerate(zip(x, x_prime_hat, kp_src, kp_driving)):
            x1 = (np.transpose(x1, (1,2,0))*255.0).astype(np.uint8)
            x3 = (np.transpose(x3, (1,2,0))*255.0).astype(np.uint8)
            ks = (ks+1) * np.array([w,h]) / 2.0
            kd = (kd+1) * np.array([w,h]) / 2.0
            x1 = draw_landmarks(x1, ks)
            x3 = draw_landmarks(x3, ks)
            x3 = draw_landmarks(x3, kd, color=(0,255,255))

            img = np.hstack((x1, x3))
            return img



def synthize_kp_driving(kp_src, delta_x=None, delta_y=None):
    kp_driving = {}
    kp_driving["value"] =  kp_src["value"].clone()

    if delta_x is  None:
        delta_x = np.random.uniform(-0.15, 0.15)
    if delta_y is  None:
        delta_y = np.random.uniform(-0.15, 0.15)

    kp_driving["value"][:,-2:,0] = kp_driving["value"][:,-2:,0] + delta_x
    kp_driving["value"][:,-2:,1] = kp_driving["value"][:,-2:,1] + delta_y
    return kp_driving


from skimage import io, img_as_float32

# Load model
# ckpt = "checkpoints/motion_iris/11.pth.tar"
# ckpt = "checkpoints/motion_iris_fix_motion_equation/15.pth.tar"
# ckpt = "checkpoints/motion_iris_fix_motion_test/8.pth.tar"
# ckpt = "checkpoints/motion_iris_thin_plate_spline_motion/20.pth.tar"
ckpt = "checkpoints/motion_iris_thin_plate_spline_motion_more_control_points/3.pth.tar"



G = load_model(ckpt = ckpt)
G.eval()

# Dataset
root_dir = "./data/eth_motion_data"
augmentation_params = {"flip_param" : {"horizontal_flip": False, "time_flip":False}, "jitter_param" :{"brightness":0.1, "contrast":0.1, "saturation":0.1, "hue":0.1}}
dataset = FramesDataset(root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
             random_seed=0, pairs_list=None, augmentation_params=augmentation_params)


# batchdata = dataset[index]
# _, x = batchdata["driving"], batchdata["source"]
# _, kp_src = batchdata["lmks_driving"], batchdata["lmks_source"]

# Fake image
def synthesize_image(src_path, delta_x, delta_y):
    src = cv2.imread(src_path)
    src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
    src = cv2.resize(src, (256, 256)) #BxCxDxH,W
    src = src/255.0
    x  = np.transpose(src, (2,0,1)) # 3x256x256
    kp_src = {"value": torch.FloatTensor([[-0.2136181,-0.31389177],[0.373667,-0.22871798]])}


    x = torch.FloatTensor(x)
    kp_src["value"] = torch.FloatTensor(kp_src["value"])
    x = x.to(device) 
    kp_src["value"] = kp_src["value"].to(device)
    kp_src["value"].unsqueeze_(0) 
    x.unsqueeze_(0) 
    kp_driving = synthize_kp_driving(kp_src, delta_x, delta_y)
    kp_driving["value"] = kp_driving["value"].to(device)
    prediction = G(source_image=x, kp_driving=kp_driving, kp_source=kp_src)
    img_out = vis(x, prediction["prediction"], kp_src["value"], kp_driving["value"])

    return img_out, prediction


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


YOOOOOOOOOOO dense_motion_params :{'block_expansion': 64, 'max_features': 1024, 'num_blocks': 5, 'scale_factor': 0.25, 'using_first_order_motion': False, 'using_thin_plate_spline_motion': True}
Use predefined train-test split.


In [3]:
from ipywidgets import interact, interactive, fixed, interact_manual, Image
import cv2


@interact(delta_x=(-0.15, 0.15, 0.005), delta_y=(-0.15, 0.15, 0.005))
def synthesize(delta_x, delta_y):
    img, prediction = synthesize_image(src_path="trinh.png", delta_x=delta_x, delta_y=delta_y)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img_bytes = cv2.imencode('.png', img)[1].tobytes()
    widget = Image(value=img_bytes, format='png')
    return (widget)
    # return x

interactive(children=(FloatSlider(value=0.0, description='delta_x', max=0.15, min=-0.15, step=0.005), FloatSli…