In [1]:
# jupyter notebook auto reload
%load_ext autoreload
%autoreload 2

In [None]:
from backbones.sana import SANA

# 사용 예
model = SANA()
print(model)

In [3]:
import matplotlib.pyplot as plt

def show_compare(euler_img, dpm_img):
    fig, axes = plt.subplots(1, 2, figsize=(18, 10))
    axes[0].imshow(euler_img); axes[0].axis('off'); axes[0].set_title('Euler')
    axes[1].imshow(dpm_img);  axes[1].axis('off'); axes[1].set_title('Diff-Solver')
    plt.tight_layout()
    plt.show()

In [None]:
from solvers.euler_solver import Euler_Solver
from solvers.dpm_solver import DPM_Solver
from solvers.sa_solver import SA_Solver
from solvers.unipc_solver import UniPC_Solver

pos_text = "A serene twilight view of a traditional Korean hanok village nestled between misty mountain slopes: curved midnight-blue tiled eaves, softly glowing paper lanterns swaying in the breeze; ancient pine trees arching over stone pathways; intricate wooden lattice windows casting delicate shadows; a lone scholar in flowing hanbok practicing calligraphy beside a koi pond with lotus petals drifting on the water; cinematic 8K ultra-realism with dynamic volumetric moonlight filtering through morning mist; painterly strokes blending classical Joseon-era ink wash with modern hyperrealism; shot on RED Monstro 8K, 50 mm f/1.2 lens; subtle film grain; maximum fidelity; emotional atmosphere."
neg_text = "lowres, bad anatomy, deformed, blurry, pixelated, oversaturated, underexposed, overexposed, artifact, jpeg artifacts, watermark, text, logo, extra limbs, mutated hands, unnatural colors, noisy background, out of focus, poor composition, cultural clichés, stereotype exaggeration, flat lighting, glitch"

model_fn, noise_schedule, latents = model.get_model_fn(pos_text=pos_text, neg_text=neg_text, guidance_scale=4.5, num_steps=20, seed=42)

solver = Euler_Solver(model_fn, noise_schedule)
latent_samples = solver.sample(latents, steps=10, skip_type='time_uniform_flow', flow_shift=3.0)
pixel_samples = model.decode_vae(latent_samples)
euler_sample = pixel_samples[0]

solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
latent_samples = solver.sample(latents, steps=10, order=2, skip_type="time_uniform_flow", method="multistep", flow_shift=3.0)
pixel_samples = model.decode_vae(latent_samples)
dpm_sample = pixel_samples[0]

show_compare(euler_sample, dpm_sample)

In [None]:

solver = SA_Solver(model_fn, noise_schedule)
tau = lambda t: 0.6 if 0.2 <= t <= 0.8 else 0
latent_samples = solver.sample(latents, tau, steps=10, predictor_order=2, corrector_order=3, skip_type="time_uniform_flow", flow_shift=3.0, t_start=0.999, t_end=1e-3)
pixel_samples = model.decode_vae(latent_samples)
sa_sample = pixel_samples[0]

show_compare(euler_sample, sa_sample)

In [None]:
solver = UniPC_Solver(model_fn, noise_schedule)
latent_samples = solver.sample(latents, steps=10, order=3, skip_type="time_uniform_flow", flow_shift=3.0, t_start=0.999, t_end=1e-3)
pixel_samples = model.decode_vae(latent_samples)
unipc_sample = pixel_samples[0]

show_compare(euler_sample, sa_sample)
