## Kaggle notebook

In [1]:
print('\n> Cloning the repo')
!git clone https://github.com/saic-mdal/lama.git 
!ls -l lama

In [5]:
print('\n> Install dependencies')
!pip install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 torchtext==0.9
!pip install pytorch-lightning==1.2.9
!pip uninstall fairscale -y
!pip install pyyaml
!pip install hydra hydra-core webdataset
!pip install wldhx.yadisk-direct

In [3]:
print('>fixing opencv')
!pip list | grep headless
!pip uninstall opencv-python-headless -y --quiet
!pip install opencv-python-headless==4.1.2.30 --quiet
!pip list | grep headless

In [4]:
print('\n> Changing the dir to:')
%cd lama
!ls -l
!export PYTHONPATH=/kaggle/working/lama

In [11]:
print('\n> Download the model')
!curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip
!ls -l 
print('\n> Unzip downloaded model')
!unzip -o big-lama.zip
!ls -l 

In [12]:
print("Prepare images and masks")
!curl -L $(yadisk-direct https://disk.yandex.ru/d/xKQJZeVRk5vLlQ) -o LaMa_test_images.zip
!unzip -o LaMa_test_images.zip
!ls -l

In [13]:
!PYTHONPATH=$PYTHONPATH:"/kaggle/working/lama"  python bin/predict.py model.path=/kaggle/working/lama/big-lama indir=/kaggle/working/lama/LaMa_test_images outdir=/kaggle/working/lama/output

In [18]:
from glob import glob
from PIL import Image
import IPython.display as ipd

inputs = sorted(glob('LaMa_test_images/*'))
outputs = sorted(glob('output/*.png'))

from pathlib import Path

print(len(inputs), len(outputs))
display(inputs[:6])
display(outputs[:6])
display(Image.open(inputs[0]))
display(Image.open(inputs[1]))
display(Image.open(outputs[0]))

In [22]:
import random

def make_input_path(ouput):
    name = Path(output).name
    mask = f'LaMa_test_images/{name}'
    org = mask.replace('_mask.png', '.png')
    return org, mask

output = random.choice(outputs)
print(output)
org, mask = make_input_path(output)
print(org)
print(mask)

In [30]:
def show(mask, org, output):
    from PIL import Image
    import numpy as np
    
    def to_np(path):
        img = Image.open(path)
        w, h = img.size
        im  = np.array(img)
        im[:,:h//200,:] = [0,255,0]
        return im

    imgs = [to_np(e) for e in [mask, org, output]]
    big_img = np.concatenate(imgs, axis=1)
    big_img = Image.fromarray(big_img)
    display(big_img)
    
show(mask, org, output)

In [35]:
for i, output in enumerate(outputs):
    org, mask = make_input_path(output)
    show(mask, org, output)
    print(f'{i}/{len(outputs)}')
    print("'Input any key except 'q'")
    ch = input()
    if ch in ['q', 'Q', 'ㅂ']:
        break
    ipd.clear_output(wait=True)
    