Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about the rpb in LegacyNeighborhoodAttention2D #61

Closed
lartpang opened this issue Oct 3, 2022 · 5 comments
Closed

A question about the rpb in LegacyNeighborhoodAttention2D #61

lartpang opened this issue Oct 3, 2022 · 5 comments
Labels
question Further information is requested

Comments

@lartpang
Copy link

lartpang commented Oct 3, 2022

My question

Why is the same relative position index used for several positions in the middle?

Information

def apply_pb(self, attn, height, width):
"""
RPB implementation by @qwopqwop200
https://github.com/qwopqwop200/Neighborhood-Attention-Transformer
"""
num_repeat_h = torch.ones(self.kernel_size,dtype=torch.long)
num_repeat_w = torch.ones(self.kernel_size,dtype=torch.long)
num_repeat_h[self.kernel_size//2] = height - (self.kernel_size-1)
num_repeat_w[self.kernel_size//2] = width - (self.kernel_size-1)
bias_hw = (self.idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*self.kernel_size-1)) + self.idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + self.idx_k
# Index flip
# Our RPB indexing in the kernel is in a different order, so we flip these indices to ensure weights match.
bias_idx = torch.flip(bias_idx.reshape(-1, self.kernel_size**2), [0])
return attn + self.rpb.flatten(1, 2)[:, bias_idx].reshape(self.num_heads, height * width, 1, self.kernel_size ** 2).transpose(0, 1)

A simple visualization:

rpb

The related code is copied from LegacyNeighborhoodAttention2D:

# %%
import matplotlib.pyplot as plt
import numpy as np
import torch

kernel_size = 3
height = width = 5
rpb_size = 2 * kernel_size - 1

# %%
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
shared_bg = np.zeros((height, width), dtype=np.uint8)

# %%
idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)
print(idx_k.reshape(kernel_size, kernel_size))

# %%
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
bias_hw = (
    idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2 * kernel_size - 1)
) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2)
print(bias_idx)
'''
tensor([[ 0,  1,  2,  5,  6,  7, 10, 11, 12],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 1,  2,  3,  6,  7,  8, 11, 12, 13],
        [ 2,  3,  4,  7,  8,  9, 12, 13, 14],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [ 5,  6,  7, 10, 11, 12, 15, 16, 17],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 6,  7,  8, 11, 12, 13, 16, 17, 18],
        [ 7,  8,  9, 12, 13, 14, 17, 18, 19],
        [10, 11, 12, 15, 16, 17, 20, 21, 22],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [11, 12, 13, 16, 17, 18, 21, 22, 23],
        [12, 13, 14, 17, 18, 19, 22, 23, 24]])
'''

# %%
for h in range(height):
    for w in range(width):
        new_bg = shared_bg.flatten().copy()
        new_bg[bias_idx[h * height + w]] = 255
        new_bg = new_bg.reshape(height, width)
        axes[h, w].imshow(new_bg)

# %%
plt.show()
@alihassanijr
Copy link
Member

alihassanijr commented Oct 3, 2022

Hello and thank you for your interest,

Just to give a bit of background: RPB is in theory a continuous function, at least they way we intended it for NA/DiNA.
Here we're only learning a discrete set of weights because our kernel size is typically fixed.

As for its implementation here, most tokens (non-edge cases) share an identical RPB grid: north, south, east, west -- and positions in between, i.e. northwest.
And of course there's a magnitude: 1 north, 2 north, etc.

As a result, if you look at a visualization of NA, you would see that if we don't consider edge cases, the key-value positions for the rest of the feature map is identical: query is centered, and the neighbors are wrapped around it, hence they share the same RPB.

It becomes different for the edge cases precisely because they are not centered. For instance, the north-west (top-left) most pixel is always attending to an "extended neighborhood", which is explained in the original NAT paper, therefore its relative positional biases with respect to its key-value pair, or neighborhood, would be different compared to non-edge cases where they're always centered.

To clarify further, you can try plotting much larger inputs, in which you would see the RPB difference only in the corners and see an identical RPB index in the middle.
By the way, thank you for taking the time to plot these, I'm sure it'll help other users as well.

I hope this explains the idea, but if that's not the case, please let us know so we can clarify further.

@alihassanijr alihassanijr added the question Further information is requested label Oct 3, 2022
@lartpang
Copy link
Author

lartpang commented Oct 3, 2022

@alihassanijr

Thanks for your reply!

About the original question

In my original example, some settings were blocking my understanding. I optimized the code and it is more intuitive now.
But this also leads to another problem, see the discussion in the next section.

import matplotlib.pyplot as plt
import numpy as np
import torch

# specify the height and width of the feature map
height = width = 10

# construct a figure containing height*width subfigures corresponding to different (h,w) pixel
fig, axes = plt.subplots(nrows=height, ncols=width, figsize=(8, 8))
fig.suptitle('All Index Windows of RPE for each position of H-W Plane')

# specify the size of kernel for position bias
kernel_size = 5

