<a href="https://colab.research.google.com/github/simonsny/image_background_removal/blob/main/FlaskApp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Background Removal with Flask Application 

![TeamTheano](https://img.shields.io/badge/ProjectBy:-Simon,Gulce,Louan,Arlene-<COLOR>.svg)

<p> To use this notebook, please run all the cells in this notebook. </p>


First, you need to connect your Google Colab to your Google Drive. 


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Then you need to install the Flask-ngrok to be able to run the Flask app inside the Google Colab.


In [None]:

!pip install flask-ngrok 


After that, we import all the necessary libraries for this project.

In [None]:
import os
import cv2
import numpy as np
import torch
from torchvision import transforms, utils
from skimage import io, transform, color
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import glob
from __future__ import print_function, division
import random
import math
import matplotlib.pyplot as plt

%cd /content/drive/MyDrive/BeCode/Faktion/U-2-Net/
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET 

Then get the binary mask from the [U2Net model](https://github.com/xuebinqin/U-2-Net).

In [None]:
def normPRED(d):
    """
    Function that will normalize the predicted SOD probability map
    :attrib d will contain the prediction
    :attrib db will contain the normalized prediction
    This function will return the dn as the normalized prediction
    """
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d-mi)/(ma-mi)
    return dn

def save_output(image_name,predict,d_dir):
    """
    Function that will save the output/result image
    :attrib image_name will be the name of the image
    :attrib predict will be the predicted output
    :attrib d_dir will be the directory of the saved output/result 
    :attrib predict_np will contain the predicted numpy array
    :attrib im will contain the array image
    :attrib img_name will be the image name
    :attrib image will contain the read image_name using io.imread
    :attrib imo will contain the resized image
    This function will save the prediction in the d_dir with a .png extension
    """
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()
    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
    pb_np = np.array(imo)
    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def u2net(filename):
    """
    Function that will get the alpha matte from the U2Net which we will then use as a binary mask
    :attrib filename contains the image file path
    This function will load the U2net model and save the and return the output 
    """
    %cd /content/drive/MyDrive/BeCode/Faktion/U-2-Net/
    prediction_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/binary_mask/'
    model_dir =  '/content/drive/MyDrive/BeCode/Faktion/U-2-Net/saved_models/u2net/u2net.pth'
    img_name_list = [filename.split("/")[-1]]
    print(img_name_list)
    %cd '/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/' 
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)
    net = U2NET(3,1)
    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()
    for i_test, data_test in enumerate(test_salobj_dataloader):
        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)
        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)
        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
        pred = d1[:,0,:,:]
        pred = normPRED(pred)
        if not os.path.exists(prediction_dir):
            os.makedirs(prediction_dir, exist_ok=True)
        save_output(img_name_list[i_test],pred,prediction_dir)
        del d1,d2,d3,d4,d5,d6,d7
    print("Successfully created the binary mask")
    return filename.split(".")[0] + ".png"

Now that we already have the binary mask from U2Net, we need to trimap. To do that we have function to get the trimap using the binary mask.

In [None]:
def generate_trimap(mask_path, open_size=10, mask_margin=10):
    """
    Function that will generate the trimap
    :attrib mask_path is the path of the created binary mask
    :attrib mask will contain the read mask_path
    :attrib foregound will contain the predicted foreground region
    :attrib background will contain the predicted background region
    :attrib unknown will contain the unknown region
    :attrib trimap will contain the trimap which will then be saved inside the trimap folder
    This function will save the binary mask and return it as the output
    """
    mask = cv2.imread(mask_path, 0)
    assert mask.ndim == 2
    foreground = ((255 - mask_margin) < mask)
    background = (mask < mask_margin)
    unknown = ~(foreground + background)
    unknown = cv2.dilate(
        unknown.astype(np.uint8),
        cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (open_size, open_size))
    ).astype(np.bool)
    trimap = np.zeros_like(mask)
    trimap[foreground] = 255
    trimap[unknown] = 127
    cv2.imwrite(f"/content/drive/MyDrive/BeCode/Faktion/Flask/static/trimap/{mask_path.split('/')[-1]}",trimap)
    return mask_path.split('/')[-1]    

