In [1]:
# 필수 라이브러리 설치
%pip install -q diffusers transformers accelerate peft bitsandbytes
%pip install xformers==0.0.27.post2

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
# 모델로드 float32   # 1
# VAE dtype  float32    # 이미지를 인코딩 디코딩 하는 역활  # 2
# Mixed Precision 비활성화  AMP
# Learning_rate 1e-5
# BatchSize 1
# LoRa dropout 0.0
# Loss계산 float32
#  Gradient Clipping 추가
#  nan 체크 추가


import torch
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import os

# ================== 설정 ==================
DATASET_NAME = "lambdalabs/naruto-blip-captions"
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
OUTPUT_DIR = "./sdxl-lora-output"
RESOLUTION = 512
BATCH_SIZE = 1          # 수정: 1로 시작
LEARNING_RATE = 1e-5    # 수정: 1e-8은 너무 낮음
GRADIENT_ACCUMULATION_STEPS = 4
EPOCHS = 3
MAX_GRAD_NORM = 1.0

device = "cuda" if torch.cuda.is_available() else "cpu"

# ================== 모델 로드 ==================
print("모델 로딩 중...")
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch. float32,  # 수정: float32로 로드  # 1
    use_safetensors=True,
)

# ================== LoRA 설정 및 적용 ==================
print("LoRA 설정 적용 중...")

lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=[
        "to_q", "to_k", "to_v", "to_out. 0",
    ],
    lora_dropout=0.0,  # 수정: dropout 제거 (안정성)
)

unet = pipe.unet
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()

# ================== 메모리 최적화 ==================
print("메모리 최적화 설정 중...")

try:
    pipe.enable_xformers_memory_efficient_attention()
    print("xformers 활성화 완료")
except Exception as e: 
    print(f"xformers 사용 불가:  {e}")
    pipe.enable_attention_slicing()

pipe.enable_vae_slicing()
torch.cuda.empty_cache()

# ================== 모델 배치 ==================
print("모델 설정 중...")

# VAE는 float32 유지 (안정성)
pipe.vae.to(device, dtype=torch.float32)  # 2
pipe.text_encoder.to(device)
pipe.text_encoder_2.to(device)
unet.to(device)

# 학습하지 않는 모듈 freeze
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)

# VAE를 eval 모드로
pipe.vae.eval()

# ================== SDXL용 헬퍼 함수 ==================
def compute_time_ids(original_size, crops_coords_top_left, target_size, device, dtype):
    add_time_ids = list(original_size + crops_coords_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=device)
    return add_time_ids

