In [1]:
import torch
import os
from PIL import Image
from torchvision.transforms import ToPILImage

# 출력 디렉토리 설정
output_dir = "./test_results"
os.makedirs(output_dir, exist_ok=True)

# Stable Diffusion 관련 라이브러리 임포트
from LocalStableDiffusionPipeline_Guide import  LocalStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler, UNet2DConditionModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_and_save_images(
    pipeline,
    prompt,
    negative_prompt,
    image,
    num_inference_steps=50,
    strength=0.8,
    guidance_scale=7.5,
    #lambda_scale=None,  # CFG++에 사용
    #cfg_type="cfg",  # 'cfg' 또는 'cfg++'
    output_dir="./test_results",
    file_prefix="test"
):
    """
    Stable Diffusion Pipeline에서 이미지를 생성하고 저장.

    Args:
        pipeline (StableDiffusionPipeline): 수정된 Stable Diffusion 파이프라인.
        prompt (str): 텍스트 프롬프트.
        negative_prompt (str): 부정적 텍스트 프롬프트.
        image (PIL.Image): 초기 이미지.
        num_inference_steps (int): 샘플링 스텝 수.
        strength (float): 이미지에서 얼마나 강하게 변화를 줄 것인지 (0~1).
        guidance_scale (float): CFG 강도.
        lambda_scale (float): CFG++의 내삽 스케일 (CFG에서는 None 사용).
        cfg_type (str): 'cfg' 또는 'cfg++' 선택.
        output_dir (str): 생성된 이미지를 저장할 디렉토리.
        file_prefix (str): 파일 이름 접두사.
    """
    # 파이프라인에 CFG++ 설정 적용
    #pipeline.scheduler.use_cfg_plus = (cfg_type == "cfg++")  # CFG++ 활성화 여부
    #pipeline.scheduler.lambda_scale = lambda_scale if lambda_scale is not None else 0.5

    # 이미지 생성
    result = pipeline(
        prompt=prompt,
        image=image,
        strength=strength,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt,
    )

    # 결과 저장
    result_image = result.images[0]  # 첫 번째 이미지를 저장
    output_path = os.path.join(output_dir, f"{file_prefix}_scale_{guidance_scale}.png")
    result_image.save(output_path)
    print(f"이미지 저장 완료: {output_path}")

In [3]:
# 수정된 Stable Diffusion 파이프라인 로드
model_id = "runwayml/stable-diffusion-v1-5"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").half()

pipeline = LocalStableDiffusionPipeline.from_pretrained(
    model_id,
    unet=unet,
    scheduler=scheduler,
    safety_checker=None,
    #num_inference_steps=inference_step,
    torch_dtype=torch.float16
).to("cuda")
# 테스트 프롬프트 및 초기 설정
prompt = "A futuristic cityscape at sunset, ultra-detailed, 4K"
negative_prompt = "blurry, distorted, low resolution"
image = Image.open("init.JPEG")  # 초기 이미지 경로
num_inference_steps = 50
strength = 0.8
guidance_scale = 7.5

# CFG 테스트
generate_and_save_images(
    pipeline=pipeline,
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=image,
    num_inference_steps=num_inference_steps,
    strength=strength,
    file_prefix="cityscape"
)

# CFG++ 테스트
generate_and_save_images(
    pipeline=pipeline,
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=image,
    num_inference_steps=num_inference_steps,
    strength=strength,
    guidance_scale= 0.7,
    file_prefix="cityscape"
)

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 16.98it/s]
You have disabled the safety checker for <class 'LocalStableDiffusionPipeline_Guide.LocalStableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
100%|██████████| 40/40 [00:03<00:00, 11.38it/s]
  deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)


이미지 저장 완료: ./test_results/cityscape_scale_7.5.png


100%|██████████| 40/40 [00:03<00:00, 12.01it/s]


이미지 저장 완료: ./test_results/cityscape_scale_0.7.png


In [None]:
from IPython.display import display

# 이미지 로드
cfg_image = Image.open("./test_results/cityscape_cfg.png")
cfg_plus_image = Image.open("./test_results/cityscape_cfg++.png")

# 결과 시각화
print("CFG 결과:")
display(cfg_image)

print("CFG++ 결과:")
display(cfg_plus_image)

