# 概要

1枚の人物画像から背景を切り抜いて人物だけを切り出す

1. trimapを作成
2. `FBA_Matting`で背景と前景を分離
3. 結果を保存

## 1. trimapを作成

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os

import torch
import torchvision
from torchvision import transforms

In [None]:
image_path = 'examples/images/takeshi_512_no_alpha.png'
image_name = os.path.basename(image_path)
mask_image_name = os.path.join('./examples/mask/', image_name)
trimap_image_name = os.path.join('./examples/trimaps/',image_name)
os.makedirs('./examples/mask/',exist_ok=True)
os.makedirs('./examples/trimaps/',exist_ok=True)
img = cv2.imread(image_path)
img = img[...,::-1] #BGR->RGB
h,w,_ = img.shape
img = cv2.resize(img,(320,320))

## segmentationで人間を抜き出す

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model = model.to(device)
model.eval()

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)

In [None]:
with torch.no_grad():
    output = model(input_batch)['out'][0]
output = output.argmax(0)
mask = output.byte().cpu().numpy()
mask = cv2.resize(mask,(w,h))
img = cv2.resize(img,(w,h))
cv2.imwrite(mask_image_name,mask)
plt.gray()
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(mask)

## mask画像を拡散収縮させてtrimapを作成

In [None]:
def binarize_image(image):
    image_bi = image.copy()
    image_bi[np.where(image_bi > 0)] = 255
    return image_bi

In [None]:
def gen_trimap(mask,k_size=(5,5),ite=1):
    kernel = np.ones(k_size,np.uint8)
    eroded = cv2.erode(mask,kernel,iterations = ite)
    dilated = cv2.dilate(mask,kernel,iterations = ite)
    eroded_bi = binarize_image(eroded)
    dilated_bi = binarize_image(dilated)
    trimap = np.full(mask.shape,128)
    trimap[eroded_bi == 255] = 255
    trimap[dilated_bi == 0] = 0
    return trimap
trimap = gen_trimap(mask,k_size=(5,5),ite=3)
cv2.imwrite(trimap_image_name,trimap)
plt.figure(figsize=(20,20))
plt.subplot(1,2,1)
plt.imshow(img)
plt.subplot(1,2,2)
plt.imshow(trimap)

## `FBA_Matting`で背景と前景を分離する

In [None]:
import os
import sys

if not 'FBA_Matting' in sys.path:
    print('add path')
    sys.path.append(os.path.join(os.path.dirname('__file__'), 'FBA_Matting'))
print(sys.path)

In [None]:
from demo import np_to_torch, pred, scale_input
from dataloader import read_image, read_trimap
from networks.models import build_model
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
class Args:
  encoder = 'resnet50_GN_WS'
  decoder = 'fba_decoder'
  weights = './FBA_Matting/FBA.pth'
args=Args()
try:
    model = build_model(args)
except:
    !gdown  https://drive.google.com/uc?id=1T_oiKDE_biWf2kqexMEN7ObWqtXAzbB1
    model = build_model(args)

In [None]:
image = read_image(image_path)
trimap = read_trimap(trimap_image_name)

In [None]:
fg, bg, alpha = pred(image, trimap, model)

In [None]:
plt.title('Alpha Matte')
plt.imshow(alpha, cmap='gray', vmin=0, vmax=1)
plt.show()
plt.title('Foreground')
plt.imshow(fg)
plt.show()
plt.title('Background')
plt.imshow(bg)
plt.show()
plt.title('Composite')
plt.imshow(fg*alpha[:,:,None])
plt.show()

In [None]:
'''
!python ./FBA_Matting/demo.py --image_dir ./examples/images/ --trimap_dir ./examples/trimaps/ --output_dir ./examples/predictions/ --weights ./FBA_Matting/FBA.pth
'''

# 結果を保存

In [None]:
os.makedirs('output',exist_ok=True)
output_file = image_name.split('.')[0] + '_fg.png'
output_mask_file = image_name.split('.')[0] + '_fg_mask.png'
out_fg_img = cv2.cvtColor(fg*alpha[:,:,None] * 255.0, cv2.COLOR_RGB2BGR)
cv2.imwrite('./output/' + output_file, out_fg_img)
cv2.imwrite('./output/' + output_mask_file, alpha * 255.0)