Skip to content

Commit

Permalink
fix(api): blend embeddings into second tokenizer/text encoder for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 25, 2023
1 parent e338fcd commit fc02fa6
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ def load_pipeline(
vae_components = load_vae(server, device, model, params)
components.update(vae_components)

# additional options for panorama pipeline
if params.is_panorama():
components["window"] = params.tiles // 8
components["stride"] = params.stride // 8

pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
Expand All @@ -228,8 +223,15 @@ def load_pipeline(
components["text_encoder_2_session"], pipe
)

if "tokenizer" in components:
pipe.tokenizer = components["tokenizer"]

if "tokenizer_2" in components:
pipe.tokenizer_2 = components["tokenizer_2"]

if "unet_session" in components:
# unload old UNet
logger.debug("unloading previous Unet")
pipe.unet = None
run_gc([device])

Expand Down Expand Up @@ -300,20 +302,25 @@ def load_text_encoders(
torch_dtype,
params: ImageParams,
):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)

components = {}
components["tokenizer"] = tokenizer

text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
text_encoder_2 = None
components = {
"tokenizer": tokenizer,
}

if params.is_xl():
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
tokenizer_2 = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer_2",
torch_dtype=torch_dtype,
)
components["tokenizer_2"] = tokenizer_2

# blend embeddings, if any
if embeddings is not None and len(embeddings) > 0:
Expand All @@ -339,6 +346,23 @@ def load_text_encoders(
)
),
)
components["tokenizer"] = tokenizer

if params.is_xl():
text_encoder_2, tokenizer_2 = blend_textual_inversions(
server,
text_encoder_2,
tokenizer_2,
list(
zip(
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
components["tokenizer_2"] = tokenizer_2

# blend LoRAs, if any
if loras is not None and len(loras) > 0:
Expand Down

0 comments on commit fc02fa6

Please sign in to comment.