In [5]:
"""
This module includes LDM-based inverse problem solvers.
Forward operators follow DPS and DDRM/DDNM.
"""

from typing import Any, Callable, Dict, Optional

import torch
from diffusers import DPMSolverMultistepScheduler,DDIMScheduler, StableDiffusionPipeline
from tqdm import tqdm

####### Factory #######
__SOLVER__ = {}

def register_solver(name: str):
    def wrapper(cls):
        if __SOLVER__.get(name, None) is not None:
            raise ValueError(f"Solver {name} already registered.")
        __SOLVER__[name] = cls
        return cls
    return wrapper

def get_solver(name: str, **kwargs):
    if name not in __SOLVER__:
        raise ValueError(f"Solver {name} does not exist.")
    return __SOLVER__[name](**kwargs)

########################

def get_ancestral_step(sigma_from, sigma_to, eta=1.):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    if not eta:
        return sigma_to, 0.
    sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up


def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n+1, device=device)[:-1]
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)

########################

class StableDiffusion():
    def __init__(self,
                 solver_config: Dict,
                 model_key:str="runwayml/stable-diffusion-v1-5",
                 device: Optional[torch.device]='cuda',
                 **kwargs):
        self.device = device

        self.dtype = kwargs.get("pipe_dtype", torch.float16)
        pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.dtype).to(device)
        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")        
        self.total_alphas = self.scheduler.alphas_cumprod.clone()
        
        self.sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt()
        self.log_sigmas = self.sigmas.log()
        
        total_timesteps = len(self.scheduler.timesteps)
        self.scheduler.set_timesteps(solver_config.num_sampling, device=device)
        self.skip = total_timesteps // solver_config.num_sampling

        self.final_alpha_cumprod = self.scheduler.final_alpha_cumprod.to(device)
        self.scheduler.alphas_cumprod = torch.cat([torch.tensor([1.0]), self.scheduler.alphas_cumprod])

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        self.sample(*args, **kwargs)

    def sample(self, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError("Solver must implement sample() method.")

    def alpha(self, t):
        at = self.scheduler.alphas_cumprod[t] if t >= 0 else self.final_alpha_cumprod
        return at

    @torch.no_grad()
    def get_text_embed(self, null_prompt, prompt):
        """
        Get text embedding.
        args:
            null_prompt (str): null text
            prompt (str): guidance text
        """
        # null text embedding (negation)
        null_text_input = self.tokenizer(null_prompt,
                                         padding='max_length',
                                         max_length=self.tokenizer.model_max_length,
                                         return_tensors="pt",)
        null_text_embed = self.text_encoder(null_text_input.input_ids.to(self.device))[0]

        # text embedding (guidance)
        text_input = self.tokenizer(prompt,
                                    padding='max_length',
                                    max_length=self.tokenizer.model_max_length,
                                    return_tensors="pt",
                                    truncation=True)
        text_embed = self.text_encoder(text_input.input_ids.to(self.device))[0]

        return null_text_embed, text_embed

    def encode(self, x):
        """
        xt -> zt
        """
        return self.vae.encode(x).latent_dist.sample() * 0.18215

    def decode(self, zt):
        """
        zt -> xt
        """
        zt = 1/0.18215 * zt
        img = self.vae.decode(zt).sample.float()
        return img

    def predict_noise(self,
                      zt: torch.Tensor,
                      t: torch.Tensor,
                      uc: torch.Tensor,
                      c: torch.Tensor):
        """
        compuate epsilon_theta for null and condition
        args:
            zt (torch.Tensor): latent features
            t (torch.Tensor): timestep
            uc (torch.Tensor): null-text embedding
            c (torch.Tensor): text embedding
        """
        t_in = t.unsqueeze(0)
        if uc is None:
            noise_c = self.unet(zt, t_in, encoder_hidden_states=c)['sample']
            noise_uc = noise_c
        elif c is None:
            noise_uc = self.unet(zt, t_in, encoder_hidden_states=uc)['sample']
            noise_c = noise_uc
        else:
            c_embed = torch.cat([uc, c], dim=0)
            z_in = torch.cat([zt] * 2)
            t_in = torch.cat([t_in] * 2)
            noise_pred = self.unet(z_in, t_in, encoder_hidden_states=c_embed)['sample']
            noise_uc, noise_c = noise_pred.chunk(2)

        return noise_uc, noise_c

    @torch.no_grad()
    def inversion(self,
                  z0: torch.Tensor,
                  uc: torch.Tensor,
                  c: torch.Tensor,
                  cfg_guidance: float=1.0):

        # initialize z_0
        zt = z0.clone().to(self.device)

        # loop
        pbar = tqdm(reversed(self.scheduler.timesteps), desc='DDIM Inversion')
        for _, t in enumerate(pbar):
            at = self.alpha(t)
            at_prev = self.alpha(t - self.skip)

            noise_uc, noise_c = self.predict_noise(zt, t, uc, c)
            noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc)

            z0t = (zt - (1-at_prev).sqrt() * noise_pred) / at_prev.sqrt()
            zt = at.sqrt() * z0t + (1-at).sqrt() * noise_pred

        return zt

    def initialize_latent(self,
                          method: str='random',
                          src_img: Optional[torch.Tensor]=None,
                          **kwargs):
        if method == 'ddim':
            z = self.inversion(self.encode(src_img.to(self.dtype).to(self.device)),
                               kwargs.get('uc'),
                               kwargs.get('c'),
                               cfg_guidance=kwargs.get('cfg_guidance', 0.0))
        elif method == 'npi':
            z = self.inversion(self.encode(src_img.to(self.dtype).to(self.device)),
                               kwargs.get('c'),
                               kwargs.get('c'),
                               cfg_guidance=1.0)
        elif method == 'random':
            size = kwargs.get('latent_dim', (1, 4, 64, 64))
            z = torch.randn(size).to(self.device)
        elif method == 'random_kdiffusion':
            size = kwargs.get('latent_dim', (1, 4, 64, 64))
            sigmas = kwargs.get('sigmas', [14.6146])
            z = torch.randn(size).to(self.device)
            z = z * (sigmas[0] ** 2 + 1) ** 0.5
        else:
            raise NotImplementedError

        return z.requires_grad_()
    
    def timestep(self, sigma):
        log_sigma = sigma.log()
        dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
        return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)

    def to_d(self, x, sigma, denoised):
        '''converts a denoiser output to a Karras ODE derivative'''
        return (x - denoised) / sigma.item()
    
    def get_ancestral_step(self, sigma_from, sigma_to, eta=1.):
        """Calculates the noise level (sigma_down) to step down to and the amount
        of noise to add (sigma_up) when doing an ancestral sampling step."""
        if not eta:
            return sigma_to, 0.
        sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
        sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
        return sigma_down, sigma_up
    
    def calculate_input(self, x, sigma):
        return x / (sigma ** 2 + 1) ** 0.5
    
    def calculate_denoised(self, x, model_pred, sigma):
        return x - model_pred * sigma
    
    def kdiffusion_x_to_denoised(self, x, sigma, uc, c, cfg_guidance, t):
        xc = self.calculate_input(x, sigma)
        noise_uc, noise_c = self.predict_noise(xc, t, uc, c)
        noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc)
        denoised = self.calculate_denoised(x, noise_pred, sigma)
        uncond_denoised = self.calculate_denoised(x, noise_uc, sigma)
        return denoised, uncond_denoised



