From 6e23572231ced2788bc69a5326ec60d8ec9f3b28 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Fri, 12 May 2023 08:54:09 -0600 Subject: [PATCH] attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten --- models/attention.py | 174 +----------------- models/attention_processor.py | 129 ++++++++++++- models/modeling_utils.py | 46 +++++ models/unet_2d_blocks.py | 67 +++++-- .../pipeline_stable_diffusion_upscale.py | 12 +- 5 files changed, 234 insertions(+), 194 deletions(-) diff --git a/models/attention.py b/models/attention.py index 134f84fc9d50..0b313b83d360 100644 --- a/models/attention.py +++ b/models/attention.py @@ -11,189 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Callable, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn from ..utils import maybe_allow_in_graph -from ..utils.import_utils import is_xformers_available from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, bias=True) - - self._use_memory_efficient_attention_xformers = False - self._use_2_0_attn = True - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - if merge_head_and_batch: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): - head_size = self.num_heads - - if unmerge_head_and_batch: - batch_head_size, seq_len, dim = tensor.shape - batch_size = batch_head_size // head_size - - tensor = tensor.reshape(batch_size, head_size, seq_len, dim) - else: - batch_size, _, seq_len, dim = tensor.shape - - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn - - query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) - key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) - value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale - ) - hidden_states = hidden_states.to(query_proj.dtype) - elif use_torch_2_0_attn: - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" diff --git a/models/attention_processor.py b/models/attention_processor.py index b727c76e2137..f88400da0333 100644 --- a/models/attention_processor.py +++ b/models/attention_processor.py @@ -65,6 +65,10 @@ def __init__( out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -72,6 +76,12 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 @@ -91,7 +101,7 @@ def __init__( ) if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -407,10 +417,22 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -434,6 +456,14 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -474,11 +504,22 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) @@ -502,6 +543,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -762,12 +811,23 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -792,6 +852,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -801,6 +870,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -812,6 +889,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -840,6 +920,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -858,11 +947,22 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() @@ -887,6 +987,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -980,11 +1088,22 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -1025,6 +1144,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states diff --git a/models/modeling_utils.py b/models/modeling_utils.py index ef14ec3d09ef..e7cfcd71062f 100644 --- a/models/modeling_utils.py +++ b/models/modeling_utils.py @@ -583,6 +583,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: @@ -625,6 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, @@ -803,3 +805,47 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict): + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/models/unet_2d_blocks.py b/models/unet_2d_blocks.py index 2f7b19b7328a..0004f074c563 100644 --- a/models/unet_2d_blocks.py +++ b/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from .attention import AdaGroupNorm, AttentionBlock +from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D @@ -427,12 +427,17 @@ def __init__( for _ in range(num_layers): if self.add_attention: attentions.append( - AttentionBlock( + Attention( in_channels, - num_head_channels=attn_num_head_channels, + heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) else: @@ -711,12 +716,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1060,12 +1070,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1134,11 +1149,17 @@ def __init__( ) ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1703,12 +1724,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2037,12 +2063,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2109,11 +2140,17 @@ def __init__( ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index b7530ac4ec5c..6bb463a6a65f 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -19,11 +19,11 @@ import numpy as np import PIL import torch -import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -709,12 +709,14 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - # TODO(Patrick, William) - clean up when attention is refactored - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") - use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + ] # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory - if not use_torch_2_0_attn and not use_xformers: + if not use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype)