[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/williamyang1991/DualStyleGAN/blob/master/notebooks/inference_playground.ipynb)

code is mainly modified from [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel/blob/master/notebooks/inference_playground.ipynb)

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.chdir('../')
CODE_DIR = 'VToonify'
device = 'cuda'

In [None]:
!git clone https://github.com/williamyang1991/VToonify.git $CODE_DIR

In [None]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 
!pip install wget

In [None]:
os.chdir(f'./{CODE_DIR}')
MODEL_DIR = os.path.join(os.path.dirname(os.getcwd()), CODE_DIR, 'checkpoint')
DATA_DIR = os.path.join(os.path.dirname(os.getcwd()), CODE_DIR, 'data')
OUT_DIR = os.path.join(os.path.dirname(os.getcwd()), CODE_DIR, 'output')

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(".")
sys.path.append("..")

import argparse
import numpy as np
import cv2
import dlib
import torch
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm
from model.vtoonify import VToonify
from model.bisenet.model import BiSeNet
from model.encoder.align_all_parallel import align_face
from util import save_image, load_image, visualize, load_psp_standalone, get_video_crop_parameter, tensor2cv2

In [None]:
def get_download_model_command(file_id, file_name):
    """ Get wget download command for downloading the desired model and save to directory ../checkpoint/. """
    current_directory = os.getcwd()
    save_path = MODEL_DIR
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
    return url

MODEL_PATHS = {
    "encoder": {"id": "1NgI4mPkboYvYw3MWcdUaQhkr0OWgs9ej", "name": "encoder.pt"},
    "faceparsing": {"id": "1jY0mTjVB8njDh6e0LP_2UxuRK3MnjoIR", "name": "faceparsing.pth"},
    "arcane_exstyle": {"id": "1TC67wRJkdmNRZTqYMUEFkrhWRKKZW40c", "name": "exstyle_code.npy"},
    "caricature_exstyle": {"id": "1xr9sx_WmRYJ4qHGTtdVQCSxSo4HP3-ip", "name": "exstyle_code.npy"},
    "cartoon_exstyle": {"id": "1BuCeLk3ASZcoHlbfT28qNru4r5f-hErr", "name": "exstyle_code.npy"},
    "pixar_exstyle": {"id": "1yTaKuSrL7I0i0RYEEK5XD6GI-y5iNUbj", "name": "exstyle_code.npy"},
    "arcane000": {"id": "1pF4fJ8acmawMsjjXo4HXRIOXeZR8jLVh", "name": "generator.pt"},
    "arcane077": {"id": "16rLTF2oC0ZeurnM6hjrfrc8BxtW8P8Qf", "name": "generator.pt"},
    "caricature039": {"id": "1C1E4WEoDWzl0nAxR9okKffFmlMOENbeF", "name": "generator.pt"},
    "caricature068": {"id": "1B1ko1x8fX2aJ4BYCL12AnknVAi3qQc8W", "name": "generator.pt"},
    "cartoon026": {"id": "1YJYODh_vEyUrL0q02okjcicpJhdYY8An", "name": "generator.pt"},
    "cartoon299": {"id": "101qMUMfcI2qDxEbfCBt5mOg2aSqdTaIt", "name": "generator.pt"},
    "pixar052": {"id": "16j_l1x0DD0PjwO8YdplAk69sh3-v95rr", "name": "generator.pt"},
    "cartoon": {"id": "11s0hwhZWTLacMAzZH4OU-o3Qkp54h30J", "name": "generator.pt"},
}

In [None]:
%matplotlib inline

# PART I - Style Transfer with specialized VToonify-D model

## Step 1: Select Style Type

In [None]:
style_types = ['cartoon026',      # balanced
               'cartoon299',      # big eyes
               'arcane000',       # for female
               'arcane077',       # for male
               'pixar052',
               'caricature039',   # big mouth
               'caricature068',   # balanced
              ]
style_type = style_types[0]

## Step 2: Download Pretrained Models 
As part of this repository, we provide pretrained models. We'll download the model and save them to the folder `../checkpoint/`.

