# 스테이블 디퓨전으로 이미지 생성 서비스 만들기
- 간단한 스케치를 기반으로 이미지 생성
- 스케치가 되어 있는 이미지를 업로드해서 생성

In [2]:
# !pip install diffusers

In [9]:
import os
from typing import IO
import gradio as gr
import requests
from tqdm import tqdm
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import torch

In [2]:
WIDTH = 512
HEIGHT = 512

with gr.Blocks() as app :
    gr.Markdown("## 프롬프트 입력")
    with gr.Row() :
        prompt = gr.Textbox(label = 'Prompt')
    with gr.Row() :
        n_prompt = gr.Textbox(label = "negative prompt")
        
    gr.Markdown('## 스케치 to 이미지 생성')
    with gr.Row() :
        with gr.Column() :
            with gr.Tab("Canvas") :
                with gr.Row() :
                    canvas = gr.Image(
                        label = 'Draw',
                        source = 'canvas',
                        image_mode = 'RGB',
                        tool = 'color-sketch',
                        interactive = True,
                        width = WIDTH,
                        height = HEIGHT,
                        shape = (WIDTH, HEIGHT),
                        brush_radius = 20,
                        type = 'pil'
                    )
                with gr.Row() :
                    canvas_run_btn = gr.Button(value = 'Generate')

            with gr.Tab("File") :
                with gr.Row() :
                    file = gr.Image(
                        label = 'Upload',
                        source = 'upload',
                        image_mode = 'RGB',
                        tool = 'color-sketch',
                        interactive = True,
                        width = WIDTH,
                        height = HEIGHT,
                        shape = (WIDTH, HEIGHT),
                        type = 'pil'
                    )
                with gr.Row() :
                    file_run_btn = gr.Button(value = 'Generate')

IMPORTANT: You are using gradio version 3.40.0, however version 4.29.0 is available, please upgrade.
--------


In [6]:
app.launch(inline = False, share = True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://a8b6f5f98efdf2363c.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [7]:
app.close()

Closing server running on port: 7860


## 모델 다운로드 UI 구현하기

In [5]:
with gr.Blocks() as app :
    gr.Markdown('## 모델 다운로드')
    with gr.Row() :
        model_url = gr.Textbox(label = '모델 URL', placeholder = 'http://civitai.com/')
        download_model_btn = gr.Button(value = '모델 다운로드')
        
    with gr.Row() :
        model_file = gr.File(label = '모델 File')
        
    download_model_btn.click(
        download_model,
        [model_url],
        [model_file]
    )

IMPORTANT: You are using gradio version 3.40.0, however version 4.29.0 is available, please upgrade.
--------


In [20]:
app.launch(inline = False, share = True)

Running on local URL:  http://127.0.0.1:7860

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB
Running on public URL: https://f0b1d07ba918cead6e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




models/disneyPixarCartoon_v10.safetensors: 100%|██████████████████████████████████| 3.95G/3.95G [11:24<00:00, 6.20MiB/s]


[info] file downloaded : models/disneyPixarCartoon_v10.safetensors


In [18]:
app.close()

Closing server running on port: 7860


## 모델 다운로드 기능 구현하기

In [4]:
import os
# 디렉토리 안에 있는 파일명을 리스트로 만들어주는 라이브러리
import glob

# 전역 변수로 모델 경로와 파일명을 저장
MODEL_PATH = None

# URL로부터 파일 다운로드하는 함수
def download_from_url(url, file_path, chunk_size = 1024) :
    try :
        resp = requests.get(url, stream = True)
        resp.raise_for_status()
    except Exception as e:
        print(f'[error] {e}')
        raise e
        
    total = int(resp.headers.get('content-length', 0)) # 파일 크기 추출
    with open(file_path, 'wb') as file, tqdm(desc = file_path, total = total, unit = 'iB', unit_scale = True,
                                            unit_divisor = 1024) as bar:
        for data in resp.iter_content(chunk_size = chunk_size) :
            size = file.write(data)
            bar.update(size)

# 모델을 다운로드하고 경로를 기억하는 함수
def download_model(url: str) -> str :
    global MODEL_PATH #전역변수를 사용해서 경로를 기억
    
    model_id = url.replace('https://civitai.com/models/', "").split('/')[0]
    
    try :
        response = requests.get(f'https://civitai.com/api/v1/models/{model_id}', timeout = 6000)
    except Exception as e :
        print(f'Error : {e}')
        raise e
        
    download_url = response.json()['modelVersions'][0]['downloadUrl']
    filename = response.json()['modelVersions'][0]['files'][0]['name']
    
    file_path = f'models/{filename}'
    if os.path.exists(file_path) :
        print(f'[info] File already exists : {file_path}')
        MODEL_PATH = file_path
        return file_path
    
    os.makedirs('models', exist_ok = True)
    download_from_url(download_url, file_path)
    print(f'[info] file downloaded : {file_path}')
    
    # 모델 경로 기억
    MODEL_PATH = file_path
    return file_path

# ./models 폴더에서 가장 최근에 수정된 모델 파일 찾기
def find_latest_model_in_directory(directory) :
    model_files = glob.glob(f'{directory}/*.safetensors')
    if not model_files :
        return None
    
    # 가장 최근에 수정된 모델 파일 선택
    latest_model = max(model_files, key = os.path.getmtime)
    return latest_model

## 다운로드한 모델 불러와서 초기화하기

In [13]:
def init_pipeline() :
    global MODEL_PATH
    
    if MODEL_PATH is None :
        MODEL_PATH = find_latest_model_in_directory('./models/')
    if MODEL_PATH is None :
        return "Error: No model found in ./models"
    
    global PIPELINE
    
    try :
        PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
            MODEL_PATH,
            torch_dtype = torch.float16,
            variant = 'fp16',
            use_safetensors = True,
        ).to('cpu')
        print('[info] initiallized pipeline')
        return 'Model Loaded!'
    except Exception as e :
        print(f'[error] {e}')

