Skip to content

Commit

Permalink
fix(api): convert back to model format after blending, convert sample…
Browse files Browse the repository at this point in the history
…s as needed (#274)
  • Loading branch information
ssube committed Mar 22, 2023
1 parent 0315a8c commit c8aad85
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 28 deletions.
18 changes: 9 additions & 9 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")

torch_dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
dtype = ctx.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)

dest_path = path.join(ctx.model_path, name)
model_index = path.join(dest_path, "model_index.json")
Expand All @@ -117,7 +117,7 @@ def convert_diffusion_diffusers(

pipeline = StableDiffusionPipeline.from_pretrained(
source,
torch_dtype=torch_dtype,
torch_dtype=dtype,
use_auth_token=ctx.token,
).to(ctx.training_device)
output_path = Path(dest_path)
Expand Down Expand Up @@ -174,11 +174,11 @@ def convert_diffusion_diffusers(
pipeline.unet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
torch.randn(2).to(device=ctx.training_device, dtype=torch_dtype),
torch.randn(2).to(device=ctx.training_device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
unet_scale,
),
Expand Down Expand Up @@ -230,7 +230,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=torch_dtype),
).to(device=ctx.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae" / "model.onnx",
Expand All @@ -255,7 +255,7 @@ def convert_diffusion_diffusers(
vae_encoder,
model_args=(
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
device=ctx.training_device, dtype=torch_dtype
device=ctx.training_device, dtype=dtype
),
False,
),
Expand All @@ -279,7 +279,7 @@ def convert_diffusion_diffusers(
model_args=(
torch.randn(
1, vae_latent_channels, unet_sample_size, unet_sample_size
).to(device=ctx.training_device, dtype=torch_dtype),
).to(device=ctx.training_device, dtype=dtype),
False,
),
output_path=output_path / "vae_decoder" / "model.onnx",
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def blend_loras(
):
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.torch_dtype()
dtype = torch.float32

base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
Expand Down
10 changes: 5 additions & 5 deletions api/onnx_web/convert/diffusion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def blend_textual_inversions(
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = context.numpy_dtype()
dtype = np.float32
embeds = {}

for name, weight, base_token, inversion_format in inversions:
Expand Down Expand Up @@ -131,11 +131,11 @@ def blend_textual_inversions(
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
embedding_weights = numpy_helper.to_array(embedding_node)
base_weights = numpy_helper.to_array(embedding_node)

weights_dim = embedding_weights.shape[1]
weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)

for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
Expand All @@ -149,7 +149,7 @@ def blend_textual_inversions(
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(dtype), embedding_node.name
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
Expand Down
14 changes: 7 additions & 7 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,15 @@ def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
global timestep_dtype
timestep_dtype = timestep.dtype

logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if "onnx-fp16" in self.server.optimizations:
logger.info("converting UNet sample to ONNX fp16")
sample = sample.astype(np.float16)
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
elif sample.dtype != timestep.dtype:
logger.info("converting UNet sample to timestep dtype")
logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype)
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

if sample.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(np.float16)

return self.wrapped(
sample=sample,
timestep=timestep,
Expand Down
6 changes: 0 additions & 6 deletions api/onnx_web/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,3 @@ def torch_dtype(self):
return torch.float16
else:
return torch.float32

def numpy_dtype(self):
if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations:
return np.float16
else:
return np.float32

0 comments on commit c8aad85

Please sign in to comment.