In [1]:
from omegaconf import OmegaConf
import torch, torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from pytorch_lightning import seed_everything
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import kornia
import os, sys
sys.path.append(os.getcwd()),
sys.path.append('src/clip')
sys.path.append('src/taming-transformers')
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, load_model_from_config, ExemplarAugmentor


  from .autonotebook import tqdm as notebook_tqdm


#### Setup

In [2]:
num_iter = 100  # num of fine-tuning iterations
lr = 1e-5
config= "configs/stable-diffusion/v1-inference.yaml"
ckpt = "ckpt/sd-v1-4-full-ema.ckpt"  # path to SD checkpoint
h = w = 512
scale=8  # cfg scale
ddim_steps= 50
ddim_eta=0.0
seed_everything(42)
n_samples = 4
out_path = "outputs/" 
gpu_id = '0'

Global seed set to 42


#### Load images and model

In [3]:
image_path = "examples/dog.png" 
mask_path = "examples/dog-mask.png" 
image_ref_path = "examples/wn.png"  # exempalr image

device = torch.device(f"cuda:{gpu_id}") if torch.cuda.is_available() else torch.device("cpu")

config = OmegaConf.load(config)
# specify init word for exemplar, e.g., "cat", "toy", ... Special options: '__clip__': auto-select by clip model (default); '__rare__': rare token; '__mean__': average token
config.model.params.personalization_config.params.initializer_words = ['__clip__'] 
config.model.params.personalization_config.params.initializer_images = [image_ref_path] # if initializer_words == ['__clip__'], you must provide the path to the exemplar image. Invliad in other cases
config.model.params.personalization_config.params.placeholder_strings = ['#']  # placeholder word for exemplar, '#' by default, you may change to other symbols.

model = load_model_from_config(config, ckpt, device)
sampler = DDIMSampler(model)
params_to_be_optimized = list(model.model.parameters())
optimizer = torch.optim.Adam(params_to_be_optimized, lr=lr)
os.makedirs(out_path, exist_ok=True)

LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.15.self_attn.v_proj.bias', 'vision_model.embeddings.position_embedding.weight', 'vision_model.encoder.layers.17.self_attn.q_proj.bias', 'vision_model.encoder.layers.6.mlp.fc1.bias', 'vision_model.encoder.layers.18.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.bias', 'vision_model.encoder.layers.18.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.self_attn.q_proj.weight', 'vision_model.encoder.layers.18.layer_norm2.weight', 'vision_model.encoder.layers.22.self_attn.k_proj.weight', 'vision_model.encoder.layers.13.mlp.fc2.weight', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.15.mlp.fc2.weight', 'vision_mod

initializer_words is __clip__, the token will be obtained from examples/wn.png


#### Define some useful functions

In [4]:
D = lambda _x: torch.clamp(model.decode_first_stage(_x), min=-1, max=1).detach() # vae decode
E = lambda _x: model.get_first_stage_encoding(model.encode_first_stage(_x))  # # vae encode
img_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.unsqueeze(0) * 2. - 1)])
mask_transforms = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: (x.unsqueeze(0) > 0).float())])

# text encode
def C(_txt, enable_emb_manager=False):
    _txt = [_txt] if isinstance(_txt,str) else _txt
    with torch.enable_grad() if enable_emb_manager else torch.no_grad(): # # disable grad flow unless we want textual inv
        c = model.get_learned_conditioning(_txt, enable_emb_manager)
        return c
        
# save tensor as image file
def tsave(tensor, save_path, **kwargs):
    save_image(tensor, save_path, normalize=True, scale_each=True, value_range=(-1, 1), **kwargs)

#### Preprocess input images and mask

In [5]:
image = Image.open(image_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)
mask = Image.open(mask_path).convert('L').resize((h,w), Image.Resampling.BILINEAR)
image_ref = Image.open(image_ref_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)

x = img_transforms(image).to(device)
m = mask_transforms(mask).to(device)
x_ref = img_transforms(image_ref).to(device)

