<a href="https://colab.research.google.com/github/tantowijh/imagerestoration/blob/main/GoogleColabAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title __1. Install all the Requirements__
#@markdown ⬅ Run this cell to install all the requirements.
!pip install diffusers transformers accelerate scipy safetensors &> /dev/null
!pip install pyngrok &> /dev/null

# Real-ESRGAN
!git clone https://github.com/xinntao/Real-ESRGAN.git &> /dev/null
%cd Real-ESRGAN
!pip install -r requirements.txt &> /dev/null
!python setup.py develop &> /dev/null
!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py &> /dev/null

import os
os.makedirs('weights', exist_ok=True)
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./weights

In [None]:
#@title __2. Import Libraries and Load Models__
#@markdown ⬅ Run this cell to import all the libraries and load the models.
import requests
from io import BytesIO
from diffusers import AutoPipelineForInpainting
import torch

from basicsr.archs.rrdbnet_arch import RRDBNet
import cv2

from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from gfpgan import GFPGANer
import tempfile
import base64
import numpy as np
import io
from PIL import Image

# Load the Real-ESRGAN model
enhance_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
model_path = os.path.join("weights", "RealESRGAN_x4plus.pth")
upsampler = RealESRGANer(
            scale=netscale,
            model_path=model_path,
            model=enhance_model,
            tile=0,
            tile_pad=10,
            pre_pad=0,
            half=True)

# Load Inpainting Model
inpaint_model = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model, torch_dtype=torch.float16, variant="fp16")
inpaint_pipe = inpaint_pipe.to("cuda")

In [None]:
#@title __3. Initialize the Flask App__
#@markdown ⬅ Run this cell to initialize the Flask app, then run the following steps.

# Resize image to smaller resolution
def resize_image(image, max_resolution=(1920, 1080)):
    """
    Resize the image to fit within the max_resolution while maintaining the aspect ratio.

    :param image: PIL Image object.
    :param max_resolution: Tuple (width, height) representing the maximum resolution.
    :return: Resized PIL Image object.
    """
    original_width, original_height = image.size
    max_width, max_height = max_resolution

    # Calculate the new size while maintaining the aspect ratio
    aspect_ratio = original_width / original_height
    if original_width > original_height:
        new_width = min(max_width, original_width)
        new_height = int(new_width / aspect_ratio)
    else:
        new_height = min(max_height, original_height)
        new_width = int(new_height * aspect_ratio)

    # Ensure the new size is below the max resolution
    if new_width > max_width:
        new_width = max_width
        new_height = int(new_width / aspect_ratio)
    if new_height > max_height:
        new_height = max_height
        new_width = int(new_height * aspect_ratio)

    # Resize the image
    resized_img = image.resize((new_width, new_height), Image.LANCZOS)
    return resized_img

from flask import Flask, request, send_file, jsonify

app = Flask(__name__)

@app.route('/', methods=['GET'])
def home():
    return 'Hello, World!'

@app.route('/enhance', methods=['POST'])
def enhance():
  try:
    torch.cuda.empty_cache()
    # Get the image from the request
    image_data = request.files['image'].read()

    # Load the image from the request
    image = Image.open(BytesIO(image_data)).convert('RGB')

    # Resize the image to a resolution below 1080p while maintaining the aspect ratio
    image = resize_image(image, max_resolution=(1920, 1080))

    # Convert the image to a numpy array
    img = np.array(image)

    # Perform the upscaling using Real-ESRGAN
    scale = 4  # You can adjust the scale as needed
    output, _ = upsampler.enhance(img, outscale=scale)

    # Convert the output back to an image
    output_image = Image.fromarray(output)

    # Save the upscaled image to a BytesIO buffer
    buffered = BytesIO()
    output_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='enhanced_image.png')

  except Exception as e:
    print(f"Error occurred: {str(e)}")
    return jsonify({'error': str(e)}), 500

@app.route('/restore', methods=['POST'])
def restore():
  try:
    # Get the image, mask, and prompt from the request
    image_data = request.files['image'].read()
    mask_data = request.files['mask'].read()
    prompt = request.form['prompt']

    # Get optional parameters from the request
    guidance_scale = float(request.form.get('guidance_scale', 8.0))
    num_inference_steps = int(request.form.get('num_inference_steps', 20))
    strength = float(request.form.get('strength', 0.99))
    seed = int(request.form.get('seed', 0))

    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Load the image and mask from the request
    image = Image.open(BytesIO(image_data)).convert("RGB")
    mask = Image.open(BytesIO(mask_data)).convert("L")  # Convert mask to grayscale

    # Perform the inpainting using the Stable Diffusion model
    restored_image = inpaint_pipe(
        prompt=prompt,
        image=image,
        mask_image=mask,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        strength=strength,
        generator=generator,
    ).images[0]

    # Ensure the restored image is the same size as the original
    restored_image = restored_image.resize(image.size, Image.LANCZOS)

    # Save the restored image to a BytesIO buffer
    buffered = BytesIO()
    restored_image.save(buffered, format="PNG")
    buffered.seek(0)

    return send_file(buffered, mimetype='image/png', as_attachment=True, download_name='restored_image.png')

  except Exception as e:
    print(f"Error occurred: {str(e)}")
    return jsonify({'error': str(e)}), 500

In [None]:
#@title __4. Run the Flask App into Public URL__
#@markdown ⬅ Run this cell to run Flask, then retrieve the URL and enter it into the configuration page.

#@markdown __Choose a Service for the Public URL__

#@markdown We recommend using **ngrok** to obtain your public URL, as **Serveo** is not always available (it may be down).


from pyngrok import ngrok
import subprocess
import time
import re

remote_port = 80
local_port = 5000
service = "ngrok" # @param ["ngrok","serveo"]
ngrok.kill()

def run_app():
  app.run(host='0.0.0.0', port=local_port)

#@markdown ---

#@markdown __Obtain Your ngrok Token__
#@markdown
#@markdown Please obtain your ngrok auth token from the official ngrok website at
#@markdown [ngrok authtoken](https://dashboard.ngrok.com/get-started/your-authtoken).
#@markdown > **Note:** Entering your ngrok token is mandatory only if you choose the ngrok service. If you select Serveo, you can skip this step.

#@markdown Please put the token you received below:
if service == "ngrok":
  authtoken = "" #@param {type:"string"}

  if not authtoken:
    print('Please enter your ngrok auth token and try again.')
  else:
    ngrok.set_auth_token(authtoken)
    ngrok_run = ngrok.connect(local_port).public_url
    print(f"Ngrok public URL: {ngrok_run}\n")
    run_app()

else:
  # Command to run in the background, using the defined port variables
  command = ["ssh", "-o", "StrictHostKeyChecking=no", "-R", f"{remote_port}:localhost:{local_port}", "serveo.net"]

  # Start the process
  process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

  # Allow some time for the connection to establish and the URL to be printed
  time.sleep(5)

  url_match = None

  # Read the output from the process
  try:
      for line in process.stdout:
          print(line)  # Print the output for debugging
          # Use a regular expression to find the URL
          match = re.search(r'(https?://\S+)', line)
          if match:
              url_match = match.group(0)
              break  # Break if we found the URL
  except Exception as e:
      print(f"Error while reading output: {e}")
  finally:
      process.stdout.close()
      process.stderr.close()

  if url_match:
      run_app()
  else:
      print('No URL! The service may be down.')