# PortraitStylization - Stylizing Human Portrait Images with Auxiliary Models.

Welcome to the PortraitStylization demo notebook. Here you can find a simple way to interact with the project API

## Install Dependecies and Load Modules

Run it first to install the required dependencies.

In [None]:
!git clone https://github.com/thiagoambiel/PortraitStylization.git
%cd PortraitStylization/

In [None]:
!pip install -r requirements.txt

In [None]:
%reload_ext autoreload
%autoreload

import io

import torch
import numpy as np

from PIL import Image, ImageColor
from matplotlib import pyplot as plt

from ipywidgets import widgets, interact
from IPython.core.display import display, HTML

from style_transfer import StyleTransfer
from remove_bg import BackgroundRemoval


class ImageUploader:
  def __init__(self, only_one: bool = False):
    self.only_one = only_one
    self.data = []

    self.output = widgets.Output()
    self.uploader = widgets.FileUpload(multiple=not self.only_one)


  def plot_images(self):    
    f, axarr = plt.subplots(1, len(self.data), figsize=(4 * len(self.data), 4 * len(self.data)))

    if len(self.data) > 1:
      for idx, img in enumerate(self.data):
        axarr[idx].imshow(img)

    else:
      axarr.imshow(self.data[0])

    plt.show()


  def save(self, _):
    with self.output:

      if self.only_one and len(self.data) > 0:
        raise ValueError("Only one image can be uploaded in this field.")

      for name, file_info in self.uploader.value.items():
        img = Image.open(io.BytesIO(file_info['content']))
        self.data.append(img)

      self.output.clear_output(wait=True)

      print("Image Uploaded!")
      self.plot_images()


  def run(self):
    display(self.output, self.uploader)
    self.uploader.observe(self.save, names='_counter')

# Let's Setup our Model

Upload your Content and Style Images, Configure the Model Parameters and Run.

In [None]:
#@title Upload Content Image {display-mode: "form"}
#@markdown Upload a Content Image to be Stylized.

content_uploader = ImageUploader(only_one=True)
content_uploader.run()

In [None]:
#@title Upload Style Images {display-mode: "form"}
#@markdown Upload the style images that will be used to stylize the content.
#@markdown You can upload more than one image.

styles_uploader = ImageUploader()
styles_uploader.run()

In [None]:
#@title Preprocess Content and Style Data { run: "auto" }
#@markdown For some images, it is recommended to remove the background for better results. use 'MODNet' to load the background removal tool or 'Raw_Content' to use the original image as input.

Mode = "MODNet" #@param ['MODNet', 'Raw_Content']

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

original_image = content_uploader.data[0]
style_images = styles_uploader.data

def render(bgcolor, fgcolor, fg_fac, bt_fac, upload):

    texture_img = None
    if upload.items():
      name, file_info = next(iter(upload.items()))
      texture_img = Image.open(io.BytesIO(file_info['content']))
      print(f"Texture Loaded: {name}")

    result = background_removal.remove_background(
        img=original_image,
        alpha=alpha,
        bg_color=bgcolor,
        bg_texture=texture_img,
        bt_fac=bt_fac,
        fg_color=fgcolor,
        fg_fac=fg_fac
    )

    f, axarr = plt.subplots(1, 3, figsize=(9, 3))

    axarr[0].imshow(original_image)
    axarr[1].imshow(alpha)
    axarr[2].imshow(result)

    result_data.clear()
    result_data.insert(0, result)
    
    return f.tight_layout()


if Mode == "MODNet":

  background_removal = BackgroundRemoval(weights_path="./weights/modnet.pth", device=device)
  alpha = background_removal.gen_alpha(np.array(original_image))

  bgcolor_picker = widgets.ColorPicker(description='Back', value='#4911e4')
  fgcolor_picker = widgets.ColorPicker(description='Fore', value='#ffffff')
  fg_fac_slider = widgets.FloatSlider(description='ForeFac', value=0.0, min=0, max=1.0, step=0.05)
  bt_fac_slider = widgets.FloatSlider(description='TextureFac', value=1.0, min=0, max=1.0, step=0.05)

  texture_uploader = widgets.FileUpload(description="Load Texture")
  texture_uploader.add_class("left-spacing-class")
  display(HTML(
     """<style>.left-spacing-class {
       margin-left: 120px;
       margin-bottom: 20px;
       margin-top: 10px;
       }
     </style>"""
  ))

  result_data = []

  interact(render, 
    bgcolor=bgcolor_picker, 
    fgcolor=fgcolor_picker, 
    fg_fac=fg_fac_slider, 
    bt_fac=bt_fac_slider, 
    upload=texture_uploader
  )

In [None]:
#@title Set Algorithm Parameters and Run

content_image = result_data[0] if Mode == "MODNet" else original_image

#@markdown Adjust Weights for Better Results.
content_weight = 0.015 #@param {type: "number"}
face_weight = 0.015 #@param {type: "number"}
mesh_weight =  0#@param {type: "number"}

#@markdown Select the Initial and Final Resolution of your Stylized Image.
min_scale = 128 #@param {type: "number"}
end_scale =  1448#@param {type: "number"}

#@markdown Set the Number of Iterations per Scale.
iterations =  1000#@param {type: "number"}
initial_iterations = 1000 #@param {type: "number"}

#@markdown The Relative Scale of the Style to the Content.
style_scale_fac = 1 #@param {type: "slider", min: 0.0, max: 2.0, step: 0.1}

#@markdown Enable Face Cropping when passing data to FaceNet and FaceMesh.
crop_faces = False #@param {type:"boolean"}
square_faces = False #@param {type:"boolean"}

#@markdown Add Padding to the Detected Faces Bounding Boxes (Need crop_faces to be Enabled).
padding_scale_fac = 0 #@param {type: "slider", min: 0.0, max: 1.0, step: 0.1}

#@markdown Set VGG Layers Pooling type.
pooling = "max" #@param ['max', 'average', 'l2']

st = StyleTransfer(device=device, pooling=pooling)

st.stylize(
  content_image=content_image, style_images=style_images,

  content_weight=content_weight,
  face_weight=face_weight,
  mesh_weight=mesh_weight,

  min_scale=min_scale,
  end_scale=end_scale,

  iterations=iterations,
  initial_iterations=initial_iterations,

  style_scale_fac=style_scale_fac,

  crop_faces=crop_faces,
  square_faces=square_faces,

  padding_scale_fac=padding_scale_fac,

  plot_progress=True,
  save_path="/content/result.png"
)