Since we now have the binary mask and the trimap, we can now generate the alpha matte using the [DIM pretrained model](https://github.com/foamliu/Deep-Image-Matting-PyTorch).

In [None]:
def alpha_matte(file):
  """
  Function that will generate the alpha matte from the original and the trimap
  :attrib file will contain the original image
  :attrib device will contain the torch.device
  :attrib IMG_FOLDER will be the uploaded image folder
  :attrib TRIMAP_FOLDER will be the trimap folder
  :attrib matte_folder will be the matte folder
  :attrib checkpoint will be the DIM pretrained model best checkpoint
  :attrib model will contain the model
  This function will return the alpha matte and save it in the matte folder
  """
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # sets device for model and PyTorch tensors
  %cd /content/drive/MyDrive/BeCode/Faktion/model_code/
  IMG_FOLDER = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/'
  TRIMAP_FOLDER = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/trimap/'
  matte_folder = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/matte/'
  checkpoint = '/content/drive/MyDrive/BeCode/Faktion/BEST_checkpoint.tar'
  checkpoint = torch.load(checkpoint)
  model = checkpoint['model'].module
  model = model.to(device)
  model.eval()

  """
  :attrib filename will contain the file path
  :attrib img will contain the read filename
  :attrib h and w are the shape of the image height and width
  :attrib x will contain the torch zeros
  :attrib transformer will contain the transforms.Compose
  """
  filename = os.path.join(IMG_FOLDER, file)
  img = cv.imread(filename)
  h, w = img.shape[:2]
  x = torch.zeros((1, 4, h, w), dtype=torch.float)
  image = img[..., ::-1]  # RGB
  image = transforms.ToPILImage()(image)
  transformer = transforms.Compose(
      [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ])
  image = transformer(image)
  x[0:, 0:3, :, :] = image
  
  """
  :attrib filename will contain the trimap file path
  :attrib trimap will contain the read filename
  :attrib h and w are the shape of the image height and width
  :attrib pred will contain the predicted matte
  :attrib out will contain the pred in a np.uint8 type which will then be saved in the matte folder
  """
  filename = os.path.join(TRIMAP_FOLDER, file.replace('.jpg', '.png'))
  trimap = cv.imread(filename, 0)
  x[0:, 3, :, :] = torch.from_numpy(trimap.copy() / 255.)
  x = x.type(torch.FloatTensor).to(device)
  with torch.no_grad():
      pred = model(x)
  pred = pred.cpu().numpy()
  pred = pred.reshape((h, w))
  pred[trimap == 0] = 0.0
  pred[trimap == 255] = 1.0
  out = (pred.copy() * 255).astype(np.uint8)
  cv.imwrite(matte_folder+file.replace('.jpg', '.png'), out)
  print(f'Created an alpha matte for  {file}.'.format(filename))
  return matte_folder + file.replace('.jpg', '.png')

Getting the foreground image using the original image and the alpha matte.


In [None]:
def foreground_img(image, alpha, filename):
  """
  Function that will remove the background from the image
  :attrib image will contain the uploaded image
  :attrib alpha will contain the alpha matte
  :attrib filename will be the filename of the saved image
  :attrib img will read the image
  :attrib matte will read the alpha
  :attrib foregound will be the matte multiplied by img
  This function will return the filename
  """
  img = cv2.imread(image)
  matte = cv2.imread(alpha)
  img = img.astype(float)
  matte = matte.astype(float)/255
  foreground = cv2.multiply(matte, img)
  cv2.imwrite(output_dir+filename,  foreground)
  print(f"Successfully saved the foreground image for {filename}.")
  return filename

And finally, the Flask Application will start here.

In [None]:
import os
from flask_ngrok import run_with_ngrok
from flask import Flask, flash, render_template, request, redirect, url_for, send_file
import time
from werkzeug.utils import secure_filename
from google.colab.patches import cv2_imshow

"""
:attrib upload_dir contains the upload path of the uploaded image
:attrib output_dir contains the output/result image 
:attrib matte_dir contains the alpha matte of the image
:allowed_extensions contains the list of allowed image file extensions
"""
upload_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/uploads/'
output_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/output/'
matte_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/matte/'
trimap_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/trimap/'
binary_dir = '/content/drive/MyDrive/BeCode/Faktion/Flask/static/binary_mask/'
allowed_extensions = {"png", "jpg", "jpeg", "gif"}

app = Flask(__name__, template_folder='/content/drive/MyDrive/BeCode/Faktion/Flask/templates/', static_folder='/content/drive/MyDrive/BeCode/Faktion/Flask/static')
run_with_ngrok(app)
app.secret_key = "secret key"
app.config["UPLOAD_FOLDER"] = upload_dir

def allowed_file(filename):
    """
    Function that checks that the file extensions is included on the allowed list.
    """
    return "." in filename and filename.rsplit(".", 1)[1].lower() in allowed_extensions

@app.route("/")
def home():
    """
    Function that returns the home page.
    """
    return render_template("home.html")

@app.route("/start", methods=["GET", "POST"])
def start():
    """
    Function that has both GET and POST method.
    This is the function where it will ask the user input and 
    then return the input with the edited version of the input
    """
    if request.method == "GET":
        return render_template("start.html")

    if request.method == "POST":
        if "image" not in request.files and "video" not in request.files:
            return redirect(request.url)
        """
        :attrib image will contain the image user input 
        :attrib video will contain the video user input
        :attrib start will contain the start time on when the program started
        :attrib end will contain the end time on when the program ended
        """
        image = request.files.get("image")
        print(f"Image Input: {image}")
        video = request.files.get("video")
        print(f"Video Input: {video}")
        start = time.time()

        if image == "" and video == "":
            return redirect(request.url)

        elif image:
            if allowed_file(image.filename):
                """
                :attrib image_upload will contain the secured uploaded image
                :attrib binary_mask will call the u2net function to get the binary mask
                :attrib trimap will call the generate_trimap function to get the trimap
                :attrib matte_name will call the alpha_matte function to get the alpha matte
                :attrib image_path will contain the image path of the uploaded image
                :attrib edited_image will call the foreground_img function to get the foreground
                :attrib foreground will contain the edited/result image file path
                """
                image_upload = secure_filename(image.filename)
                image.save(os.path.join(app.config["UPLOAD_FOLDER"], image_upload))
                binary_mask = u2net(image_upload)
                trimap = generate_trimap(binary_dir + binary_mask)
                matte_name = alpha_matte(image.filename)
                image_path = upload_dir + image.filename
                edited_image = foreground_img(image_path, matte_name, image.filename)
                foreground = Image.open(os.path.join(output_dir, edited_image))
                end = time.time()
                print(f"Program runs for {end - start} seconds.")
                return render_template("upload_image.html", filename=image_upload)
            else:
                flash("Chosen file is not supported! Please upload an image file.")
                flash("Allowed image types are -> png, jpg, jpeg, gif")
                return redirect(request.url)
        elif video:
            print("Elif video file")
            end = time.time()
            print(f"Program runs for {end - start} seconds.")
            return render_template("home.html")
        else:
            return render_template("start.html")

@app.route("/display/<filename>")
def display_image(filename):
    """
    Function that displays the uploaded image.
    """
    print(f"Display image : {filename}")
    return redirect(url_for("static", filename="uploads/" + filename), code=301)
    
@app.route("/edited/<filename>")
def edited_image(filename):
    """
    Function that displayes the edited image/result image.
    """
    print(f"Edited image : {filename}")
    return redirect(url_for("static", filename="output/" + filename), code=301)

@app.route("/save_file/<filename>")
def save_file(filename):
    """
    Function that that allows the user to save/download the image.
    """
    path = f"/content/drive/MyDrive/BeCode/Faktion/Flask/static/output/{filename}"
    print(f"Send File Path: {path}")
    return send_file(path, as_attachment=True)

@app.route("/upload_image")
def upload_image():
    """
    Function that returns the upload_image page.
    """
    return render_template("upload_image.html")

@app.route("/howtouse")
def howtouse():
    """
    Function that returns the how to use page.
    """
    return render_template("howtouse.html")

@app.route("/about")
def about():
    """
    Function that returns the about page.
    """
    return render_template("about.html")

@app.route("/live_feed")
def live_feed():
    """
    Function that returns the live feed page.
    """
    return render_template("live_feed.html")

if __name__ == "__main__":
    app.run()