In [None]:
# download pSp encoder and face parsinf network
path = MODEL_PATHS["encoder"]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
!{download_command}
path = MODEL_PATHS["faceparsing"]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
!{download_command}

In [None]:
# download vtoonify
path = MODEL_PATHS[style_type]
download_command = get_download_model_command(file_id=path["id"], file_name = style_type + '_' + path["name"])
!{download_command}
# download extrinsic style code
path = MODEL_PATHS[style_type[:-3]+'_exstyle']
download_command = get_download_model_command(file_id=path["id"], file_name = style_type[:-3] + '_' + path["name"])
!{download_command}

## Step 3: Load Pretrained Model
We assume that you have downloaded all relevant models and placed them in the directory defined by the above dictionary.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
    ])

vtoonify = VToonify(backbone = 'dualstylegan')
vtoonify.load_state_dict(torch.load(os.path.join(MODEL_DIR, style_type+'_generator.pt'), map_location=lambda storage, loc: storage)['g_ema'])
vtoonify.to(device)

parsingpredictor = BiSeNet(n_classes=19)
parsingpredictor.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'faceparsing.pth'), map_location=lambda storage, loc: storage))
parsingpredictor.to(device).eval()

modelname = './checkpoint/shape_predictor_68_face_landmarks.dat'
if not os.path.exists(modelname):
    wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
    zipfile = bz2.BZ2File(modelname+'.bz2')
    data = zipfile.read()
    open(modelname, 'wb').write(data) 
landmarkpredictor = dlib.shape_predictor(modelname)

pspencoder = load_psp_standalone(os.path.join(MODEL_DIR, 'encoder.pt'), device)    

exstyles = np.load(os.path.join(MODEL_DIR, style_type[:-3]+'_exstyle_code.npy'), allow_pickle='TRUE').item()
stylename = list(exstyles.keys())[int(style_type[-3:])]
exstyle = torch.tensor(exstyles[stylename]).to(device)
with torch.no_grad():  
    exstyle = vtoonify.zplus2wplus(exstyle)
    
print('Model successfully loaded!')

## Step 4: Image Toonification

### Visualize and Rescale Input
We rescale the input image to make it fit our pre-trained model

In [None]:
image_path = './data/077436.jpg'
original_image = load_image(image_path)

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(original_image[0])
plt.show()

In [None]:
frame = cv2.imread(image_path)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

scale = 1
kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
# We detect the face in the image, and resize the image so that the eye distance is 64 pixels.
# Centered on the eyes, we crop the image to almost 400x400 (based on args.padding).
paras = get_video_crop_parameter(frame, landmarkpredictor, padding=[200,200,200,200])
if paras is not None:
    h,w,top,bottom,left,right,scale = paras
    H, W = int(bottom-top), int(right-left)
    # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
    if scale <= 0.75:
        frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
    if scale <= 0.375:
        frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
    frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
    x = transform(frame).unsqueeze(dim=0).to(device)
else:
    print('no face detected!')

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(x[0].cpu())
plt.show()

### Perform Inference

In [None]:
with torch.no_grad():
    I = align_face(frame, landmarkpredictor)
    I = transform(I).unsqueeze(dim=0).to(device)
    s_w = pspencoder(I)
    s_w = vtoonify.zplus2wplus(s_w)
    s_w[:,:7] = exstyle[:,:7]
    # parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
    # followed by downsampling the parsing maps
    x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
                        scale_factor=0.5, recompute_scale_factor=False).detach()
    # we give parsing maps lower weight (1/16)
    inputs = torch.cat((x, x_p/16.), dim=1)
    # d_s has no effect when backbone is toonify
    y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = 0.5)        
    y_tilde = torch.clamp(y_tilde, -1, 1)

In [None]:
plt.figure(figsize=(10,10),dpi=60)
visualize(y_tilde[0].cpu())
plt.show()

## Step 5: Video Toonification

### Visualize and Rescale Input
We rescale the input video to make it fit our pre-trained model

