In [2]:
from src.models.segment_anything.modeling.image_encoder import Attention as ImageEncoderAttention, FlashRelativePositionAttention
import torch
def initialize_weights(module):
    if isinstance(module, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
# Test both implementations
B, H, W, C = 2, 8, 8, 64
input_tensor = torch.randn(B, H, W, C)

original_attention = ImageEncoderAttention(
    dim=C, num_heads=8, use_rel_pos=True, input_size=(H, W)
)
optimized_attention = FlashRelativePositionAttention(
    dim=C, num_heads=8, use_rel_pos=True, input_size=(H, W)
)
original_attention.apply(initialize_weights)
optimized_attention.load_state_dict(original_attention.state_dict())

# Get outputs
original_output = original_attention(input_tensor)
optimized_output = optimized_attention(input_tensor)

# Compare outputs
print("Original Output Shape:", original_output.shape)
print("Optimized Output Shape:", optimized_output.shape)
print("Difference between outputs:", torch.abs(original_output - optimized_output).max().item())
print("torch.equal:", torch.equal(original_output, optimized_output))

Original Output Shape: torch.Size([2, 8, 8, 64])
Optimized Output Shape: torch.Size([2, 8, 8, 64])
Difference between outputs: 1.7881393432617188e-07
torch.equal: False


In [None]:
from src.models.segment_anything.modeling.transformer import Attention, FlashAttention
import torch
# Test both implementations

def initialize_weights(module):
    if isinstance(module, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
            
torch.manual_seed(42)
batch_size, seq_len, embed_dim = 2, 16, 64
num_heads = 8

# Random input tensors
q = torch.randn(batch_size, seq_len, embed_dim)
k = torch.randn(batch_size, seq_len, embed_dim)
v = torch.randn(batch_size, seq_len, embed_dim)

# Instantiate both attention mechanisms
sam_attention = Attention(embedding_dim=embed_dim, num_heads=num_heads)
optimized_sam_attention = FlashAttention(embedding_dim=embed_dim, num_heads=num_heads)

sam_attention.apply(initialize_weights)
optimized_sam_attention.load_state_dict(sam_attention.state_dict())
# Get outputs
original_output = sam_attention(q, k, v)
optimized_output = optimized_sam_attention(q, k, v)

# Compare outputs
print("Original Output Shape:", original_output.shape)
print("Optimized Output Shape:", optimized_output.shape)
print("Difference between outputs:", torch.abs(original_output - optimized_output).max().item())
print("Are equal:", torch.equal(original_output, optimized_output))

Original Output Shape: torch.Size([2, 16, 64])
Optimized Output Shape: torch.Size([2, 16, 64])
Difference between outputs: 4.470348358154297e-07
Are equal: False
