In [None]:
import torch
from torch import nn
from typing import Optional, Tuple, Union, Dict
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
    from flash_attn import flash_attn_varlen_func, flash_attn_func

    from transformers.modeling_flash_attention_utils import _flash_attention_forward
else:
    flash_attn_varlen_func = None

In [None]:
class Config():
    def __init__(self):
        self.hidden_size = 768
        self.num_attention_heads = 12
        self.attention_dropout = 0.0

config = Config()

In [None]:
class VisionFlashAttention2(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False
    ) -> torch.Tensor:
        batch_size, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
        value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)

        attn_output = flash_attn_func(query_states, key_states, value_states)
        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()

        # cu_seqlens = [0] + [q_len] * batch_size
        # cu_seqlens = torch.Tensor(cu_seqlens).cumsum(dim=0, dtype=torch.int32)
        # max_seqlen = max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

        # attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
        #     batch_size, q_len, self.embed_dim
        # )
        attn_output = self.out_proj(attn_output)
        return attn_output, None

In [None]:
class SigLipAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, dtype=torch.bfloat16)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        batch_size, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)

        k_v_seq_len = key_states.shape[-2]
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale

        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
            raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")

        if attention_mask is not None:
            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
                raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
            raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights

In [None]:
flash_attn = VisionFlashAttention2(config).cuda()
siglip_attn = SigLipAttention(config).cuda()
flash_attn.q_proj = siglip_attn.q_proj
flash_attn.k_proj = siglip_attn.k_proj
flash_attn.v_proj = siglip_attn.v_proj
flash_attn.out_proj = siglip_attn.out_proj

In [None]:
x = torch.rand((10, 729, 768), dtype=torch.bfloat16).cuda()
flash_attn_output, _ = flash_attn(x)
siglip_attn_output, _ = siglip_attn(x)