Skip to content

Commit

Permalink
fix(models): update attn parameter name in CNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 5, 2023
1 parent 0fafc71 commit f73dce5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions api/onnx_web/models/cnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
num_attention_heads=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
Expand All @@ -307,7 +307,7 @@ def __init__(
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
Expand All @@ -321,7 +321,7 @@ def __init__(
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
num_attention_heads=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
Expand Down Expand Up @@ -367,7 +367,7 @@ def __init__(
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
num_attention_heads=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
Expand Down

0 comments on commit f73dce5

Please sign in to comment.