In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os

import torch
import torchvision
from torchvision import transforms

In [None]:
image_path = 'examples/images/sample.png'
image_name = os.path.basename(image_path)
img = cv2.imread(image_path)
img = img[...,::-1] #BGR->RGB
h,w,_ = img.shape
img = cv2.resize(img,(320,320))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model = model.to(device)
model.eval()

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)

In [None]:
with torch.no_grad():
    output = model(input_batch)['out'][0]
output = output.argmax(0)
mask = output.byte().cpu().numpy()
mask = cv2.resize(mask,(w,h))
img = cv2.resize(img,(w,h))
cv2.imwrite('./examples/mask/'+image_name,mask)
plt.gray()
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(mask)

In [None]:
def binarize_image(image):
    image_bi = image.copy()
    image_bi[np.where(image_bi > 0)] = 255
    return image_bi

In [None]:
def gen_trimap(mask,k_size=(5,5),ite=1):
    kernel = np.ones(k_size,np.uint8)
    eroded = cv2.erode(mask,kernel,iterations = ite)
    dilated = cv2.dilate(mask,kernel,iterations = ite)
    eroded_bi = binarize_image(eroded)
    dilated_bi = binarize_image(dilated)
    trimap = np.full(mask.shape,128)
    trimap[eroded_bi == 255] = 255
    trimap[dilated_bi == 0] = 0
    return trimap
trimap = gen_trimap(mask,k_size=(5,5),ite=3)
cv2.imwrite('./examples/trimaps/'+image_name,trimap)
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(trimap)

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
img = cv2.imread('./examples/images/'+image_name)
img = img[...,::-1]
matte = cv2.imread('./examples/mattes/'+image_name)
h,w,_ = img.shape
bg = np.full_like(img,0) #white background

In [None]:
img = img.astype(float)
bg = bg.astype(float)

matte = matte.astype(float)/255
img = cv2.multiply(img, matte)
bg = cv2.multiply(bg, 1.0 - matte)
outImage = cv2.add(img, bg)
plt.imshow(outImage/255)