Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ DiffusionMagic focused on the following areas:
- Cross-platform (Windows/Linux/Mac)
- Modular design, latest best optimizations for speed and memory

## Stable diffusion XL Colab
We can run StableDiffusion XL 0.9 on Google Colab
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1KrmcU2gONIQ2WihI1s6uITgDDzkbKaJK?usp=sharing)

![ DiffusionMagic](https://raw.githubusercontent.com/rupeshs/diffusionmagic/main/docs/images/diffusion_magic.PNG)
## Features
- Supports various Stable Diffusion workflows
Expand Down Expand Up @@ -113,6 +117,7 @@ Or we can clone the model use the local folder path as model id.
## Linting (Development)
Run the following commands from src folder
`mypy --ignore-missing-imports --explicit-package-bases .`

`flake8 --max-line-length=100 .`
## Contribute
Contributions are welcomed.
Expand Down
3 changes: 2 additions & 1 deletion configs/stable_diffusion_models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ lllyasviel/sd-controlnet-hed
lllyasviel/sd-controlnet-openpose
lllyasviel/sd-controlnet-depth
lllyasviel/sd-controlnet-scribble
lllyasviel/sd-controlnet-seg
lllyasviel/sd-controlnet-seg
stabilityai/stable-diffusion-xl-base-1.0
13 changes: 7 additions & 6 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ dependencies:
- torchvision=0.15.0
- numpy=1.19.2
- pip:
- accelerate==0.17.1
- diffusers==0.14.0
- gradio==3.17.1
- safetensors==0.2.8
- accelerate==0.21.0
- diffusers==0.19.3
- gradio==3.32.0
- safetensors==0.3.1
- scipy==1.10.0
- transformers==4.26.0
- transformers==4.31.0
- pydantic==1.10.4
- mypy==1.0.0
- black==23.1.0
- flake8==6.0.0
- markupsafe==2.0.1
- opencv-contrib-python==4.7.0.72
- controlnet-aux==0.0.1
- controlnet-aux==0.0.1
- invisible-watermark==0.2.0
86 changes: 86 additions & 0 deletions src/backend/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from backend.controlnet.ControlContext import ControlnetContext
from backend.stablediffusion.stablediffusion import StableDiffusion
from backend.stablediffusion.stablediffusionxl import StableDiffusionXl
from settings import AppSettings


Expand All @@ -30,6 +31,7 @@ def __init__(self, compute: Computing):
self.stable_diffusion_depth = StableDiffusionDepthToImage(compute)
self.stable_diffusion_pix_to_pix = StableDiffusionInstructPixToPix(compute)
self.controlnet = ControlnetContext(compute)
self.stable_diffusion_xl = StableDiffusionXl(compute)
self.app_settings = AppSettings().get_settings()
self.model_id = self.app_settings.model_settings.model_id
self.low_vram_mode = self.app_settings.low_memory_mode
Expand Down Expand Up @@ -78,6 +80,15 @@ def _init_stable_diffusion(self):
)
self.pipe_initialized = True

def _init_stable_diffusion_xl(self):
if not self.pipe_initialized:
print("Initializing stable diffusion xl pipeline")
self.stable_diffusion_xl.get_text_to_image_xl_pipleline(
self.model_id,
self.low_vram_mode,
)
self.pipe_initialized = True

