Skip to content

Commit

Permalink
feat: add parameter for ControlNet selection
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 12, 2023
1 parent fbf5767 commit 9e017ee
Show file tree
Hide file tree
Showing 11 changed files with 68 additions and 15 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def blend_controlnet(
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using controlnet, %s steps: %s", params.steps, params.prompt
"blending image using ControlNet, %s steps: %s", params.steps, params.prompt
)

pipe = load_pipeline(
Expand Down
1 change: 0 additions & 1 deletion api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def convert_diffusion_diffusers(
if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor())


cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
Expand Down
22 changes: 14 additions & 8 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,13 @@ def load_pipeline(
scheduler_name: str,
device: DeviceParams,
lpw: bool,
control: Optional[str] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
):
inversions = inversions or []
loras = loras or []

controlnet = "canny" # TODO; from params

torch_dtype = (
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
)
Expand Down Expand Up @@ -186,6 +185,8 @@ def load_pipeline(
}

text_encoder = None

# Textual Inversion blending
if inversions is not None and len(inversions) > 0:
logger.debug("blending Textual Inversions from %s", inversions)
inversion_names, inversion_weights = zip(*inversions)
Expand Down Expand Up @@ -225,7 +226,7 @@ def load_pipeline(
)
)

# test LoRA blending
# LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
Expand Down Expand Up @@ -278,8 +279,12 @@ def load_pipeline(
)
)

if controlnet is not None:
components["controlnet"] = OnnxRuntimeModel.from_pretrained(controlnet)
if control is not None:
components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(
control,
provider=device.ort_provider(),
sess_options=device.sess_options(),
))

pipe = pipeline.from_pretrained(
model,
Expand Down Expand Up @@ -360,7 +365,7 @@ def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped

def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None, **kwargs):
global timestep_dtype
timestep_dtype = timestep.dtype

Expand All @@ -382,6 +387,7 @@ def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)

def __getattr__(self, attr):
Expand All @@ -393,15 +399,15 @@ def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped

def __call__(self, latent_sample=None):
def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype

logger.trace("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)

return self.wrapped(latent_sample=latent_sample)
return self.wrapped(latent_sample=latent_sample, **kwargs)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def run_img2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
inversions,
loras,
control=params.control,
inversions=inversions,
loras=loras,
)
progress = job.get_progress_callback()
if params.lpw:
Expand Down
16 changes: 16 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ def torch_str(self) -> str:


class ImageParams:
model: str
scheduler: str
prompt: str
cfg: float
steps: int
seed: int
negative_prompt: Optional[str]
lpw: bool
eta: float
batch: int
control: Optional[str]

def __init__(
self,
model: str,
Expand All @@ -172,6 +184,7 @@ def __init__(
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
control: Optional[str] = None,
) -> None:
self.model = model
self.scheduler = scheduler
Expand All @@ -183,6 +196,7 @@ def __init__(
self.lpw = lpw or False
self.eta = eta
self.batch = batch
self.control = control

def tojson(self) -> Dict[str, Optional[Param]]:
return {
Expand All @@ -196,6 +210,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
"control": self.control,
}

def with_args(self, **kwargs):
Expand All @@ -210,6 +225,7 @@ def with_args(self, **kwargs):
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
kwargs.get("control", self.control),
)


Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def pipeline_from_request(
device = platform

# pipeline stuff
control = get_not_empty(request.args, "control", get_config_value("control"))
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(server, model)
Expand Down Expand Up @@ -132,6 +133,7 @@ def pipeline_from_request(
lpw=lpw,
negative_prompt=negative_prompt,
batch=batch,
control=control,
)
size = Size(width, height)
return (device, params, size)
Expand Down
4 changes: 4 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
"max": 30,
"step": 0.1
},
"control": {
"default": "",
"keys": []
},
"correction": {
"default": "",
"keys": []
Expand Down
4 changes: 4 additions & 0 deletions gui/examples/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"max": 30,
"step": 0.1
},
"control": {
"default": "",
"keys": []
},
"correction": {
"default": "",
"keys": []
Expand Down
8 changes: 7 additions & 1 deletion gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ export interface ModelParams {
* Use the long prompt weighting pipeline.
*/
lpw: boolean;

/**
* ControlNet to be used.
*/
control: string;
}

/**
Expand Down Expand Up @@ -191,7 +196,7 @@ export interface ReadyResponse {

export interface NetworkModel {
name: string;
type: 'inversion' | 'lora';
type: 'control' | 'inversion' | 'lora';
// TODO: add token
// TODO: add layer/token count
}
Expand Down Expand Up @@ -392,6 +397,7 @@ export function appendModelToURL(url: URL, params: ModelParams) {
url.searchParams.append('upscaling', params.upscaling);
url.searchParams.append('correction', params.correction);
url.searchParams.append('lpw', String(params.lpw));
url.searchParams.append('control', params.control);
}

/**
Expand Down
14 changes: 14 additions & 0 deletions gui/src/components/control/ModelControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ export function ModelControl() {
});
}}
/>
<QueryMenu
id='control'
labelKey='model'
name={t('modelType.control')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'control').map((network) => network.name),
}}
onSelect={(control) => {
setModel({
control,
});
}}
/>
</Stack>
<Stack direction='row' spacing={2}>
<FormControlLabel
Expand Down
5 changes: 3 additions & 2 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,12 @@ export function createStateSlices(server: ServerParams) {

const createModelSlice: Slice<ModelSlice> = (set) => ({
model: {
control: server.control.default,
correction: server.correction.default,
lpw: false,
model: server.model.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
correction: server.correction.default,
lpw: false,
},
setModel(params) {
set((prev) => ({
Expand Down

0 comments on commit 9e017ee

Please sign in to comment.