Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 110 additions & 20 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
from enum import Enum

import torch

from PIL import Image, ImageFilter
Expand All @@ -7,9 +9,9 @@
import numpy as np
from typing import Optional, Type
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DiffusionPipeline
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers import StableDiffusionControlNetInpaintPipeline, DDIMScheduler
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DiffusionPipeline, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionXLControlNetPipeline
from diffusers import StableDiffusionControlNetInpaintPipeline, StableDiffusionXLControlNetInpaintPipeline, DDIMScheduler
from diffusers.schedulers import KarrasDiffusionSchedulers

from .pipelines.masked_stable_diffusion_img2img import MaskedStableDiffusionImg2ImgPipeline
Expand Down Expand Up @@ -51,8 +53,6 @@ class BasePipe:
def __init__(self, model_id: str,
sd_pipe_class: Optional[Type[DiffusionPipeline]]=None,
pipe: Optional[DiffusionPipeline] = None, **args):
if sd_pipe_class is None:
sd_pipe_class = self._class
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -62,12 +62,33 @@ def __init__(self, model_id: str,
args = {**args}
if 'torch_dtype' not in args:
args['torch_dtype']=torch.float16

if self.pipe is None:
assert sd_pipe_class is not None
if model_id.endswith('.safetensors'):
self.pipe = sd_pipe_class.from_single_file(model_id, **args)
constructor_args = dict()
if isinstance(self, Cond2ImPipe):
constructor_args['controlnet'] = args['controlnet']

if sd_pipe_class is None:
if model_id.endswith('.safetensors'):
try:
self.pipe = StableDiffusionPipeline.from_single_file(model_id, **args)
except TypeError as e:
pass
self.pipe = StableDiffusionXLPipeline.from_single_file(model_id, **args)
else:
# we can't use specific class, because we dont know if it is sdxl
self.pipe = DiffusionPipeline.from_pretrained(model_id, **args)
if 'custom_pipeline' not in args:
# create correct class if custom_pipeline is not specified
# at this stage we know that the model is sdxl or sd
self.pipe = self.from_pipe(self.pipe, **constructor_args)

else:
self.pipe = sd_pipe_class.from_pretrained(model_id, **args)
if model_id.endswith('.safetensors'):
self.pipe = sd_pipe_class.from_single_file(model_id, **args)
else:
self.pipe = sd_pipe_class.from_pretrained(model_id, **args)

self.pipe.to("cuda")
# self.pipe.enable_attention_slicing()
# self.pipe.enable_vae_slicing()
Expand Down Expand Up @@ -134,16 +155,25 @@ def setup(self, steps=50, **args):
self.pipe.text_encoder.load_state_dict(prev_encoder.state_dict())
else:
self.pipe.text_encoder = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder",
num_hidden_layers=12 - clip_skip)
num_hidden_layers=12 - clip_skip)
self.pipe.text_encoder.to(prev_encoder.device)
self.pipe.text_encoder.to(prev_encoder.dtype)
if 'scheduler' in args:
# TODO? add scheduler to config?
self.try_set_scheduler(dict(scheduler=args['scheduler']))

def from_pipe(self, pipe, **args):
if isinstance(pipe, StableDiffusionXLPipeline):
return self._classxl(**pipe.components, **args)
if isinstance(pipe, StableDiffusionPipeline):
return self._class(**pipe.components, **args)
# it's a custom pipeline
return pipe


class Prompt2ImPipe(BasePipe):
_class = StableDiffusionPipeline
_classxl = StableDiffusionXLPipeline

def __init__(self, model_id: str,
pipe: Optional[StableDiffusionPipeline] = None,
Expand Down Expand Up @@ -174,6 +204,7 @@ def gen(self, inputs):
class Im2ImPipe(BasePipe):

_class = StableDiffusionImg2ImgPipeline
_classxl = StableDiffusionXLImg2ImgPipeline

def __init__(self, model_id, pipe: Optional[StableDiffusionImg2ImgPipeline] = None, **args):
super().__init__(model_id=model_id, pipe=pipe, **args)
Expand Down Expand Up @@ -256,21 +287,39 @@ def gen(self, inputs):
return img_compose


class ControlnetType(Enum):
stable_diffusion = 1
stable_diffusion_xl = 2


class Cond2ImPipe(BasePipe):
_class = StableDiffusionControlNetPipeline
_classxl = StableDiffusionXLControlNetPipeline

# TODO: set path
cpath = "./models-cn/"
cpathxl = "./models-cn-xl/"

cmodels = {
"canny": "sd-controlnet-canny",
"pose": "control_v11p_sd15_openpose",
"ip2p": "control_v11e_sd15_ip2p",
"soft-sobel": "control_v11p_sd15_softedge",
"soft": "control_v11p_sd15_softedge",
"depth": "control_v11f1p_sd15_depth",
"inpaint": "control_v11p_sd15_inpaint"
"inpaint": "control_v11p_sd15_inpaint",
"qr": "controlnet_qrcode-control_v1p_sd15"
}

cmodelsxl = {
"qr": "controlnet-qr-pattern-sdxl",
}
cscalem = {

cond_scales_defaults_xl = {
"qr": 0.5
}

cond_scales_defaults = {
"canny": 0.75,
"pose": 1.0,
"ip2p": 0.5,
Expand All @@ -281,26 +330,57 @@ class Cond2ImPipe(BasePipe):
}

def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] = None,
ctypes=["soft"], **args):
ctypes=["soft"], model_type=ControlnetType.stable_diffusion, **args):
self.model_type = model_type
if not isinstance(ctypes, list):
ctypes = [ctypes]
self.ctypes = ctypes
self._condition_image = None
dtype = torch.float16 if 'torch_type' not in args else args['torch_type']
cnets = [ControlNetModel.from_pretrained(CIm2ImPipe.cpath+CIm2ImPipe.cmodels[c], torch_dtype=dtype) for c in ctypes]
super().__init__(sd_pipe_class=StableDiffusionControlNetPipeline, model_id=model_id, pipe=pipe, controlnet=cnets, **args)
cpath = self.get_cpath()
cmodels = self.get_cmodels()
sd_class = self.get_sd_class()
cnets = [ControlNetModel.from_pretrained(cpath+cmodels[c], torch_dtype=dtype) for c in ctypes]
super().__init__(model_id=model_id, pipe=pipe, sd_pipe_class=sd_class, controlnet=cnets, **args)
# FIXME: do we need to setup this specific scheduler here?
# should we pass its name in setup to super?
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)