def diffusion_image_to_image(
self,
image,
Expand Down Expand Up @@ -355,3 +366,78 @@ def diffusion_control_to_image(
"CannyToImage",
)
return images

def diffusion_text_to_image_xl(
self,
prompt,
neg_prompt,
image_height,
image_width,
inference_steps,
scheduler,
guidance_scale,
num_images,
attention_slicing,
vae_slicing,
seed,
) -> Any:
stable_diffusion_settings = StableDiffusionSetting(
prompt=prompt,
negative_prompt=neg_prompt,
image_height=image_height,
image_width=image_width,
inference_steps=inference_steps,
guidance_scale=guidance_scale,
number_of_images=num_images,
scheduler=scheduler,
seed=seed,
attention_slicing=attention_slicing,
vae_slicing=vae_slicing,
)
self._init_stable_diffusion_xl()
images = self.stable_diffusion_xl.text_to_image_xl(stable_diffusion_settings)
self._save_images(
images,
"TextToImage",
)
return images

def diffusion_image_to_image_xl(
self,
image,
strength,
prompt,
neg_prompt,
image_height,
image_width,
inference_steps,
scheduler,
guidance_scale,
num_images,
attention_slicing,
seed,
) -> Any:
stable_diffusion_image_settings = StableDiffusionImageToImageSetting(
image=image,
strength=strength,
prompt=prompt,
negative_prompt=neg_prompt,
image_height=image_height,
image_width=image_width,
inference_steps=inference_steps,
guidance_scale=guidance_scale,
number_of_images=num_images,
scheduler=scheduler,
seed=seed,
attention_slicing=attention_slicing,
)
self._init_stable_diffusion_xl()
images = self.stable_diffusion_xl.image_to_image(
stable_diffusion_image_settings
)

self._save_images(
images,
"ImageToImage",
)
return images
3 changes: 3 additions & 0 deletions src/backend/stablediffusion/stable_diffusion_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class StableDiffusionType(str, Enum):
controlnet_depth = "controlnet_depth"
controlnet_scribble = "controlnet_scribble"
controlnet_seg = "controlnet_seg"
stable_diffusion_xl = "StableDiffusionXl"


def get_diffusion_type(
Expand Down Expand Up @@ -44,4 +45,6 @@ def get_diffusion_type(
stable_diffusion_type = StableDiffusionType.controlnet_scribble
elif "controlnet-seg" in model_id:
stable_diffusion_type = StableDiffusionType.controlnet_seg
elif "stable-diffusion-xl" in model_id:
stable_diffusion_type = StableDiffusionType.stable_diffusion_xl
return stable_diffusion_type
181 changes: 181 additions & 0 deletions src/backend/stablediffusion/stablediffusionxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from time import time

import torch
from diffusers import (
DiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
)
from PIL import Image

from backend.computing import Computing
from backend.stablediffusion.modelmeta import ModelMeta
from backend.stablediffusion.models.scheduler_types import SchedulerType
from backend.stablediffusion.models.setting import (
StableDiffusionImageToImageSetting,
StableDiffusionSetting,
)
from backend.stablediffusion.scheduler_mixin import SamplerMixin


class StableDiffusionXl(SamplerMixin):
def __init__(self, compute: Computing):
self.compute = compute
self.pipeline = None
self.device = self.compute.name

super().__init__()

def get_text_to_image_xl_pipleline(
self,
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
low_vram_mode: bool = False,
sampler: str = SchedulerType.DPMSolverMultistepScheduler.value,
):
repo_id = model_id
model_meta = ModelMeta(repo_id)
is_lora_model = model_meta.is_loramodel()
if is_lora_model:
print("LoRA model detected")
self.model_id = model_meta.get_lora_base_model()
print(f"LoRA base model - {self.model_id}")
else:
self.model_id = model_id

self.low_vram_mode = low_vram_mode
print(f"StableDiffusion - {self.compute.name},{self.compute.datatype}")
print(f"using model {model_id}")
self.default_sampler = self.find_sampler(
sampler,
self.model_id,
)
tic = time()
self._load_model()
delta = time() - tic
print(f"Model loaded in {delta:.2f}s ")

if self.pipeline is None:
raise Exception("Text to image pipeline not initialized")
if is_lora_model:
self.pipeline.unet.load_attn_procs(repo_id)
self._pipeline_to_device()
components = self.pipeline.components
self.img_to_img_pipeline = StableDiffusionXLImg2ImgPipeline(**components)

def text_to_image_xl(self, setting: StableDiffusionSetting):
if self.pipeline is None:
raise Exception("Text to image pipeline not initialized")

self.pipeline.scheduler = self.find_sampler(
setting.scheduler,
self.model_id,
)
generator = None
if setting.seed != -1:
print(f"Using seed {setting.seed}")
generator = torch.Generator(self.device).manual_seed(setting.seed)

# if setting.attention_slicing:
# self.pipeline.enable_attention_slicing()
# else:
# self.pipeline.disable_attention_slicing()

if setting.vae_slicing:
self.pipeline.enable_vae_slicing()
else:
self.pipeline.disable_vae_slicing()

images = self.pipeline(
setting.prompt,
guidance_scale=setting.guidance_scale,
num_inference_steps=setting.inference_steps,
height=setting.image_height,
width=setting.image_width,
negative_prompt=setting.negative_prompt,
num_images_per_prompt=setting.number_of_images,
generator=generator,
).images

# self.pipeline.unet = torch.compile(
# self.pipeline.unet,
# mode="reduce-overhead",
# fullgraph=True,
# )
return images

def _pipeline_to_device(self):
if self.low_vram_mode:
print("Running in low VRAM mode,slower to generate images")
self.pipeline.enable_sequential_cpu_offload()
else:
if self.compute.name == "cuda":
self.pipeline = self.pipeline.to("cuda")
elif self.compute.name == "mps":
self.pipeline = self.pipeline.to("mps")

def _load_full_precision_model(self):
self.pipeline = DiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=self.compute.datatype,
scheduler=self.default_sampler,
)

def _load_model(self):
if self.compute.name == "cuda":
try:
self.pipeline = DiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=self.compute.datatype,
scheduler=self.default_sampler,
use_safetensors=True,
variant="fp16",
)
except Exception as ex:
print(
f" The fp16 of the model not found using full precision model, {ex}"
)
self._load_full_precision_model()
else:
self._load_full_precision_model()

def image_to_image(self, setting: StableDiffusionImageToImageSetting):
if setting.scheduler is None:
raise Exception("Scheduler cannot be empty")

print("Running image to image pipeline")
self.img_to_img_pipeline.scheduler = self.find_sampler( # type: ignore
setting.scheduler,
self.model_id,
)
generator = None
if setting.seed != -1 and setting.seed:
print(f"Using seed {setting.seed}")
generator = torch.Generator(self.device).manual_seed(setting.seed)

if setting.attention_slicing:
self.img_to_img_pipeline.enable_attention_slicing() # type: ignore
else:
self.img_to_img_pipeline.disable_attention_slicing() # type: ignore

if setting.vae_slicing:
self.pipeline.enable_vae_slicing() # type: ignore
else:
self.pipeline.disable_vae_slicing() # type: ignore

init_image = setting.image.resize(
(
setting.image_width,
setting.image_height,
),
Image.Resampling.LANCZOS,
)
images = self.img_to_img_pipeline( # type: ignore
image=init_image,
strength=setting.strength,
prompt=setting.prompt,
guidance_scale=setting.guidance_scale,
num_inference_steps=setting.inference_steps,
negative_prompt=setting.negative_prompt,
num_images_per_prompt=setting.number_of_images,
generator=generator,
).images
return images
2 changes: 1 addition & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "2.0.0-beta.0"
VERSION = "3.0.0"
STABLE_DIFFUSION_MODELS_FILE = "stable_diffusion_models.txt"
APP_SETTINGS_FILE = "settings.yaml"
CONFIG_DIRECTORY = "configs"
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/web/depth_to_image_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def random_seed():
show_label=True,
elem_id="gallery",
).style(
grid=2,
columns=2,
)
generate_btn.click(
fn=generate_callback_fn,
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/web/image_inpainting_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def random_seed():
label="Number of images to generate",
)
attn_slicing = gr.Checkbox(
label="Attention slicing (Enable if low VRAM)",
label="Attention slicing (Not used)",
value=True,
)
seed = gr.Number(
Expand Down Expand Up @@ -105,7 +105,7 @@ def random_seed():
show_label=True,
elem_id="gallery",
).style(
grid=2,
columns=2,
)
generate_btn.click(
fn=generate_callback_fn,
Expand Down
Loading