# construct a shared relative position bias map
rpb_size = 2 * kernel_size - 1
shared_rpb_bg = np.zeros((rpb_size, rpb_size), dtype=np.uint8)

idx_h = torch.arange(0, kernel_size)
idx_w = torch.arange(0, kernel_size)
# absolute 1D indices in the left-top window of the rpe map (2*kernel_size-1, 2*kernel_size-1)
# other window indices can be obtained by adding a new start index on this `idx_k`
idx_k = ((idx_h.unsqueeze(-1) * rpb_size) + idx_w).reshape(-1)

# construct indices of the window in rpe map for each (h,w) pixel
num_repeat_h = torch.ones(kernel_size, dtype=torch.long)
num_repeat_w = torch.ones(kernel_size, dtype=torch.long)
num_repeat_h[kernel_size // 2] = height - (kernel_size - 1)
num_repeat_w[kernel_size // 2] = width - (kernel_size - 1)
# the base h and w of the four edge regions is different from others
bias_hw = idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * rpb_size + idx_w.repeat_interleave(num_repeat_w)
# each (h,w) in the H-W plane corresponds to a window of kernel_size*kernel_size containing indices
bias_idx = (bias_hw.unsqueeze(-1) + idx_k).reshape(-1, kernel_size ** 2) # height*width,kernel_size**2

# traverse all positions to visualize and highlight their own index window in the shared rpe map
for h in range(height):
    for w in range(width):
        new_rpb_bg = shared_rpb_bg.flatten().copy()

        new_start_idx = h * height + w
        new_rpb_bg[bias_idx[new_start_idx]] = 255  # index the specific window in rpb map
        new_rpb_bg = new_rpb_bg.reshape(rpb_size, rpb_size)
        axes[h, w].imshow(new_rpb_bg)
        axes[h, w].set_title(f"Win {(h,w)}")

plt.show()

rpb-k5-h10

About the relative position bias for NAT

Let's consider a simple case, kernel_size=3, and the rpb map is [2*3-1, 2*3-1]=[5,5].

The real indices of rpb map is:

(-2, -2), (-2, -1), (-2, 0), (-2, 1), (-2, 2), 
(-1, -2), (-1, -1), (-1, 0), (-1, 1), (-1, 2), 
(0, -2), (0, -1), (0, 0), (0, 1), (0, 2),
(1, -2), (1, -1), (1, 0), (1, 1), (1, 2),
(2, -2), (2, -1), (2, 0), (2, 1), (2, 2),

In traditional attention, the rpb is simple and it has no edge regions that require special consideration. So it only has one index pattern: (Fixed: In fact, the traditional RPB is more like a special form of NAT with only edge regions. So the implementation here is generic.)

In the convolution-like NAT operation, the rpe become more complicated.

The index window in the non-edge region is consist with the following pattern window:

(-1, -1), (-1, 0), (-1, 1), 
(0, -1), (0, 0), (0, 1), 
(1, -1), (1, 0), (1, 1),

The index window in the edge region is:

# start at the pixel (h=0, w=0), we denote the matrix as $W_{0,0}$
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);
(2, 0), (2, 1), (2, 2);

# start at the pixel (h=0, w=1), we denote the matrix as $W_{0,1}$
(0, -1), (0, 0), (0, 1);
(1, -1), (1, 0), (1, 1);
(2, -1), (2, 0), (2, 1);

# start at the pixel (h=1, w=0), we denote the matrix as $W_{1,0}$
(-1, 0), (-1, 1), (-1, 2);
(0, 0), (0, 1), (0, 2);
(1, 0), (1, 1), (1, 2);

# ....

# start at the pixel (h=height-1, w=width-2), we denote the matrix as $W_{height-1,width-2}$
(-2, -1), (-2, 0), (-2, 1),
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 0), (0, 1),

# start at the pixel (h=height-1, w=width-1), we denote the matrix as $W_{height-1,width-1}$
(-2, -2), (-2, -1), (-2, 0), 
(-1, -2), (-1, -1), (-1, 0), 
(0, -2), (0, -1), (0, 0),

In current implementation of the rpb of LegacyNeighborhoodAttention2D, the index pattern does not correspond to the abovementioned real indices of rpb map.

@alihassanijr
Copy link
Member

Maybe it's the flip?

# Index flip
# Our RPB indexing in the kernel is in a different order, so we flip these indices to ensure weights match.
bias_idx = torch.flip(bias_idx.reshape(-1, self.kernel_size**2), [0])

We have this additional flip to make sure the behavior is identical to the behavior programmed in NATTEN.

@lartpang
Copy link
Author

lartpang commented Oct 3, 2022

@alihassanijr

oh.... I understand it. Thank you so much for your patient reply.

@alihassanijr
Copy link
Member

I'm closing this issue now because we're moving our extension to its own separate repository, and due to inactivity.

Please feel free to reopen it if you still have questions, or open an issue in NATTEN if it's related to that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants