In [None]:
import torch
from src.models.swin.window_attention import WindowAttention

# Example config
dim = 96
window_size = (7, 7)
num_heads = 3
attn_dropout = 0.1
proj_dropout = 0.2

# Create the module
attn = WindowAttention(dim=dim, window_size=window_size, num_heads=num_heads, attn_dropout=attn_dropout, proj_dropout=proj_dropout)

# Dummy input
B = 2                     # batch size
num_windows = 4           # e.g. 4 local windows
N = window_size[0] * window_size[1]  # tokens per window
x = torch.randn(B * num_windows, N, dim)  # [num_windows*B, N, C]

# Forward pass
with torch.no_grad():
    out = attn(x)  # no attn_mask for plain W-MSA

# Check result
print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")

Input shape:  torch.Size([8, 49, 96])
Output shape: torch.Size([8, 49, 96])


In [2]:
dim = 96
window_size = (7, 7)
num_heads = 3
attn_dropout = 0.1
proj_dropout = 0.2

# Create the attention module
attn = WindowAttention(dim=dim, window_size=window_size, num_heads=num_heads, attn_dropout=attn_dropout, proj_dropout=proj_dropout)

# ---- Dummy input ----
B = 2                                 # batch size
H = W = 56                            # fake image height & width
N = window_size[0] * window_size[1]   # 49 tokens per window
num_windows = (H // window_size[0]) * (W // window_size[1])
x = torch.randn(B * num_windows, N, dim)

# ---- Build shifted-window attention mask ----
def build_attn_mask(H, W, window_size, shift_size):
    img_mask = torch.zeros((1, H, W, 1))
    cnt = 0
    h_slices = (
        slice(0, -window_size),
        slice(-window_size, -shift_size),
        slice(-shift_size, None),
    )
    w_slices = (
        slice(0, -window_size),
        slice(-window_size, -shift_size),
        slice(-shift_size, None),
    )
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    # partition into windows
    def window_partition(x, window_size):
        B, H, W, C = x.shape
        x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        return x.view(-1, window_size * window_size, C)

    mask_windows = window_partition(img_mask, window_size).view(-1, N)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
    return attn_mask

shift_size = window_size[0] // 2
attn_mask = build_attn_mask(H, W, window_size[0], shift_size)

print("Attention mask shape:", attn_mask.shape)

# ---- Forward pass ----
with torch.no_grad():
    out = attn(x, attn_mask=attn_mask)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")

Attention mask shape: torch.Size([64, 49, 49])
Input shape:  torch.Size([128, 49, 96])
Output shape: torch.Size([128, 49, 96])


In [3]:
window_size = (2,2)

coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # 2, Wh, Ww
print("Coords:")
print(coords)
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
print("Relative coords:")
print(relative_coords)
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print("Relative processed coords:")
print(relative_coords)
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
print("Relative position index:")
print(relative_position_index)

Coords:
tensor([[[0, 0],
         [1, 1]],

        [[0, 1],
         [0, 1]]])
Relative coords:
tensor([[[ 0,  0, -1, -1],
         [ 0,  0, -1, -1],
         [ 1,  1,  0,  0],
         [ 1,  1,  0,  0]],

        [[ 0, -1,  0, -1],
         [ 1,  0,  1,  0],
         [ 0, -1,  0, -1],
         [ 1,  0,  1,  0]]])
Relative processed coords:
tensor([[[3, 1],
         [3, 0],
         [0, 1],
         [0, 0]],

        [[3, 2],
         [3, 1],
         [0, 2],
         [0, 1]],

        [[6, 1],
         [6, 0],
         [3, 1],
         [3, 0]],

        [[6, 2],
         [6, 1],
         [3, 2],
         [3, 1]]])
Relative position index:
tensor([[4, 3, 1, 0],
        [5, 4, 2, 1],
        [7, 6, 4, 3],
        [8, 7, 5, 4]])