In [None]:
video_path = './data/529.mp4'
video_cap = cv2.VideoCapture(video_path)
num = int(video_cap.get(7))

In [None]:
success, frame = video_cap.read()
if success == False:
    assert('load video frames error')
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(transform(frame))
plt.show()

In [None]:
scale = 1
kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
# We proprocess the video by detecting the face in the first frame, 
# and resizing the frame so that the eye distance is 64 pixels.
# Centered on the eyes, we crop the first frame to almost 400x400 (based on args.padding).
# All other frames use the same resizing and cropping parameters as the first frame.
paras = get_video_crop_parameter(frame, landmarkpredictor, padding=[200,200,200,200])
if paras is None:
    print('no face detected!')
else:
    h,w,top,bottom,left,right,scale = paras
    H, W = int(bottom-top), int(right-left)
# for HR video, we apply gaussian blur to the frames to avoid flickers caused by bilinear downsampling
# this can also prevent over-sharp stylization results. 
if scale <= 0.75:
    frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
if scale <= 0.375:
    frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(transform(frame))
plt.show()

### Perform Inference

In [None]:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(os.path.join(OUT_DIR, 'result.mp4'), fourcc, video_cap.get(5), (4*W, 4*H))
batch_size = 4

In [None]:
with torch.no_grad():
    for i in tqdm(range(num)):
        if i == 0:        
            I = align_face(frame, landmarkpredictor)
            I = transform(I).unsqueeze(dim=0).to(device)
            s_w = pspencoder(I)
            s_w = vtoonify.zplus2wplus(s_w)
            s_w[:,:7] = exstyle[:,:7]
        else
            if scale <= 0.75:
                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
            if scale <= 0.375:
                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
            frame = cv2.resize(frame, (w, h))[top:bottom, left:right]

        batch_frames += [transform(frame).unsqueeze(dim=0).to(device)]

        if len(batch_frames) == batch_size or (i+1) == num:
            x = torch.cat(batch_frames, dim=0)
            batch_frames = []
            # parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
            # followed by downsampling the parsing maps
            x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
                            scale_factor=0.5, recompute_scale_factor=False).detach()
            # we give parsing maps lower weight (1/16)
            inputs = torch.cat((x, x_p/16.), dim=1)
            # d_s has no effect when backbone is toonify
            y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = args.style_degree)       
            y_tilde = torch.clamp(y_tilde, -1, 1)
            for k in range(y_tilde.size(0)):
                videoWriter.write(tensor2cv2(y_tilde[k].cpu()))
videoWriter.release()
video_cap.release()

In [None]:
viz = torchvision.utils.make_grid(y_tilde, 2, 2)
plt.figure(figsize=(10,10),dpi=120)
visualize(viz.cpu())
plt.show()

### Find the stylized video in `./output/result.mp4`

# PART II - Style control with VToonify-Dsd model

## Step 1: Download Pretrained Models 
As part of this repository, we provide pretrained models. We'll download the model and save them to the folder `../checkpoint/`.

In [None]:
# if you haved downloaded the encoder and faceparsing model in PART I, skip this step
if False:
    path = MODEL_PATHS["encoder"]
    download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
    !{download_command}
    path = MODEL_PATHS["faceparsing"]
    download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
    !{download_command}

In [None]:
# download the style code and the vtoonify-Dsd
path = MODEL_PATHS['cartoon_exstyle']
download_command = get_download_model_command(file_id=path["id"], file_name = 'cartoon_cartoon_exstyle.npy')
!{download_command}
path = MODEL_PATHS['cartoon']
download_command = get_download_model_command(file_id=path["id"], file_name = 'cartoon_generator.pt')
!{download_command}

## Step 2: Load Pretrained Model
We assume that you have downloaded all relevant models and placed them in the directory defined by the above dictionary.