@register_solver("dpm++_2m_cfg++")
class DPMpp2mCFGppSolver(StableDiffusion):
    @torch.autocast(device_type='cuda', dtype=torch.float16)
    def sample(self, cfg_guidance, prompt=["", ""], callback_fn=None, **kwargs):
        t_fn = lambda sigma: sigma.log().neg()
        sigma_fn = lambda t: t.neg().exp()
        # Text embedding
        uc, c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])
        # convert to karras sigma scheduler
        total_sigmas = (1-self.total_alphas).sqrt() / self.total_alphas.sqrt()
        sigmas = get_sigmas_karras(len(self.scheduler.timesteps), total_sigmas.min(), total_sigmas.max(), rho=7.)
        # initialize
        x = self.initialize_latent(method="random_kdiffusion",
                                   latent_dim=(1, 4, 64, 64),
                                   sigmas=sigmas).to(torch.float16)
        old_denoised = None # buffer
        # Sampling
        pbar = tqdm(self.scheduler.timesteps, desc="SD")
        for i, _ in enumerate(pbar):
            sigma = sigmas[i]
            new_t = self.timestep(sigma).to(self.device)
            
            with torch.no_grad():
                denoised, uncond_denoised = self.kdiffusion_x_to_denoised(x, sigma, uc, c, cfg_guidance, new_t)

            # solve ODE one step
            t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i+1])
            h = t_next - t
            if old_denoised is None or sigmas[i+1] == 0:
                x = denoised + self.to_d(x, sigmas[i], uncond_denoised) * sigmas[i+1]
            else:
                h_last = t - t_fn(sigmas[i-1])
                r = h_last / h
                extra1 = -torch.exp(-h) * uncond_denoised - (-h).expm1() * (uncond_denoised - old_denoised) / (2*r)
                extra2 = torch.exp(-h) * x
                x = denoised + extra1 + extra2
            old_denoised = uncond_denoised

            if callback_fn is not None:
                callback_kwargs = { 'z0t': denoised.detach(),
                                    'zt': x.detach(),
                                    'decode': self.decode}
                callback_kwargs = callback_fn(i, new_t, callback_kwargs)
                denoised = callback_kwargs["z0t"]
                x = callback_kwargs["zt"]
        
        # for the last step, do not add noise
        img = self.decode(x)
        img = (img / 2 + 0.5).clamp(0, 1)
        return img.detach().cpu()