In [14]:
with gr.Blocks() as app :
    gr.Markdown('## 모델 불러오기')
    with gr.Row() :
        load_model_btn = gr.Button(value = '모델 불러오기')
    with gr.Row() :
        is_model_check = gr.Textbox(label = 'Model Load Check', value = 'model not loaded')
    
    load_model_btn.click(
        init_pipeline,
        None,
        [is_model_check]
    )

IMPORTANT: You are using gradio version 3.40.0, however version 4.29.0 is available, please upgrade.
--------


In [15]:
app.queue().launch(inline = False, share = True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://2023fa78bcd092a206.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Fetching 11 files: 100%|█████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 95522.45it/s]
Some weights of the model checkpoint were not used when initializing CLIPTextModel: 
 ['text_model.embeddings.position_ids']
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 6/6 [00:29<00:00,  4.89s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com

[info] initiallized pipeline


In [16]:
app.close()

Closing server running on port: 7861


## 스케치 to 이미지 생성기능 구현

In [17]:
def sketch_to_image(sketch, prompt, negative_prompt) :
    global PIPELINE
    if PIPELINE is None :
        return "error! pipeline is not initialized"
    
    prompt = [prompt]
    negative_prompt = [negative_prompt]
    
    images = [sketch] * len(prompt)
    
    try :
        # 이미지 생성
        result = PIPELINE(
            image = images,
            prompt = prompt,
            negative_prompt = negative_prompt,
            height = height,
            width = width,
            num_images_per_prompt = 4,
            num_inference_steps = 20,
            strength = 0.7
        ).images
        
    except Exception as e :
        print(e)
        
#     # gpu 메모리 캐시 비우기
#     with torch.cuda.device('cuda') :
#         torch.cuda.empty_cache()
    return result

# 이미지 생성 전체 코드

In [1]:
import os
from typing import IO
import gradio as gr
import requests
from tqdm import tqdm
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import torch
# 디렉토리 안에 있는 파일명을 리스트로 만들어주는 라이브러리
import glob

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 스케치하고 이미지 업로드하는 ui
WIDTH = 512
HEIGHT = 512

# 전역 변수로 모델 경로와 파일명을 저장
MODEL_PATH = None
PIPELINE = None

In [7]:
# URL로부터 파일 다운로드하는 함수
def download_from_url(url, file_path, chunk_size = 1024) :
    try :
        resp = requests.get(url, stream = True)
        resp.raise_for_status()
    except Exception as e:
        print(f'[error] {e}')
        raise e
        
    total = int(resp.headers.get('content-length', 0)) # 파일 크기 추출
    with open(file_path, 'wb') as file, tqdm(desc = file_path, total = total, unit = 'iB', unit_scale = True,
                                            unit_divisor = 1024) as bar:
        for data in resp.iter_content(chunk_size = chunk_size) :
            size = file.write(data)
            bar.update(size)

# 모델을 다운로드하고 경로를 기억하는 함수
def download_model(url: str) -> str :
    global MODEL_PATH #전역변수를 사용해서 경로를 기억
    
    model_id = url.replace('https://civitai.com/models/', "").split('/')[0]
    
    try :
        response = requests.get(f'https://civitai.com/api/v1/models/{model_id}', timeout = 6000)
    except Exception as e :
        print(f'Error : {e}')
        raise e
        
    download_url = response.json()['modelVersions'][0]['downloadUrl']
    filename = response.json()['modelVersions'][0]['files'][0]['name']
    
    file_path = f'models/{filename}'
    if os.path.exists(file_path) :
        print(f'[info] File already exists : {file_path}')
        MODEL_PATH = file_path
        return file_path
    
    os.makedirs('models', exist_ok = True)
    download_from_url(download_url, file_path)
    print(f'[info] file downloaded : {file_path}')
    
    # 모델 경로 기억
    MODEL_PATH = file_path
    return file_path

# ./models 폴더에서 가장 최근에 수정된 모델 파일 찾기
def find_latest_model_in_directory(directory) :
    model_files = glob.glob(f'{directory}/*.safetensors')
    if not model_files :
        return None
    
    # 가장 최근에 수정된 모델 파일 선택
    latest_model = max(model_files, key = os.path.getmtime)
    return latest_model

# 다운로드한 모델 불러와서 초기화하기
def init_pipeline() :
    global MODEL_PATH
    
    if MODEL_PATH is None :
        MODEL_PATH = find_latest_model_in_directory('./models/')
    if MODEL_PATH is None :
        return "Error: No model found in ./models"
    
    global PIPELINE
    
    try :
        PIPELINE = StableDiffusionImg2ImgPipeline.from_single_file(
            MODEL_PATH,
            torch_dtype = torch.float32, #cpu와 호환되도록 float34로 변경
            variant = 'fp32',
            use_safetensors = True,
        ).to('cpu')
        print('[info] initiallized pipeline')
        return 'Model Loaded!'
    except Exception as e :
        print(f'[error] {e}')

# 스케치 to 이미지 생성기능 구현
def sketch_to_image(sketch, prompt, negative_prompt) :
    global PIPELINE
    if PIPELINE is None :
        return "error! pipeline is not initialized"
    
    prompt = [prompt]
    negative_prompt = [negative_prompt]
    
    images = [sketch] * len(prompt)
    
    try :
        # 이미지 생성
        result = PIPELINE(
            image = images,
            prompt = prompt,
            negative_prompt = negative_prompt,
            height = HEIGHT,
            width = WIDTH,
            num_images_per_prompt = 4,
            num_inference_steps = 20,
            strength = 0.7
        ).images
        
    except Exception as e :
        print(e)
        return e
        
#     # gpu 메모리 캐시 비우기
#     with torch.cuda.device('cuda') :
#         torch.cuda.empty_cache()
    return result

In [8]:
with gr.Blocks() as app :
    
    # 모델 다운로드 ui
    gr.Markdown('## 모델 다운로드')
    with gr.Row() :
        model_url = gr.Textbox(label = '모델 URL', placeholder = 'http://civitai.com/')
        download_model_btn = gr.Button(value = '모델 다운로드')
        
    with gr.Row() :
        model_file = gr.File(label = '모델 File')
        
    # 모델 불러오기 ui
    gr.Markdown('## 모델 불러오기')
    with gr.Row() :
        load_model_btn = gr.Button(value = '모델 불러오기')
    with gr.Row() :
        is_model_check = gr.Textbox(label = 'Model Load Check', value = 'model not loaded')
    

    # 프롬프트 입력 & 스케치 & 이미지 업로드 ui 구현
    gr.Markdown("## 프롬프트 입력")
    with gr.Row() :
        prompt = gr.Textbox(label = 'Prompt')
    with gr.Row() :
        n_prompt = gr.Textbox(label = "negative prompt")
    
    # 스케치에서 이미지 생성 블록
    gr.Markdown('## 스케치 to 이미지 생성')
    with gr.Row() :
        with gr.Column() :
            with gr.Tab("Canvas") :
                with gr.Row() :
                    canvas = gr.Image(
                        label = 'Draw',
                        source = 'canvas',
                        image_mode = 'RGB',
                        tool = 'color-sketch',
                        interactive = True,
                        width = WIDTH,
                        height = HEIGHT,
                        shape = (WIDTH, HEIGHT),
                        brush_radius = 20,
                        type = 'pil'
                    )
                with gr.Row() :
                    canvas_run_btn = gr.Button(value = 'Generate')

            # 파일 업로드
            with gr.Tab("File") :
                with gr.Row() :
                    file = gr.Image(
                        label = 'Upload',
                        source = 'upload',
                        image_mode = 'RGB',
                        tool = 'color-sketch',
                        interactive = True,
                        width = WIDTH,
                        height = HEIGHT,
                        shape = (WIDTH, HEIGHT),
                        type = 'pil'
                    )
                with gr.Row() :
                    file_run_btn = gr.Button(value = 'Generate')
                    
        # 결과 이미지 갤러리
        with gr.Column() :
            result_gallery = gr.Gallery(label = 'Output', height = 512)
        
    # 로드 모델 실행              
    load_model_btn.click(
        init_pipeline,
        None,
        [is_model_check]
    )
    
    # 다운로드 모델 실행
    download_model_btn.click(
        download_model,
        [model_url],
        [model_file]
    )
    
    # canvas에서 이미지 생성 버튼 실행
    canvas_run_btn.click(
        sketch_to_image,
        [canvas, prompt, n_prompt],
        [result_gallery]
    )
    
    # 파일을 업로드했을 때 이미지 생성 버튼 실행
    file_run_btn.click(
        sketch_to_image,
        [file, prompt, n_prompt],
        [result_gallery]
    )

app.queue().launch(inline = False, share = True)

Running on local URL:  http://127.0.0.1:7861
IMPORTANT: You are using gradio version 3.40.0, however version 4.29.0 is available, please upgrade.
--------
Running on public URL: https://bd69caefc8f5a4eb57.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)





Fetching 11 files: 100%|█████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 53153.62it/s][A

Some weights of the model checkpoint were not used when initializing CLIPTextModel: 
 ['text_model.embeddings.position_ids']

Loading pipeline components...:  50%|██████████████████████████▌                          | 3/6 [00:02<00:02,  1.17it/s][A
Loading pipeline components...:  67%|███████████████████████████████████▎                 | 4/6 [00:03<00:01,  1.29it/s][A
Loading pipeline components...:  83%|████████████████████████████████████████████▏        | 5/6 [00:24<00:07,  7.40s/it][A
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 6/6 [00:26<00:00,  4.50s/it][A
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion l

[info] initiallized pipeline


  deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)

  0%|                                                                                            | 0/14 [00:00<?, ?it/s][A
  7%|██████                                                                              | 1/14 [00:37<08:10, 37.75s/it][A

In [6]:
app.close()

KeyboardInterrupt: 