# Talking Head Anime from a Single Image (Manual Poser Tool)

<a href="https://kaggle.com/kernels/welcome?src=https://github.com/pkhungurn/talking-head-anime-demo/blob/master/tha_colab.ipynb"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>&nbsp;<br></br>**Instruction**

1. From the main menu, click "Runtime > Change runtime type." 
2. Change "Hardware accelerator" to "GPU," and click "Save."
3. Run the four cells below, one by one, in order by clicking the "Play" button to the left of it. Wait for each cell to finish before going to the next one.
4. Scroll down to the end of the last cell, and play with the GUI.

**Constraints on Images**

1. Must be an image of a single humanoid anime character.
2. Must be of size 256x256.
3. The head must be roughly contained in the middle 128x128 middle box.
4. Must have PNG format.
5. Must have an alpha channel.
6. Background pixels must have RGBA=(0,0,0,0). See [this link](https://github.com/pkhungurn/talking-head-anime-demo/issues/1) if you do not get clean results.

**Links**

* Github repository: http://github.com/pkhungurn/talking-head-anime-demo
* Project writeup: http://pkhungurn.github.io/talking-head-anime

In [None]:
# Clone the repository
%cd /content
!git clone https://github.com/pkhungurn/talking-head-anime-demo.git

In [None]:
# CD into the repository directory.
%cd /content/talking-head-anime-demo

In [None]:
# Download model files
!wget -O data/combiner.pt https://www.dropbox.com/s/p220v9rmbjmqien/combiner.pt?dl=0
!wget -O data/face_morpher.pt https://www.dropbox.com/s/oukbnofkffc2bis/face_morpher.pt?dl=0
!wget -O data/two_algo_face_rotator.pt https://www.dropbox.com/s/o78wzc5cpxnxggr/two_algo_face_rotator.pt?dl=0

In [None]:
# Run the GUI

import torch

DEVICE_NAME = 'cuda'
device = torch.device(DEVICE_NAME)

import PIL.Image
import io
from io import BytesIO
import IPython.display
import numpy
import ipywidgets
from poser.morph_rotate_combine_poser import MorphRotateCombinePoser256Param6
from poser.poser import Poser
from tha.combiner import CombinerSpec
from tha.face_morpher import FaceMorpherSpec
from tha.two_algo_face_rotator import TwoAlgoFaceRotatorSpec
from util import extract_pytorch_image_from_filelike, rgba_to_numpy_image

last_torch_input_image = None
torch_input_image = None

def show_pytorch_image(pytorch_image, output_widget=None):
    output_image = pytorch_image.detach().cpu()
    numpy_image = rgba_to_numpy_image(output_image)    
    pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGBA')        
    IPython.display.display(pil_image)

input_image_widget = ipywidgets.Output(
    layout={
        'border': '1px solid black',
        'width': '256px',
        'height': '256px'
    })

upload_input_image_button = ipywidgets.FileUpload(
    accept='.png',
    multiple=False,
    layout={
        'width': '256px'
    }
)

output_image_widget = ipywidgets.Output(
    layout={
        'border': '1px solid black',
        'width': '256px',
        'height': '256px'
    }
)

eye_left_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Left Eye:",
    readout=True,
    readout_format=".2f"
)
eye_right_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Right Eye:",
    readout=True,
    readout_format=".2f"
)
mouth_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=0.0,
    max=1.0,
    step=0.01,
    description="Mouth:",
    readout=True,
    readout_format=".2f"
)

head_x_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="X-axis:",
    readout=True,
    readout_format=".2f"
)
head_y_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Y-axis:",
    readout=True,
    readout_format=".2f",    
)
neck_z_slider = ipywidgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    description="Z-axis:",
    readout=True,
    readout_format=".2f",    
)


control_panel = ipywidgets.VBox([    
    ipywidgets.HTML(value="<center><b>Head Rotation</b></center>"),
    head_x_slider,
    head_y_slider,
    neck_z_slider,
    ipywidgets.HTML(value="<hr>"),
    ipywidgets.HTML(value="<center><b>Facial Features</b></center>"),
    eye_left_slider,
    eye_right_slider,
    mouth_slider,
])

controls = ipywidgets.HBox([
    ipywidgets.VBox([
        input_image_widget, 
        upload_input_image_button
    ]),
    control_panel,
    ipywidgets.HTML(value="&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"),
    output_image_widget,
])

poser = MorphRotateCombinePoser256Param6(
    morph_module_spec=FaceMorpherSpec(),
    morph_module_file_name="data/face_morpher.pt",
    rotate_module_spec=TwoAlgoFaceRotatorSpec(),
    rotate_module_file_name="data/two_algo_face_rotator.pt",
    combine_module_spec=CombinerSpec(),
    combine_module_file_name="data/combiner.pt",
    device=device)
pose_size = 6
last_pose = torch.zeros(1, pose_size).to(device)

def get_pose():
    pose = torch.zeros(1, pose_size)
    
    pose[0, 0] = head_x_slider.value
    pose[0, 1] = head_y_slider.value
    pose[0, 2] = neck_z_slider.value
    pose[0, 3] = eye_left_slider.value
    pose[0, 4] = eye_right_slider.value
    pose[0, 5] = mouth_slider.value
        
    return pose.to(device)

display(controls)

def update(change):
    global last_pose
    global last_torch_input_image
        
    if torch_input_image is None:
        return
        
    needs_update = False
    if last_torch_input_image is None:
        needs_update = True        
    else:
        if (torch_input_image - last_torch_input_image).abs().max().item() > 0:
            needs_update = True         
            
    pose = get_pose()
    if (pose - last_pose).abs().max().item() > 0:
        needs_update = True
    
    if not needs_update:
        return
   
    output_image = poser.pose(torch_input_image, pose)[0]
    with output_image_widget:
        output_image_widget.clear_output(wait=True)
        show_pytorch_image(output_image, output_image_widget)  
        
    last_torch_input_image = torch_input_image
    last_pose = pose
        
def upload_image(change):
    global torch_input_image
    for name, file_info in upload_input_image_button.value.items():
        torch_input_image = extract_pytorch_image_from_filelike(io.BytesIO(file_info['content'])).to(device)
        torch_input_image = torch_input_image.unsqueeze(0)
    if torch_input_image is not None:
        n,c,h,w = torch_input_image.shape
        if h != 256 or w != 256:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                display(ipywidgets.HTML("Image must be 256x256 in size!!!"))
            torch_input_image = None
        if c != 4:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                display(ipywidgets.HTML("Image must have an alpha channel!!!"))                
            torch_input_image = None
        if torch_input_image is not None:
            with input_image_widget:
                input_image_widget.clear_output(wait=True)
                show_pytorch_image(torch_input_image[0], input_image_widget)
        update(None)
        
upload_input_image_button.observe(upload_image, names='value')
eye_left_slider.observe(update, 'value')
eye_right_slider.observe(update, 'value')
mouth_slider.observe(update, 'value')
head_x_slider.observe(update, 'value')
head_y_slider.observe(update, 'value')
neck_z_slider.observe(update, 'value')