diff --git a/models/attention_processor.py b/models/attention_processor.py index 1bfaa0258155..e39bdc0429c1 100644 --- a/models/attention_processor.py +++ b/models/attention_processor.py @@ -166,22 +166,28 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) + self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) if use_memory_efficient_attention_xformers: - if self.added_kv_proj_dim is not None: - # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported + if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( - "Memory efficient attention with `xformers` is currently not supported when" - " `self.added_kv_proj_dim` is defined." + f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" ) - elif not is_xformers_available(): + if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" @@ -233,6 +239,15 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -889,6 +904,71 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. @@ -1428,6 +1508,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor, diff --git a/models/unet_2d_condition.py b/models/unet_2d_condition.py index 484f9323c69f..106346070d94 100644 --- a/models/unet_2d_condition.py +++ b/models/unet_2d_condition.py @@ -261,6 +261,7 @@ def __init__( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: diff --git a/pipelines/versatile_diffusion/modeling_text_unet.py b/pipelines/versatile_diffusion/modeling_text_unet.py index af647fe810aa..a0dbdaa75230 100644 --- a/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/pipelines/versatile_diffusion/modeling_text_unet.py @@ -364,6 +364,7 @@ def __init__( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: