## StyleBlend Inference

### Prepare

In [None]:
import torch
import os
from diffusers import DDIMScheduler
from omegaconf import OmegaConf
from src.pipeline import StyleBlendT2IPipeline

device = 'cuda'
cfg = OmegaConf.load('./configs/inference_config.yaml')   
dataset_root = cfg.sample_dir
weights_dir = cfg.weights_dir

# TODO: specify the style name
style_name = 'style1'

pipeline: StyleBlendT2IPipeline = StyleBlendT2IPipeline.from_pretrained(
    cfg.pretrained_model_path, torch_dtype=torch.float16
).to(device)

pipeline.scheduler = DDIMScheduler.from_config(
    os.path.join(cfg.pretrained_model_path, 'scheduler/scheduler_config.json'))

pipeline.load_styleblend_weights(
    te_lora_path=os.path.join(weights_dir, style_name, f'{style_name}_text_encoder_lora.bin'),
    unet_lora_path=os.path.join(weights_dir, style_name, f'{style_name}_unet_lora.bin'),
    texture_style_embeds_path=os.path.join(weights_dir, style_name, f'{style_name}_texture_style_embeds.bin'),
    composition_style_embeds_path=os.path.join(weights_dir, style_name, f'{style_name}_composition_style_embeds.bin'),
    placeholder_composition_style=cfg.placeholder_composition_style,
    placeholder_texture_style=cfg.placeholder_texture_style,
)

### Inference

Some tips for configuring parameters for inference:
- The default feature layers to register (`c2t_self_attn_layers_to_register` and `t2c_self_attn_layers_to_register`) generally work well for most style cases. For styles prone to overfitting, we can register more layers for `c2t` while fewer for `t2c`. 
- Start with using the middle layers for `c2t` and the side layers for `t2c`. There are 16 layers in SD.

In [None]:
# TODO: specify the text prompt
prompt = 'Tower Bridge'
resolution = (768, 768)

latents = torch.randn([1, 4, resolution[1]//8, resolution[0]//8]).to(device='cuda', dtype=torch.float16)
pipeline.unregister_styleblend_modules()
pipeline.register_styleblend_modules(
    c2t_self_attn_layers_to_register=[4, 5, 6, 7, 8, 9],
    t2c_self_attn_layers_to_register=[0, 1, 2, 3, 10, 11, 12, 13, 14, 15],
    scale=0.3,
    c2t_step_ratio=0.8,
    t2c_step_ratio=0.6,
)

image = pipeline(prompt, num_inference_steps=30, latents=latents, negative_prompt=[''], eta=1., guidance_scale=7.5).images[1]
display(image)