#############################

if __name__ == "__main__":
    # print all list of solvers
    print(f"Possble solvers: {[x for x in __SOLVER__.keys()]}")

Possble solvers: ['dpm++_2m_cfg++']


In [6]:
def test_solver(
    solver,
    cfg_guidance=7.5,
    prompt=["A futuristic cityscape at sunset", "A futuristic cityscape at sunset"],
    num_inference_steps=50,
    save_path="./cfg_plus_test_results/test_image.png",
    verbose=True,
):
    """
    CFG++ Solver 테스트 함수
    Args:
        solver (StableDiffusion): DPMpp2mCFGppSolver 객체.
        cfg_guidance (float): CFG 가이드 강도.
        prompt (list): 텍스트 프롬프트. 첫 번째는 null prompt, 두 번째는 텍스트 조건.
        num_inference_steps (int): 샘플링 스텝 수.
        save_path (str): 결과 이미지를 저장할 경로.
        verbose (bool): 진행 상황을 출력할지 여부.
    """
    # Solver 초기화
    solver.scheduler.set_timesteps(num_inference_steps, device=solver.device)
    
    # 이미지 생성
    if verbose:
        print(f"CFG++ 테스트 시작 - Prompt: {prompt[1]}, CFG Scale: {cfg_guidance}")
    img = solver.sample(cfg_guidance=cfg_guidance, prompt=prompt)

    # 이미지 저장
    img_pil = ToPILImage()(img.squeeze(0))  # 텐서를 PIL 이미지로 변환
    img_pil.save(save_path)
    if verbose:
        print(f"결과 이미지 저장 완료: {save_path}")

    return img_pil

In [7]:
# Solver 생성
from dataclasses import dataclass
import torch
from PIL import Image
from torchvision.transforms import ToPILImage
import os

@dataclass
class SolverConfig:
    num_sampling: int
    # 필요한 다른 매개변수를 추가 가능

solver_config = SolverConfig(
    num_sampling=50  # 샘플링 스텝 수
)

solver = get_solver('dpm++_2m_cfg++', solver_config=solver_config, model_key="runwayml/stable-diffusion-v1-5")

# 테스트 프롬프트 및 가이드 스케일
prompt = [
    "",  # Null prompt for unconditional guidance
    "A futuristic cityscape at sunset, ultra-detailed, 4K, vibrant colors",
]
cfg_guidance = 0.6  # Classifier-Free Guidance scale

# 결과 이미지 생성 및 저장
output_path = os.path.join('./', "cityscape_cfg_plus.png")
result_image = test_solver(
    solver=solver,
    cfg_guidance=cfg_guidance,
    prompt=prompt,
    num_inference_steps=50,
    save_path=output_path,
)

Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  6.48it/s]


CFG++ 테스트 시작 - Prompt: A futuristic cityscape at sunset, ultra-detailed, 4K, vibrant colors, CFG Scale: 0.6


SD: 100%|██████████| 50/50 [00:02<00:00, 17.62it/s]


결과 이미지 저장 완료: ./cityscape_cfg_plus.png