def get_cmodels(self):
if self.model_type == ControlnetType.stable_diffusion_xl:
cmodels = self.cmodelsxl
elif self.model_type == ControlnetType.stable_diffusion:
cmodels = self.cmodels
else:
raise ValueError(f"Unknown controlnet type: {self.model_type}")
return cmodels

def get_cpath(self):
if self.model_type == ControlnetType.stable_diffusion_xl:
cpath = self.cpathxl
elif self.model_type == ControlnetType.stable_diffusion:
cpath = self.cpath
else:
raise ValueError(f"Unknown controlnet type: {self.model_type}")
return cpath

def get_sd_class(self):
if self.model_type == ControlnetType.stable_diffusion_xl:
cclass = self._classxl
elif self.model_type == ControlnetType.stable_diffusion:
cclass = self._class
else:
raise ValueError(f"Unknown controlnet type: {self.model_type}")
return cclass

def setup(self, fimage, width=None, height=None, image=None, cscales=None, guess_mode=False, **args):
super().setup(**args)
# TODO: allow multiple input images for multiple control nets
self.fname = fimage
image = Image.open(fimage) if image is None else image
self._condition_image = [image]
if cscales is None:
cscales = [CIm2ImPipe.cscalem[c] for c in self.ctypes]
cscales = [self.get_default_cond_scales()[c] for c in self.ctypes]
self.pipe_params.update({
"width": image.size[0] if width is None else width,
"height": image.size[1] if height is None else height,
Expand All @@ -309,6 +389,15 @@ def setup(self, fimage, width=None, height=None, image=None, cscales=None, guess
"num_inference_steps": 20
})

def get_default_cond_scales(self):
if self.model_type == ControlnetType.stable_diffusion_xl:
cond_scales = self.cond_scales_defaults_xl
elif self.model_type == ControlnetType.stable_diffusion:
cond_scales = self.cond_scales_defaults
else:
raise ValueError(f"Unknown controlnet type: {self.model_type}")
return cond_scales

def get_config(self):
cfg = super().get_config()
cfg.update({
Expand All @@ -329,8 +418,8 @@ def gen(self, inputs):
class CIm2ImPipe(Cond2ImPipe):

def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] = None,
ctypes=["soft"], **args):
super().__init__(model_id=model_id, pipe=pipe, ctypes=ctypes, **args)
ctypes=["soft"], model_type=ControlnetType.stable_diffusion, **args):
super().__init__(model_id=model_id, pipe=pipe, ctypes=ctypes, model_type=model_type, **args)
# The difference from Cond2ImPipe is that the conditional image is not
# taken as input but is obtained from an ordinary image, so this image
# should be processed, and the processor depends on the conditioning type
Expand Down Expand Up @@ -395,7 +484,8 @@ def _proc_cimg(self, oriImg):
# TODO: does it make sense to inherint it from Cond2Im or CIm2Im ?
class InpaintingPipe(BasePipe):
_class = StableDiffusionControlNetInpaintPipeline

_classxl = StableDiffusionXLControlNetInpaintPipeline

def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] = None,
**args):
dtype = torch.float16 if 'torch_type' not in args else args['torch_type']
Expand All @@ -413,7 +503,7 @@ def setup(self, fimage, mask_image, image=None, **args):
self._init_image = Image.open(fimage) if image is None else image
self._mask_image = mask_image
self._control_image = self._make_inpaint_condition(self._init_image, mask_image)
# self._condition_image = [image]

self.pipe_params.update({
# TODO: check if condtitioning_scale and guess_mode are in this pipeline and what is their effect
# "controlnet_conditioning_scale": cscales,
Expand Down