def encode_prompt_sdxl(pipe, prompt, device):
    # Text Encoder 1
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    text_input_ids = text_inputs.input_ids. to(device)
    
    with torch.no_grad():
        prompt_embeds = pipe. text_encoder(
            text_input_ids,
            output_hidden_states=True
        )
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds.hidden_states[-2]
    
    # Text Encoder 2
    text_inputs_2 = pipe. tokenizer_2(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer_2.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    text_input_ids_2 = text_inputs_2.input_ids.to(device)
    
    with torch.no_grad():
        prompt_embeds_2 = pipe.text_encoder_2(
            text_input_ids_2,
            output_hidden_states=True
        )
        pooled_prompt_embeds_2 = prompt_embeds_2[0]
        prompt_embeds_2 = prompt_embeds_2.hidden_states[-2]
    
    prompt_embeds = torch.concat([prompt_embeds, prompt_embeds_2], dim=-1)
    
    return prompt_embeds, pooled_prompt_embeds_2

# ================== 데이터셋 로드 ==================
print("데이터셋 로딩 중...")
dataset = load_dataset(DATASET_NAME, split='train[:100]')
print(f'original dataset:  {dataset.column_names}')

transform = transforms.Compose([
    transforms.Resize((RESOLUTION, RESOLUTION)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

def preprocess(example):
    image = example['image'].convert('RGB')
    return {
        'pixel_values': transform(image),
        'caption': example['text']
    }

dataset = dataset. map(preprocess, remove_columns=dataset.column_names)
dataset.set_format(type='torch', columns=['pixel_values', 'caption'])
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# ================== 학습 설정 ==================
optimizer = torch.optim. AdamW(
    filter(lambda p: p.requires_grad, unet.parameters()),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8
)
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder='scheduler')

# Mixed Precision 비활성화 (안정성 우선)
use_amp = False

# ================== 학습 루프 ==================
print("LoRA 학습 시작!")
unet.train()

for epoch in range(EPOCHS):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    total_loss = 0
    valid_steps = 0

    for step, batch in enumerate(progress_bar):
        
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float32)

        # Latent 인코딩 (float32)
        with torch.no_grad():
            latent_dist = pipe.vae.encode(pixel_values).latent_dist
            latents = latent_dist.sample()
            latents = latents * pipe.vae. config.scaling_factor
            latents = latents.to(dtype=unet.dtype)

        # 노이즈 추가
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (latents.shape[0],), device=device
        ).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # 캡션 처리
        captions = batch["caption"]
        if isinstance(captions, torch.Tensor):
            captions = [str(c) for c in captions]
        elif isinstance(captions, str):
            captions = [captions]
        else: 
            captions = list(captions)

        # SDXL 텍스트 임베딩
        prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(pipe, captions, device)
        prompt_embeds = prompt_embeds.to(dtype=unet.dtype)
        pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=unet.dtype)

        # SDXL time_ids 생성
        add_time_ids = compute_time_ids(
            (RESOLUTION, RESOLUTION),
            (0, 0),
            (RESOLUTION, RESOLUTION),
            device,
            unet. dtype
        )
        add_time_ids = add_time_ids.repeat(latents.shape[0], 1)

        # UNet 예측
        noise_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={
                "text_embeds":  pooled_prompt_embeds,
                "time_ids": add_time_ids
            },
            return_dict=False
        )[0]

        # Loss 계산 (float32로 계산)
        loss = torch. nn.functional.mse_loss(noise_pred. float(), noise.float())
        
        # nan 체크
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Step {step}: nan/inf detected, skipping...")
            optimizer.zero_grad()
            continue
        
        loss = loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()

        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, unet.parameters()),
                MAX_GRAD_NORM
            )
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        valid_steps += 1
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    if valid_steps > 0:
        avg_loss = total_loss / valid_steps
        print(f"Epoch {epoch+1} 완료 - 평균 Loss:  {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch+1} 완료 - 유효한 step 없음")

# ================== LoRA 가중치 저장 ==================
print("LoRA 가중치 저장 중...")
os.makedirs(OUTPUT_DIR, exist_ok=True)
unet.save_pretrained(OUTPUT_DIR)
print(f"저장 완료: {OUTPUT_DIR}")

print("LoRA 학습 완료!")

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
Keyword arguments {'dtype': torch.float16} are not expected by StableDiffusionXLPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
# 추론...
import torch
from diffusers import StableDiffusionXLPipeline
from peft import PeftModel
import os
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
LORA_PATH = "./sdxl-lora-output"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 모델 로드
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch. float32,  # 수정: float32로 로드  # 1
    use_safetensors=True,
)
pipe.to(device)
# lora 가중치 로드
pipe.unet = PeftModel.from_pretrained(pipe.unet,LORA_PATH)
# 메모리 최적화
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("xformers 활성화 완료")
except Exception as e: 
    print(f"xformers 사용 불가:  {e}")
    pipe.enable_attention_slicing()

pipe.enable_vae_slicing()
torch.cuda.empty_cache()

# 이미지 생성
prompt = "a ninja cat in naruto anime style, highly detailed"
negative_prompt = "blurry, low quality, distorted, ugly"

image = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=30,
    guidance_scale=7.5,
    width=512,
    height=512,
    generator=torch.Generator(device).manual_seed(42)
).images[0]

# 이미지 저장
output_path = "generated_image.png"
image.save(output_path)

모델전체 저장(예비용으로 LoRA백업)

In [None]:
# 모델로드 float32   # 1
# VAE dtype  float32    # 이미지를 인코딩 디코딩 하는 역활  # 2
# Mixed Precision 비활성화  AMP
# Learning_rate 1e-5
# BatchSize 1
# LoRa dropout 0.0
# Loss계산 float32
#  Gradient Clipping 추가
#  nan 체크 추가


