<a href="https://colab.research.google.com/github/xxxrokxxx/GDL/blob/master/Copy_of_PULSE_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Enter the number of images you want to generate, press the run button on the left and then while it is spinning, click on "Choose Files" to upload your own pictures
number_of_images = 5#@param {type: "number"}
show_intermediate_images = True#@param {type: "boolean"}
downsampling_factor = "32"  #@param [32, 64]

#@markdown <font color='red'>Many users have reported receiving an error message saying "Google Drive Quota Exceeded." We are currently using Google Drive to store our model weights and it has a daily cap on downloads. If you are experiencing this error please try again later in the day or come back tomorrow. We apologize for the inconvenience.</font>
#@markdown ### NOTE: PULSE only attempts to match the downscaled version of the image, and the output will likely not resemble the high resolution input image. ###

#@markdown We suggest running this demo in Google Chrome. Using the show_intermediate_images option may slow down performance. PULSE works best on images where people are directly facing the camera. No data is stored. See: https://github.com/adamian98/pulse for additional information.

from pathlib import Path

if not Path("PULSE.py").exists():
  if Path("pulse").exists():
    %cd pulse
  else:
    ! git clone https://github.com/adamian98/pulse
    %cd pulse

from google.colab import files
from io import BytesIO
from matplotlib import pyplot as plt
from PIL import Image
from PULSE import PULSE
import torchvision
from IPython import display
import numpy as np
from shape_predictor import align_face
from drive import open_url
import dlib
from mpl_toolkits.axes_grid1 import ImageGrid
from bicubic import BicubicDownSample

display.clear_output(wait=True)

uploaded_names = files.upload().keys()

if(len(uploaded_names)==0): raise Exception("You need to upload at least one image.")

f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir="cache", return_path=True)
predictor = dlib.shape_predictor(f)

toPIL = torchvision.transforms.ToPILImage()
toTensor = torchvision.transforms.ToTensor()
D = BicubicDownSample(factor=int(downsampling_factor))

images = []
imagesHR = []
for filename in uploaded_names:
  for face in align_face(filename,predictor):
    imagesHR.append(face)
    face = toPIL(D(toTensor(face).unsqueeze(0).cuda()).cpu().detach().clamp(0,1)[0])
    images.append(face)

if(len(images)==0): raise Exception("No faces found. Try again with a different image.")

model = PULSE(cache_dir="cache", verbose=False)

kwargs={
 'loss_str': '100*L2+0.05*GEOCROSS',
 'seed': None,
 'eps': 1e-3,
 'noise_type': 'trainable',
 'num_trainable_noise_layers': 5,
 'tile_latent': False,
 'bad_noise_layers': '17',
 'opt_name': 'adam',
 'learning_rate': 0.4,
 'steps': 100,
 'lr_schedule': 'linear1cycledrop',
 'save_intermediate': True
}
dims = np.array((len(images),number_of_images+2))
fig = plt.figure(figsize=20*dims)
axs = ImageGrid(fig, 111, nrows_ncols=dims, axes_pad=0.2)

im_downsample = Image.open("resources/downsample.png")
im_PULSE = Image.open("resources/PULSE.png")

display.clear_output(wait=True)
image_list=[]
for ax in axs:
    image_list.append(ax.imshow(Image.new('RGB', (1024,1024), (255, 255, 255))))
    ax.axis('off')

for i,(imHR,imLR) in enumerate(zip(imagesHR,images)):
    axs[i*dims[1]].imshow(imLR.resize((1024,1024),Image.NEAREST))
    axs[i*dims[1]+1].imshow(im_PULSE)

display.display(plt.gcf())
display.clear_output(wait=True)

for i,PIL_im in enumerate(images):
    ref_im = torchvision.transforms.ToTensor()(PIL_im).unsqueeze(0).cuda()
    for j in range(number_of_images):
        running_text = axs[i*dims[1]+j+2].text(50,50,f"Running...",
                                          {'family': 'serif','weight': 'normal','size': 12},
                                          horizontalalignment='left',
                                          verticalalignment='top',
                                          bbox=dict(facecolor='white', alpha=1))
        
        display.display(plt.gcf())
        display.clear_output(wait=True)
        for k,(HR,_) in enumerate(model(ref_im,**kwargs)):
          if(show_intermediate_images):
            if(k==0 or (k+1)%10==0):
              PIL_out = toPIL(HR[0].cpu().detach().clamp(0, 1))
              curr_image = image_list[i*dims[1]+j+2].set_data(PIL_out)
              running_text.set_text(f"Running ({k+1}%)")
              display.display(plt.gcf())
              display.clear_output(wait=True)

          if(k+1==kwargs["steps"]):
            PIL_out = toPIL(HR[0].cpu().detach().clamp(0, 1))
            curr_image = image_list[i*dims[1]+j+2].set_data(PIL_out)
            display.display(plt.gcf())
            display.clear_output(wait=True)
          
        running_text.remove()
        display.display(plt.gcf())
        display.clear_output(wait=True)