In [1]:
# jupyter notebook auto reload
%load_ext autoreload
%autoreload 2

In [None]:
from backbones.sana import SANA

# 사용 예
model = SANA()
print(model)

In [3]:
import matplotlib.pyplot as plt

def show_compare(euler_img, dpm_img):
    fig, axes = plt.subplots(1, 2, figsize=(18, 10))
    axes[0].imshow(euler_img); axes[0].axis('off'); axes[0].set_title('Euler')
    axes[1].imshow(dpm_img);  axes[1].axis('off'); axes[1].set_title('Diff-Solver')
    plt.tight_layout()
    plt.show()

In [None]:
# https://github.com/wl-zhao/DC-Solver/blob/f82d0601851b483d1dc06fc97e443588c6c9d9b0/stable-diffusion/scripts/sample_diffusion_dc_solver.py#L77
# https://github.com/wl-zhao/DC-Solver/blob/f82d0601851b483d1dc06fc97e443588c6c9d9b0/stable-diffusion/ldm/models/diffusion/dc_solver/sampler.py#L86
import os 
import torch
import json
from solvers.euler_solver import Euler_Solver

B, C, H, W = model.get_model_fn(pos_text="")[2].shape

print(B, C, H, W)

batch_size = 10 
steps = 10 
gt_steps = 200
guidance_scale=4.5

noise_path = f"solvers/tuning_data/dc_solver/noise/default.pth"
caption_path = f"solvers/tuning_data/dc_solver/caption/default.json"
ddim_gt_path = f"solvers/tuning_data/dc_solver/ddim_gt/default.pth"
dc_ratios_path = f"solvers/tuning_data/dc_solver/dc_ratios/default.json"
neg_text = "lowres, bad anatomy, deformed, blurry, pixelated, oversaturated, underexposed, overexposed, artifact, jpeg artifacts, watermark, text, logo, extra limbs, mutated hands, unnatural colors, noisy background, out of focus, poor composition, cultural clichés, stereotype exaggeration, flat lighting, glitch"

os.makedirs(os.path.dirname(noise_path), exist_ok=True)
os.makedirs(os.path.dirname(caption_path), exist_ok=True)
os.makedirs(os.path.dirname(ddim_gt_path), exist_ok=True)
os.makedirs(os.path.dirname(dc_ratios_path), exist_ok=True)

# 앗 Diffuser 로 해야하네 ??!?!?
# https://github.com/wl-zhao/DC-Solver/blob/f82d0601851b483d1dc06fc97e443588c6c9d9b0/diffusers/scripts/sample_dc_solver.py#L85
# https://github.com/wl-zhao/DC-Solver/blob/f82d0601851b483d1dc06fc97e443588c6c9d9b0/diffusers/src/diffusers/schedulers/scheduling_dcsolver_multistep.py#L873

# 1. noise 파일 저장 
if os.path.isfile(noise_path):
    noise = torch.load(noise_path).to(model.device)
else:
    shape = (batch_size, C, H, W)
    noise = torch.randn(shape, device=model.device, dtype=model.dtype)
    torch.save(noise, noise_path)

# 2. caption 파일 로드 
with open(caption_path, "r") as f:
    captions = json.load(f)
prompts = [item['caption'] for item in captions]

# 3. ddim_gt 파일 저장 
if not os.path.isfile(ddim_gt_path):
    print('ddim gt does not exist, generate for once')
    # ddim_gt = generate_gt( noise, prompts)
    def get_callback(t_list, latent_list):
        def callback(t, latents):
            t_list.append(t)
            latent_list.append(latents)
        return callback
    
    latent_list_batch_dim = []
    
    for i, prompt in enumerate(prompts):
        t_list = []
        latent_list_time_dim = []
        callback = get_callback(t_list, latent_list_time_dim)
        print(i, prompt)

        single_noise = noise[i:i+1]
        model_fn, noise_schedule, _ = model.get_model_fn(pos_text=prompt, neg_text=neg_text, guidance_scale=guidance_scale, num_steps=gt_steps, seed=42)


        # 1. GT 생성을 위해서 get_model_fn 을 batch 형태로 처리할 수 있어야 함. 
        # 2. ***** tuning을 하기 위해서 model_fn을 batch 형태로 만들어야 함. *****


        solver = Euler_Solver(model_fn, noise_schedule, callback_fn=callback)
        _img = solver.sample(single_noise, steps=gt_steps, skip_type='time_uniform_flow', flow_shift=3.0)

        latent_list_batch_dim.append(torch.stack(latent_list_time_dim, dim=1).squeeze(0)) # T, C, H, W

    ddim_gt = {
        'ts': torch.stack(t_list),
        'intermediates': torch.stack(latent_list_batch_dim, dim=0),
    }
    torch.save(ddim_gt, ddim_gt_path)
    
else:
    ddim_gt = torch.load(ddim_gt_path, map_location=model.device)
    
    if ddim_gt['intermediates'].shape[0] > batch_size:
        ddim_gt['intermediates'] = ddim_gt['intermediates'][:batch_size]
        prompts = prompts[:batch_size]

# 4. sample --> Scheduler에 ddim_gt 넣으면 알아서 search mode로 동작하는 듯. 