x_in = x * (1 - m)
z_xm = E(x_in)
z_ref = E(x_ref)
z_m = F.interpolate(m, size=(h // 8, w // 8)) # latent mask
z_m = kornia.morphology.dilation(z_m, torch.ones((3,3),device=device)) # dilate mask a little bit

attn_mask = {}
for attn_size in [64,32,16,8]:  # create attention masks for multi-scale layers in unet
    attn_mask[str(attn_size**2)]= (F.interpolate(m, (attn_size,attn_size), mode='bilinear'))[0,0,...]

uc = C("").detach()  # null-text emb
c_ref =  C('#',True).detach()  # exemplar emb

exemplar_augmentor = ExemplarAugmentor(mask=mask)

#### Masked Fine-tuning

In [6]:
model.train()
pbar = tqdm(range(num_iter), desc='Fine-tune the model')
for i in pbar:
    optimizer.zero_grad()

    
    x_reff, x_reff_mask = exemplar_augmentor(x_ref)
    z_reff = E(x_reff)
    z_reff_mask = F.interpolate(x_reff_mask, size=(h // 8, w // 8),mode='bilinear')

    t_emb = torch.randint(model.num_timesteps, (1,), device=device)
    noise1 = torch.randn_like(z_xm)
    z_ref_t = model.q_sample(z_reff, t_emb, noise=noise1)
    pred_noise_ref = model.apply_model(z_ref_t, t_emb,  c_ref)
    loss_ref = F.mse_loss(pred_noise_ref * z_reff_mask, noise1 * z_reff_mask)

    t_emb2 = torch.randint(model.num_timesteps, (1,), device=device)
    noise2 = torch.randn_like(z_xm)
    z_bg_t = model.q_sample(z_xm, t_emb2, noise=noise2)
    pred_noise_bg = model.apply_model(z_bg_t, t_emb2, uc)
    loss_bg = F.mse_loss(pred_noise_bg * (1 - z_m), noise2 * (1 - z_m))

    loss = loss_bg + loss_ref
    loss.backward()
    optimizer.step()
    
    losses_dict = {"loss": loss,  "loss_bg": loss_bg, "loss_ref":loss_ref}
    pbar.set_postfix({k: v.item() for k,v in losses_dict.items()})

Fine-tune the model: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s, loss=0.251, loss_bg=0.237, loss_ref=0.0145]     


#### Inference - Exemplar inpainting

In [7]:

with torch.no_grad(), torch.autocast(device.type):
    tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                        conditioning=C('#', True).repeat(n_samples,1,1), 
                        unconditional_conditioning=uc.repeat(n_samples,1,1),
                        blend_interval=[0, 1], 
                        x0=z_xm.repeat(n_samples,1,1,1), 
                        mask=z_m.repeat(n_samples,1,1,1), 
                        attn_mask=attn_mask,
                        x_T=None, 
                        unconditional_guidance_scale=scale, 
                        eta=ddim_eta,
                        verbose=False)

    tsave(D(tmp), os.path.join(out_path, f'Exemplar.jpg'), nrow=n_samples)

Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:19<00:00,  2.60it/s]


 blend_interval = [0, 1] 



#### Inference - Exemplar + Stroke inpainting

In [8]:
# prepare stroke image
stroke_path = "examples/wn-stroke-blue.png" 
stroke = Image.open(stroke_path).convert('RGB').resize((h,w), Image.Resampling.BILINEAR)
x_stroke = img_transforms(stroke).to(device)
x_stroke_mask = (torch.mean(x_stroke,dim=1,keepdim=True) > -1).float()
z_stroke = E(x_stroke)
z_stroke_mask = F.interpolate(x_stroke_mask, size=(h // 8, w // 8))


tau = 0.55 # stroke blending timestep

model.eval()
with torch.no_grad(), torch.autocast(device.type):
    tmp, _ = sampler.sample(S=ddim_steps, batch_size=n_samples, shape=[4, h // 8, w // 8],
                        conditioning={'t': [[0, tau], [tau, 1]], 'c': [C('#', True).repeat(n_samples,1,1), uc.repeat(n_samples,1,1)]}, 
                        unconditional_conditioning=uc.repeat(n_samples,1,1),
                        blend_interval=[[0, 1], [tau, tau + 0.02]], 
                        x0=[z_xm.repeat(n_samples,1,1,1), z_stroke.repeat(n_samples,1,1,1)], 
                        mask=[z_m.repeat(n_samples,1,1,1), 1 - z_stroke_mask.repeat(n_samples,1,1,1)], 
                        attn_mask=attn_mask,
                        x_T=None, 
                        unconditional_guidance_scale=scale, 
                        eta=ddim_eta,
                        verbose=False)

    tsave(D(tmp), os.path.join(out_path, f'Exemplar+Stroke.jpg'), nrow=n_samples)

Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:19<00:00,  2.63it/s]


 blend_interval = [[0, 1], [0.55, 0.5700000000000001]] 

