In [None]:
from omegaconf import OmegaConf
import torch, torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from pytorch_lightning import seed_everything
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import kornia
import os, sys
sys.path.append(os.getcwd()),
sys.path.append('src/clip')
sys.path.append('src/taming-transformers')
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, load_model_from_config

#### Setup

In [None]:
num_iter = 100  # num of fine-tuning iterations
lr = 1e-5
config= "configs/stable-diffusion/v1-inference.yaml"
ckpt = "ckpt/sd-v1-4-full-ema.ckpt"  # path to SD checkpoint
h = w = 512
scale=8  # cfg scale
ddim_steps= 50
ddim_eta=0.0
seed_everything(42)
n_samples = 4
out_path = "outputs/" 
gpu_id = '0'

#### Load model

In [None]:

device = torch.device(f"cuda:{gpu_id}") if torch.cuda.is_available() else torch.device("cpu")
config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt, device)
sampler = DDIMSampler(model)
params_to_be_optimized = list(model.model.parameters())
optimizer = torch.optim.Adam(params_to_be_optimized, lr=lr)
os.makedirs(out_path, exist_ok=True)

#### Define some useful functions

In [None]:
D = lambda _x: torch.clamp(model.decode_first_stage(_x), min=-1, max=1).detach() # vae decode
E = lambda _x: model.get_first_stage_encoding(model.encode_first_stage(_x))  # # vae encode
img_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.unsqueeze(0) * 2. - 1)])
mask_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x.unsqueeze(0) > 0).float())])

# text encode
def C(_txt, enable_emb_manager=False):
    _txt = [_txt] if isinstance(_txt,str) else _txt
    with torch.enable_grad() if enable_emb_manager else torch.no_grad(): # # disable grad flow unless we want textual inv
        c = model.get_learned_conditioning(_txt, enable_emb_manager)
        return c

# save tensor as image file
def tsave(tensor, save_path, **kwargs):
    save_image(tensor, save_path, normalize=True, scale_each=True, value_range=(-1, 1), **kwargs)

#### Read image and mask, and preprocess

In [None]:
image_path = "examples/dog.png" 
mask_path = "examples/dog-mask.png" 

# read image and mask
image = Image.open(image_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)
mask = Image.open(mask_path).convert('L').resize((h,w), Image.Resampling.BILINEAR)

# x = img_transforms(image).repeat(n_samples, 1, 1, 1).to(device)
x = img_transforms(image).to(device)
# m = mask_transforms(mask).repeat(n_samples, 1, 1, 1).to(device)
m = mask_transforms(mask).to(device)
x_in = x * (1 - m)
z_xm = E(x_in)
z_m = F.interpolate(m, size=(h // 8, w // 8)) # latent mask
z_m = kornia.morphology.dilation(z_m, torch.ones((3,3),device=device)) # dilate mask a little bit

attn_mask = {}
for attn_size in [64,32,16,8]:  # create attention masks for multi-scale layers in unet
    attn_mask[str(attn_size**2)]= (F.interpolate(m, (attn_size,attn_size), mode='bilinear'))[0,0,...]

uc = C("")  # null-text emb

#### Masked Fine-tuning

In [None]:
model.train()
pbar = tqdm(range(num_iter), desc='Fine-tune the model')
for i in pbar:
    optimizer.zero_grad()
    noise = torch.randn_like(z_xm)
    t_emb = torch.randint(model.num_timesteps, (1,), device=device)
    z_t = model.q_sample(z_xm, t_emb, noise=noise)
    pred_noise = model.apply_model(z_t, t_emb, uc)

    loss_noise = F.mse_loss(pred_noise * (1 - z_m), noise * (1 - z_m))
    loss = loss_noise

    losses_dict = {"loss": loss}
    pbar.set_postfix({k: v.item() for k,v in losses_dict.items()})

    loss.backward()
    optimizer.step()

#### Inference - Unconditional inpainting

In [None]:
model.eval()
with torch.no_grad(), torch.autocast(device.type):
    # uncond inpainting
    tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                            conditioning=uc.repeat(n_samples,1,1), 
                            unconditional_conditioning=uc.repeat(n_samples,1,1),
                            blend_interval=[0, 1], 
                            x0=z_xm.repeat(n_samples,1,1,1), 
                            mask=z_m.repeat(n_samples,1,1,1), 
                            attn_mask=attn_mask,
                            x_T=None, 
                            unconditional_guidance_scale=scale, 
                            eta=ddim_eta,
                            verbose=False)

    tsave(D(tmp), os.path.join(out_path, f'Uncond.jpg'), nrow=n_samples)


#### Inference - Text inpainting

In [None]:
prompt = 'a vase of flower'


with torch.no_grad(), torch.autocast(device.type):
    tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                        conditioning=C(prompt).repeat(n_samples,1,1), 
                        unconditional_conditioning=uc.repeat(n_samples,1,1),
                        blend_interval=[0, 1], 
                        x0=z_xm.repeat(n_samples,1,1,1), 
                        mask=z_m.repeat(n_samples,1,1,1), 
                        attn_mask=attn_mask,
                        x_T=None, 
                        unconditional_guidance_scale=scale, 
                        eta=ddim_eta,
                        verbose=False)

    tsave(D(tmp), os.path.join(out_path, f'Text-{prompt}.jpg'), nrow=n_samples)

#### Inference - Stroke inpainting

In [None]:
# prepare stroke image
stroke_path = "examples/wn-stroke-blue.png" 
stroke = Image.open(stroke_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)
x_stroke = img_transforms(stroke).to(device)
x_stroke_mask = (torch.mean(x_stroke,dim=1,keepdim=True) > -1).float()
z_stroke = E(x_stroke)
z_stroke_mask = F.interpolate(x_stroke_mask, size=(h // 8, w // 8))


prompt = 'a toy bear'
tau = 0.55 # stroke blending timestep

model.eval()
with torch.no_grad(), torch.autocast(device.type):
    tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                        conditioning={'t': [[0, tau], [tau, 1]], 'c': [C(prompt).repeat(n_samples,1,1), uc.repeat(n_samples,1,1)]}, 
                        unconditional_conditioning=uc.repeat(n_samples,1,1),
                        blend_interval=[[0, 1], [tau, tau + 0.02]], 
                        x0=[z_xm.repeat(n_samples,1,1,1), z_stroke.repeat(n_samples,1,1,1)], 
                        mask=[z_m.repeat(n_samples,1,1,1), 1 - z_stroke_mask.repeat(n_samples,1,1,1)], 
                        attn_mask=attn_mask,
                        x_T=None, 
                        unconditional_guidance_scale=scale, 
                        eta=ddim_eta,
                        verbose=False)

    tsave(D(tmp), os.path.join(out_path, f'Stroke-{prompt}.jpg'), nrow=n_samples)