<a href="https://colab.research.google.com/github/softmurata/generative-ai-handsbook/blob/main/application/propainter/propainter_application.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Installation

In [None]:
!git clone https://github.com/sczhou/ProPainter.git

!pip install -U openmim
!mim install mmcv
!pip install einops

In [None]:
!pip install transformers accelerate bitsandbytes

Download pretrained models

In [None]:
!wget https://github.com/sczhou/ProPainter/releases/download/v0.1.0/i3d_rgb_imagenet.pt -P /content/ProPainter/weights
!wget https://github.com/sczhou/ProPainter/releases/download/v0.1.0/ProPainter.pth -P /content/ProPainter/weights
!wget https://github.com/sczhou/ProPainter/releases/download/v0.1.0/raft-things.pth -P /content/ProPainter/weights
!wget https://github.com/sczhou/ProPainter/releases/download/v0.1.0/recurrent_flow_completion.pth -P /content/ProPainter/weights

Prepare

In [None]:
# please upload inference_with_image.py in ProPainter directory

Inference with oneformer with room

In [None]:
import os
project_name = "room"
os.makedirs(f"/content/ProPainter/inputs/object_removal/{project_name}", exist_ok=True)
os.makedirs(f"/content/ProPainter/inputs/object_removal/{project_name}_mask", exist_ok=True)

import torch
from PIL import Image
from transformers import AutoProcessor

from transformers import AutoModelForUniversalSegmentation

model_id = "shi-labs/oneformer_ade20k_swin_large"
model = AutoModelForUniversalSegmentation.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
image = Image.open("/content/ProPainter/inputs/object_removal/room/room002.jpeg")

semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")

# forward pass
with torch.no_grad():
  outputs = model(**semantic_inputs)

semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]

In [3]:
print(model.config.label2id.keys())

dict_keys(['animal', 'arcade machine', 'armchair', 'awning, sunshade, sunblind', 'bag', 'ball', 'bannister, banister, balustrade, balusters, handrail', 'bar', 'barrel, cask', 'base, pedestal, stand', 'basket, handbasket', 'bed', 'bench', 'bicycle', 'blanket, cover', 'blind, screen', 'boat', 'book', 'bookcase', 'booth', 'bottle', 'box', 'bridge, span', 'buffet, counter, sideboard', 'building', 'bulletin board', 'bus', 'cabinet', 'canopy', 'car', 'case, display case, showcase, vitrine', 'ceiling', 'chair', 'chandelier', 'chest of drawers, chest, bureau, dresser', 'clock', 'clothes', 'coffee table', 'column, pillar', 'computer', 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'counter', 'countertop', 'cradle', 'crt screen', 'curtain', 'cushion', 'desk', 'dirt track', 'dishwasher', 'door', 'earth, ground', 'escalator, moving staircase, moving stairway', 'falls', 'fan', 'fence', 'field', 'fireplace', 'flag', 'floor', 'flower', 'food, solid food', 'fountain', 'glass, drinkin

In [None]:
import numpy as np
target_lists = ["chair"]
target_label_ids = [model.config.label2id[l] for l in target_lists]
answer_map = np.zeros_like(semantic_segmentation).astype(np.uint8)
for target_label_id in target_label_ids:
  target_segmentation_map = np.where(semantic_segmentation == target_label_id, 255, 0)
  answer_map += target_segmentation_map.astype(np.uint8)

display(Image.fromarray(answer_map))
Image.fromarray(answer_map).save("/content/ProPainter/inputs/object_removal/room_mask/room002.jpeg")

In [16]:
!cp /content/ProPainter/inputs/object_removal/room_mask/room002.jpeg /content/ProPainter/inputs/object_removal/room_mask/room003.jpeg
!cp /content/ProPainter/inputs/object_removal/room/room002.jpeg /content/ProPainter/inputs/object_removal/room/room003.jpeg

In [None]:
%cd /content/ProPainter
!python inference_with_image.py --video inputs/object_removal/room --mask inputs/object_removal/room_mask

In [None]:
display(Image.open("/content/ProPainter/results/room/out.jpg"))

Inference with oneformer for test

In [16]:
!cp /content/ProPainter/inputs/object_removal/bmx-trees/00003.jpg /content/ProPainter/inputs/object_removal/test/

In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor

from transformers import AutoModelForUniversalSegmentation

model_id = "shi-labs/oneformer_coco_swin_large"
model = AutoModelForUniversalSegmentation.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

In [17]:
image = Image.open("/content/ProPainter/inputs/object_removal/test/00003.jpg")

In [18]:
semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")

# forward pass
with torch.no_grad():
  outputs = model(**semantic_inputs)

semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]

In [None]:
import numpy as np
target_lists = ["person", "bicycle"]
target_label_ids = [model.config.label2id[l] for l in target_lists]
answer_map = np.zeros_like(semantic_segmentation).astype(np.uint8)
for target_label_id in target_label_ids:
  target_segmentation_map = np.where(semantic_segmentation == target_label_id, 255, 0)
  answer_map += target_segmentation_map.astype(np.uint8)

display(Image.fromarray(answer_map))
Image.fromarray(answer_map).save("/content/ProPainter/inputs/object_removal/test_mask/00003.jpg")

In [22]:
# 同じ画像を入れてあげることでいける
!cp /content/ProPainter/inputs/object_removal/test/00003.jpg /content/ProPainter/inputs/object_removal/test/00004.jpg
!cp /content/ProPainter/inputs/object_removal/test_mask/00003.jpg /content/ProPainter/inputs/object_removal/test_mask/00004.jpg

In [None]:
%cd /content/ProPainter
!python inference_with_image.py --video inputs/object_removal/test --mask inputs/object_removal/test_mask

In [None]:
display(Image.open("/content/ProPainter/results/test/out.jpg"))

Appendix

In [None]:
# draw masks
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm


def draw_semantic_segmentation(segmentation):
    # get the used color map
    viridis = cm.get_cmap('viridis', torch.max(segmentation))
    # get all the unique numbers
    labels_ids = torch.unique(segmentation).tolist()
    fig, ax = plt.subplots()
    ax.imshow(segmentation)
    handles = []
    for label_id in labels_ids:
        label = model.config.id2label[label_id]
        color = viridis(label_id)
        handles.append(mpatches.Patch(color=color, label=label))
    ax.legend(handles=handles)

draw_semantic_segmentation(semantic_segmentation)