In [1]:
from pathlib import Path
import yaml

import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image

import torch
import torch.nn as nn
from torchvision import transforms
from saicinpainting.training.modules.ffc import FFCResNetGenerator

In [2]:
model_loc = Path("checkpoints/big-lama")

ckpt = torch.load(model_loc.joinpath("models/best.ckpt"), map_location="cpu")
gen_weight = {k.replace("generator.", ""):v for k,v in ckpt["state_dict"].items() if k.startswith("generator.")}
hparams = yaml.load(model_loc.joinpath("config.yaml").read_bytes())["generator"]
del hparams["kind"]

model = FFCResNetGenerator(**hparams).eval().to("cuda:1")
model.load_state_dict(gen_weight)

  """


<All keys matched successfully>

In [None]:

model = [nn.ReflectionPad2d(3), FFC_BN_ACT_local(4, 64, 7, 1, 0)]

### downsample
model += [
    FFC_BN_ACT_local(64, 128, 3, 2, 1), 
    FFC_BN_ACT_local(128, 256, 3, 2, 1),
    FFC_BN_ACT_local2global(256, 512, 3, 2, 1, ratio_gout=0.75)
]

mult = 8
feats_num_bottleneck = 512

### resnet blocks
for i in range(18):
    model += [FFCResnetBlock(512, 0.75, enable_lfu=False)]

model += [ConcatTupleLayer()]

### upsample
for i in range(3):
    mult = 2 ** (3 - i)
    model += [nn.ConvTranspose2d(min(1024, 64 * mult),  min(1024, int(64 * mult / 2)),
                                    kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(min(1024, int(64 * mult / 2))),
                nn.ReLU(inplace=True)]

model += [nn.ReflectionPad2d(3), nn.Conv2d(64, 3, kernel_size=7, padding=0)]
model.append(nn.Sigmoid())

model = nn.Sequential(*model)

In [None]:
from PIL import Image, ImageFilter

img_data = Image.open("/datasets/RD/interactive-LAMA/img/webui/21-12-31_04:47:28:876/001-000-2a5ce1afb166.jpg")
# mask = img_data.point(lambda p: p <= 0 and 255)
mask = Image.new('L', img_data.size)
d = img_data.getdata()
new_d = []
for item in d:
    if item[0] in range(240,256) and item[1] in range(0,16) and item[2] in range(240,256):
        new_d.append(255)
    else:
        new_d.append(0)
mask.putdata(new_d)

mask = mask.resize((img_data.width, img_data.height), 3)
mask = mask.filter(ImageFilter.ModeFilter(size=13))


In [None]:
img  = cv2.imread("../interactive-LAMA/img/tigerbro_clean_v2/001-000-2a5ce1afb166.jpg")
img  = (img  / 255.).astype(np.float32)

#mask = cv2.imread("/datasets/RD/interactive-LAMA/img/webui/21-12-31_03:15:43:310/001-000-2a5ce1afb166_mask.jpg")[:,:,0:1]
# mask = cv2.resize(np.asarray(mask), img.shape[:2][::-1])
mask = mask.resize((img.shape[1], img.shape[0]), 3)
mask = np.expand_dims(np.asarray(mask), axis=-1) 
mask = ((mask / 255.) > 0.9).astype(np.float32)

h, w, c = img.shape
out_h = h if h % 8 == 0 else (h // 8 + 1) * 8
out_w = w if w % 8 == 0 else (w // 8 + 1) * 8
img_t  = np.pad(img,  ((0, out_h-h), (0, out_w-w), (0,0)), mode='symmetric')
mask_t = np.pad(mask, ((0, out_h-h), (0, out_w-w), (0,0)), mode='symmetric')

img_t = torch.from_numpy(img_t).permute(2,0,1).to("cuda:0")
mask_t = torch.from_numpy(mask_t).permute(2,0,1).to("cuda:0")
masked_img_t = img_t * (1 - mask_t)
masked_img_t = torch.cat([masked_img_t, mask_t], 0).unsqueeze(0)

with torch.no_grad():
    predicted_image = model(masked_img_t)

inpaint = mask_t * predicted_image + (1 - mask_t) * img_t
predict = inpaint[0].permute(1,2,0).cpu().numpy()
#predict = (predict * 255.).astype(np.uint8)

fig, ax = plt.subplots(1,3,figsize=(24,8))
ax[0].imshow(img[:,:,::-1])
ax[1].imshow(mask[:,:,0])
ax[2].imshow(predict[:,:,::-1])
plt.show()