In [2]:
import torch
from src.models.swin.window_utils import window_partition

B, H, W, C = (1, 8, 8, 1)
window_size = 4

# Create a tensor with distinct values to visualize the partitioning
x = torch.arange(B * H * W * C).view(B, H, W, C)

shift_size = 0

if shift_size > 0:
    x = torch.roll(
        x, shifts=(-shift_size, -shift_size), dims=(1, 2)
    )
else:
    x = x

print("Input shape:", x.shape)
print("Input tensor (channel 0):")
print(x[0, :, :, 0])  # show only first channel

# Partition into windows
windows = window_partition(x, window_size)
print("\nOutput shape:", windows.shape)
print("Number of windows:", windows.shape[0])

# Visualize first window’s first channel
print("\nFirst window (channel 0):")
print(windows[0, :, :, 0])

# Visualize second window’s first channel
print("\nSecond window (channel 0):")
print(windows[1, :, :, 0])

Input shape: torch.Size([1, 8, 8, 1])
Input tensor (channel 0):
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31],
        [32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47],
        [48, 49, 50, 51, 52, 53, 54, 55],
        [56, 57, 58, 59, 60, 61, 62, 63]])

Output shape: torch.Size([4, 4, 4, 1])
Number of windows: 4

First window (channel 0):
tensor([[ 0,  1,  2,  3],
        [ 8,  9, 10, 11],
        [16, 17, 18, 19],
        [24, 25, 26, 27]])

Second window (channel 0):
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15],
        [20, 21, 22, 23],
        [28, 29, 30, 31]])


In [3]:
import torch

B, H, W, C = (1, 8, 8, 1)
window_size = 4

# Create a tensor with distinct values to visualize the partitioning
x = torch.arange(B * H * W * C).view(B, H, W, C)

shift_size = window_size // 2

if shift_size > 0:
    x = torch.roll(
        x, shifts=(-shift_size, -shift_size), dims=(1, 2)
    )
else:
    x = x

print("Input shape:", x.shape)
print("Input tensor (channel 0):")
print(x[0, :, :, 0])  # show only first channel

# Partition into windows
windows = window_partition(x, window_size)
print("\nOutput shape:", windows.shape)
print("Number of windows:", windows.shape[0])

# Visualize first window’s first channel
print("\nFirst window (channel 0):")
print(windows[0, :, :, 0])

# Visualize second window’s first channel
print("\nLast window (channel 0):")
print(windows[3, :, :, 0])

Input shape: torch.Size([1, 8, 8, 1])
Input tensor (channel 0):
tensor([[18, 19, 20, 21, 22, 23, 16, 17],
        [26, 27, 28, 29, 30, 31, 24, 25],
        [34, 35, 36, 37, 38, 39, 32, 33],
        [42, 43, 44, 45, 46, 47, 40, 41],
        [50, 51, 52, 53, 54, 55, 48, 49],
        [58, 59, 60, 61, 62, 63, 56, 57],
        [ 2,  3,  4,  5,  6,  7,  0,  1],
        [10, 11, 12, 13, 14, 15,  8,  9]])

Output shape: torch.Size([4, 4, 4, 1])
Number of windows: 4

First window (channel 0):
tensor([[18, 19, 20, 21],
        [26, 27, 28, 29],
        [34, 35, 36, 37],
        [42, 43, 44, 45]])

Last window (channel 0):
tensor([[54, 55, 48, 49],
        [62, 63, 56, 57],
        [ 6,  7,  0,  1],
        [14, 15,  8,  9]])


In [None]:
H, W = (8,8)
window_size = 4
shift_size = window_size // 2

img_mask = torch.empty((H, W))

# Boundaries of the 3x3 areas
h_regions = [(0, -window_size), (-window_size, -shift_size), (-shift_size, None)]
w_regions = [(0, -window_size), (-window_size, -shift_size), (-shift_size, None)]

# Assign unique region IDs to each of the 3×3 areas
for idx, (hs, ws) in enumerate([(h, w) for h in h_regions for w in w_regions]):
    img_mask[hs[0]:hs[1], ws[0]:ws[1]] = idx

img_mask = img_mask.unsqueeze(2).unsqueeze(0)

print("Image mask shape:", img_mask.shape)

# Visualize image_mask
print("\nImage mask:")
print(img_mask.squeeze(-1).squeeze(0)) # squeeze it for better visibility

Image mask shape: torch.Size([1, 8, 8, 1])

Image mask:
tensor([[0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [0., 0., 0., 0., 1., 1., 2., 2.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [3., 3., 3., 3., 4., 4., 5., 5.],
        [6., 6., 6., 6., 7., 7., 8., 8.],
        [6., 6., 6., 6., 7., 7., 8., 8.]])


In [None]:
B, H, W, C = (10,8,8,3)
window_size = 4
shift_size = window_size // 2

img_mask = torch.empty((H, W))

# Boundaries of the 3x3 areas
h_regions = [(0, -window_size), (-window_size, -shift_size), (-shift_size, None)]
w_regions = [(0, -window_size), (-window_size, -shift_size), (-shift_size, None)]

# Assign unique region IDs to each of the 3×3 areas
for idx, (hs, ws) in enumerate([(h, w) for h in h_regions for w in w_regions]):
    img_mask[hs[0]:hs[1], ws[0]:ws[1]] = idx

img_mask = img_mask.unsqueeze(2).unsqueeze(0)

B, H, W, C = img_mask.shape

# Reshape to separate windows: [B, H, W, C] → [B, H//M, M, W//M, M, C]
x = img_mask.view(B, H // window_size, window_size, W // window_size, window_size, C)

# Permute to group windows: [B, H//M, M, W//M, M, C] → [B, H//M, W//M, M, M, C]
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()

# Flatten batch and window dimensions: [B, H//M, W//M, M, M, C] → [B*num_windows, M, M, C]
mask_windows = x.view(-1, window_size, window_size, C)
mask_windows = mask_windows.view(-1, window_size * window_size)
attention_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

print("Image mask shape:", attention_mask.shape)

# Visualize attention mask
print("\nAttention mask (first window):")
print(attention_mask[0])

# Visualize attention mask
print("\nAttention mask (second window):")
print(attention_mask[1])

Image mask shape: torch.Size([4, 16, 16])

Attention mask (first window):
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0