import torch
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import StableDiffusionXLPipeline, DDPMScheduler
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import os

# ================== 설정 ==================
DATASET_NAME = "lambdalabs/naruto-blip-captions"
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
LORA_OUTPUT_DIR = "./sdxl-lora-output"    # LoRA only (backup)
MERGED_OUTPUT_DIR = "./sdxl-merged-output" # hole model save
RESOLUTION = 512
BATCH_SIZE = 1          # 수정: 1로 시작
LEARNING_RATE = 1e-5    # 수정: 1e-8은 너무 낮음
GRADIENT_ACCUMULATION_STEPS = 4
EPOCHS = 3
MAX_GRAD_NORM = 1.0

device = "cuda" if torch.cuda.is_available() else "cpu"

# ================== 모델 로드 ==================
print("모델 로딩 중...")
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch. float32,  # 수정: float32로 로드  # 1
    use_safetensors=True,
)

# ================== LoRA 설정 및 적용 ==================
print("LoRA 설정 적용 중...")

lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=[
        "to_q", "to_k", "to_v", "to_out. 0",
    ],
    lora_dropout=0.0,  # 수정: dropout 제거 (안정성)
)

unet = pipe.unet
unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()

# ================== 메모리 최적화 ==================
print("메모리 최적화 설정 중...")

try:
    pipe.enable_xformers_memory_efficient_attention()
    print("xformers 활성화 완료")
except Exception as e: 
    print(f"xformers 사용 불가:  {e}")
    pipe.enable_attention_slicing()

pipe.enable_vae_slicing()
torch.cuda.empty_cache()

# ================== 모델 배치 ==================
print("모델 설정 중...")

# VAE는 float32 유지 (안정성)
pipe.vae.to(device, dtype=torch.float32)  # 2
pipe.text_encoder.to(device)
pipe.text_encoder_2.to(device)
unet.to(device)

# 학습하지 않는 모듈 freeze
pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)

# VAE를 eval 모드로
pipe.vae.eval()

# ================== SDXL용 헬퍼 함수 ==================
def compute_time_ids(original_size, crops_coords_top_left, target_size, device, dtype):
    add_time_ids = list(original_size + crops_coords_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=device)
    return add_time_ids

