Skip to content

Commit

Permalink
feat(api): convert CNet for existing diffusion models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 15, 2023
1 parent 2c75311 commit 0dd8272
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 125 deletions.
268 changes: 144 additions & 124 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,141 @@
logger = getLogger(__name__)


def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
source: str,
device: str,
output_path: Path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
):
# CNet
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
device=device, dtype=dtype
)

if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor())

cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"down_block_0",
"down_block_1",
"down_block_2",
"down_block_3",
"down_block_4",
"down_block_5",
"down_block_6",
"down_block_7",
"down_block_8",
"down_block_9",
"down_block_10",
"down_block_11",
"mid_block_additional_residual",
"return_dict",
],
output_names=[
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
)
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
cnet = load_model(cnet_model_path)

# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)

# collate external tensor files into one
save_model(
cnet,
cnet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del pipe_cnet



@torch.no_grad()
def convert_diffusion_diffusers(
conversion: ConversionContext,
Expand All @@ -54,6 +189,7 @@ def convert_diffusion_diffusers(

dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)

# diffusers go into a directory rather than .onnx file
logger.info(
Expand All @@ -64,9 +200,13 @@ def convert_diffusion_diffusers(
logger.info("converting model with single VAE")

if path.exists(dest_path) and path.exists(model_index):
# TODO: check if CNet has been converted
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
if not path.exists(model_cnet):
logger.info("ONNX model was converted without a ControlNet UNet, converting one")
convert_diffusion_diffusers_cnet(conversion, source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size)
return (True, dest_path)
else:
logger.info("ONNX model already exists, skipping")
return (False, dest_path)

pipeline = StableDiffusionPipeline.from_pretrained(
source,
Expand Down Expand Up @@ -166,127 +306,7 @@ def convert_diffusion_diffusers(
)
del pipeline.unet

# CNet
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
device=device, dtype=dtype
)

if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor())

cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
pipe_cnet,
model_args=(
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2).to(device=device, dtype=dtype),
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size, unet_sample_size).to(
device=device, dtype=dtype
),
torch.randn(2, 320, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 2, unet_sample_size // 2).to(
device=device, dtype=dtype
),
torch.randn(2, 640, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 4, unet_sample_size // 4).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
torch.randn(2, 1280, unet_sample_size // 8, unet_sample_size // 8).to(
device=device, dtype=dtype
),
False,
),
output_path=cnet_path,
ordered_input_names=[
"sample",
"timestep",
"encoder_hidden_states",
"down_block_0",
"down_block_1",
"down_block_2",
"down_block_3",
"down_block_4",
"down_block_5",
"down_block_6",
"down_block_7",
"down_block_8",
"down_block_9",
"down_block_10",
"down_block_11",
"mid_block_additional_residual",
"return_dict",
],
output_names=[
"out_sample"
], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
"down_block_0": {0: "batch", 2: "height", 3: "width"},
"down_block_1": {0: "batch", 2: "height", 3: "width"},
"down_block_2": {0: "batch", 2: "height", 3: "width"},
"down_block_3": {0: "batch", 2: "height2", 3: "width2"},
"down_block_4": {0: "batch", 2: "height2", 3: "width2"},
"down_block_5": {0: "batch", 2: "height2", 3: "width2"},
"down_block_6": {0: "batch", 2: "height4", 3: "width4"},
"down_block_7": {0: "batch", 2: "height4", 3: "width4"},
"down_block_8": {0: "batch", 2: "height4", 3: "width4"},
"down_block_9": {0: "batch", 2: "height8", 3: "width8"},
"down_block_10": {0: "batch", 2: "height8", 3: "width8"},
"down_block_11": {0: "batch", 2: "height8", 3: "width8"},
"mid_block_additional_residual": {0: "batch", 2: "height8", 3: "width8"},
},
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
)
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
cnet = load_model(cnet_model_path)

# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)

# collate external tensor files into one
save_model(
cnet,
cnet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
convert_attribute=False,
)
del pipe_cnet
convert_diffusion_diffusers_cnet(conversion, source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size)

# VAE
if replace_vae is not None:
Expand Down
2 changes: 1 addition & 1 deletion gui/src/components/input/QueryList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export function QueryList<T>(props: QueryListProps<T>) {

function noneLabel(): Maybe<string> {
if (showNone) {
return t(`${labelKey}.none`);
return 'none';
}

return undefined;
Expand Down

0 comments on commit 0dd8272

Please sign in to comment.