In [None]:
# if you haved load the models in PART I, skip this step, or set False to True
if False:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
        ])

    parsingpredictor = BiSeNet(n_classes=19)
    parsingpredictor.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'faceparsing.pth'), map_location=lambda storage, loc: storage))
    parsingpredictor.to(device).eval()

    modelname = './checkpoint/shape_predictor_68_face_landmarks.dat'
    if not os.path.exists(modelname):
        wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
        zipfile = bz2.BZ2File(modelname+'.bz2')
        data = zipfile.read()
        open(modelname, 'wb').write(data) 
    landmarkpredictor = dlib.shape_predictor(modelname)

    pspencoder = load_psp_standalone(os.path.join(MODEL_DIR, 'encoder.pt'), device)    
    
vtoonify = VToonify(backbone = 'dualstylegan')
vtoonify.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'cartoon_generator.pt'), map_location=lambda storage, loc: storage)['g_ema'])
vtoonify.to(device)

exstyles = np.load(os.path.join(MODEL_DIR, 'cartoon_cartoon_exstyle.npy'), allow_pickle='TRUE').item()  
styles = []
with torch.no_grad(): 
    for stylename in exstyles.keys():
        exstyle = torch.tensor(exstyles[stylename]).to(device)
        exstyle = vtoonify.zplus2wplus(exstyle)
        styles += [exstyle]
exstyles = torch.cat(styles, dim=0)
    
print('Model successfully loaded!')

## Step 3: Image Toonification

### Visualize and Rescale Input
We rescale the input image to make it fit our pre-trained model

In [None]:
image_path = './data/077436.jpg'
original_image = load_image(image_path)

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(original_image[0])
plt.show()

In [None]:
frame = cv2.imread(image_path)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

scale = 1
kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
# We detect the face in the image, and resize the image so that the eye distance is 64 pixels.
# Centered on the eyes, we crop the image to almost 400x400 (based on args.padding).
paras = get_video_crop_parameter(frame, landmarkpredictor, padding=[200,200,200,200])
if paras is not None:
    h,w,top,bottom,left,right,scale = paras
    H, W = int(bottom-top), int(right-left)
    # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
    if scale <= 0.75:
        frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
    if scale <= 0.375:
        frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
    frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
    x = transform(frame).unsqueeze(dim=0).to(device)
else:
    print('no face detected!')

In [None]:
plt.figure(figsize=(10,10),dpi=30)
visualize(x[0].cpu())
plt.show()

### Select style image

Select the style index (the mapping between index and style image is defined [here](https://github.com/williamyang1991/DualStyleGAN/blob/main/doc_images/cartoon_overview.jpg))

In [None]:
style_id = [8, 26, 64, 153, 299]

### Style transfer with different cartoon structure styles

In [None]:
with torch.no_grad():
    I = align_face(frame, landmarkpredictor)
    I = transform(I).unsqueeze(dim=0).to(device)
    s_w = pspencoder(I)
    s_w = vtoonify.zplus2wplus(s_w).repeat(len(style_id), 1, 1)
    s_w[:,:7] = exstyles[style_id,:7]
    x = x.repeat(len(style_id), 1, 1, 1)
    # parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
    # followed by downsampling the parsing maps
    x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
                        scale_factor=0.5, recompute_scale_factor=False).detach()
    # we give parsing maps lower weight (1/16)
    inputs = torch.cat((x, x_p/16.), dim=1)
    # d_s has no effect when backbone is toonify
    y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = 0.6)        
    y_tilde = torch.clamp(y_tilde, -1, 1)

In [None]:
viz = torchvision.utils.make_grid(y_tilde, 2, 2)
plt.figure(figsize=(10,10),dpi=120)
visualize(viz.cpu())
plt.show()

### Navigation with different style degree to achieve flexible style manipulation

Users are suggested to try different style degrees to find the ideal results

In [None]:
results = []
for i in range(5):
    d_s = i / 4.0
    y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = d_s)  
    y_tilde = torch.clamp(y_tilde, -1, 1)
    results += [y_tilde]
        
vis = torchvision.utils.make_grid(torch.cat(y_tilde, dim=0), 5, 1)
plt.figure(figsize=(10,10),dpi=120)
visualize(vis.cpu())
plt.show()