def encode_prompt_sdxl(pipe, prompt, device):
    # Text Encoder 1
    text_inputs = pipe.tokenizer(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    text_input_ids = text_inputs.input_ids. to(device)
    
    with torch.no_grad():
        prompt_embeds = pipe. text_encoder(
            text_input_ids,
            output_hidden_states=True
        )
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds.hidden_states[-2]
    
    # Text Encoder 2
    text_inputs_2 = pipe. tokenizer_2(
        prompt,
        padding="max_length",
        max_length=pipe.tokenizer_2.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    text_input_ids_2 = text_inputs_2.input_ids.to(device)
    
    with torch.no_grad():
        prompt_embeds_2 = pipe.text_encoder_2(
            text_input_ids_2,
            output_hidden_states=True
        )
        pooled_prompt_embeds_2 = prompt_embeds_2[0]
        prompt_embeds_2 = prompt_embeds_2.hidden_states[-2]
    
    prompt_embeds = torch.concat([prompt_embeds, prompt_embeds_2], dim=-1)
    
    return prompt_embeds, pooled_prompt_embeds_2

# ================== 데이터셋 로드 ==================
print("데이터셋 로딩 중...")
dataset = load_dataset(DATASET_NAME, split='train[:100]')
print(f'original dataset:  {dataset.column_names}')

transform = transforms.Compose([
    transforms.Resize((RESOLUTION, RESOLUTION)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

def preprocess(example):
    image = example['image'].convert('RGB')
    return {
        'pixel_values': transform(image),
        'caption': example['text']
    }

dataset = dataset. map(preprocess, remove_columns=dataset.column_names)
dataset.set_format(type='torch', columns=['pixel_values', 'caption'])
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# ================== 학습 설정 ==================
optimizer = torch.optim. AdamW(
    filter(lambda p: p.requires_grad, unet.parameters()),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8
)
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder='scheduler')

# Mixed Precision 비활성화 (안정성 우선)
use_amp = False

# ================== 학습 루프 ==================
print("LoRA 학습 시작!")
unet.train()

for epoch in range(EPOCHS):
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    total_loss = 0
    valid_steps = 0

    for step, batch in enumerate(progress_bar):
        
        pixel_values = batch["pixel_values"].to(device, dtype=torch.float32)

        # Latent 인코딩 (float32)
        with torch.no_grad():
            latent_dist = pipe.vae.encode(pixel_values).latent_dist
            latents = latent_dist.sample()
            latents = latents * pipe.vae. config.scaling_factor
            latents = latents.to(dtype=unet.dtype)

        # 노이즈 추가
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (latents.shape[0],), device=device
        ).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # 캡션 처리
        captions = batch["caption"]
        if isinstance(captions, torch.Tensor):
            captions = [str(c) for c in captions]
        elif isinstance(captions, str):
            captions = [captions]
        else: 
            captions = list(captions)

        # SDXL 텍스트 임베딩
        prompt_embeds, pooled_prompt_embeds = encode_prompt_sdxl(pipe, captions, device)
        prompt_embeds = prompt_embeds.to(dtype=unet.dtype)
        pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=unet.dtype)

        # SDXL time_ids 생성
        add_time_ids = compute_time_ids(
            (RESOLUTION, RESOLUTION),
            (0, 0),
            (RESOLUTION, RESOLUTION),
            device,
            unet. dtype
        )
        add_time_ids = add_time_ids.repeat(latents.shape[0], 1)

        # UNet 예측
        noise_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={
                "text_embeds":  pooled_prompt_embeds,
                "time_ids": add_time_ids
            },
            return_dict=False
        )[0]

        # Loss 계산 (float32로 계산)
        loss = torch. nn.functional.mse_loss(noise_pred. float(), noise.float())
        
        # nan 체크
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Step {step}: nan/inf detected, skipping...")
            optimizer.zero_grad()
            continue
        
        loss = loss / GRADIENT_ACCUMULATION_STEPS
        loss.backward()

        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, unet.parameters()),
                MAX_GRAD_NORM
            )
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item()
        valid_steps += 1
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    if valid_steps > 0:
        avg_loss = total_loss / valid_steps
        print(f"Epoch {epoch+1} 완료 - 평균 Loss:  {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch+1} 완료 - 유효한 step 없음")

# ================== LoRA 가중치 저장 ==================
print("LoRA 가중치 저장 중...")
os.makedirs(LORA_OUTPUT_DIR, exist_ok=True)
unet.save_pretrained(LORA_OUTPUT_DIR)
print(f"저장 완료: {LORA_OUTPUT_DIR}")
# ================== LoRA 병합 및 전체모델 저장 ==================
# LoRA를 기본 모델에 병합
unet = unet.merge_and_unload()
# 병합된 unet을 파이프라인에 다시 할당
pipe.unet = unet
# 전체모델 저장
os.makedirs(MERGED_OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(MERGED_OUTPUT_DIR,safe_serialization=True)
print(f"저장 완료: {MERGED_OUTPUT_DIR}")

print("LoRA 학습 완료!")

In [None]:
# 추론 - 모델 전체


import torch
from diffusers import StableDiffusionXLPipeline
import zipfile

# 압축 해제
print("압축 해제 중...")
with zipfile.ZipFile("sdxl-merged-model.zip", 'r') as zip_ref:
    zip_ref.extractall("./")
print("압축 해제 완료")

# 모델 로드 (다운로드 없음!)
print("모델 로딩 중...")
device = "cuda"

pipe = StableDiffusionXLPipeline.from_pretrained(
    "./sdxl-merged-model",
    torch_dtype=torch. float16,
    local_files_only=True  # 로컬에서만 로드
)
pipe.to(device)
pipe.enable_vae_slicing()
torch.cuda.empty_cache()

# 이미지 생성
print("이미지 생성 중...")
image = pipe(
    prompt="a ninja cat in naruto anime style, highly detailed",
    negative_prompt="blurry, low quality, distorted, ugly",
    num_inference_steps=30,
    guidance_scale=7.5,
    width=512,
    height=512,
    generator=torch.Generator(device=device).manual_seed(42)
).images[0]

image.save("output.png")
print("저장 완료:  output.png")
image