### Relative position bias

Ref Video: [Soroush Mehraban Video](https://www.youtube.com/watch?v=Ws2RAh_VDyU)

* In Vision Transformer paper, where images are divided into small patches to convert this as a token, given into the trasnformer it gives some output. Which shown in the below image

![fig_1](data/relative_position_bias/Screenshot%20from%202024-12-31%2013-55-18.png)

* If We change the order of the patches or the tokens, output tokenms will changes. On this cases both of the image input are same, when we merge the output it seems different. (See the below image)

![fig_1](data/relative_position_bias/Screenshot%20from%202024-12-31%2014-06-25.png)



![fig 3](data/relative_position_bias/Screenshot%20from%202024-12-31%2014-06-41.png)

* To mitigate this problem, Absolute position Embedding was introduced in the original ViT paper.

<img src="data/relative_position_bias/Screenshot from 2024-12-31 14-24-56.png" alt="fig 4" style="width:800px;height:500px;">

* But, There are some limitation in the Absolute Position Encoding refer the human image. So, they introduces Relative position Bias






* Consider you have an input image size of `4 * 4` and split into the patches with patch size of `2 * 2` to each with the dimension of `1`, it results in the layer input of size `1 * 4 * 1`  (`B * nT * C`).




In [21]:
import torch
from torch import nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sample_input = torch.randn((1,1,8,8)).to(device)

In [13]:
class PatchPartition(nn.Module):
  def __init__(self, in_channels, patch_size, embed_dims, patch_norm_needed:bool = True):
    super().__init__()
    self.patch_size = patch_size
    self.embed_dim = embed_dims
    self.embedding = nn.Conv2d(in_channels=in_channels, out_channels=self.embed_dim,kernel_size=self.patch_size, stride=self.patch_size)
    self.patch_norm_needed = patch_norm_needed
    if self.patch_norm_needed:
      self.patch_norm = nn.LayerNorm(self.embed_dim)
  def forward(self, x: torch.Tensor):
    x=self.embedding(x)
    print(x.shape)
    x = x.permute(0,2,3,1)
    B,H,W,C = x.shape
    x = x.reshape(B,-1,C)
    if self.patch_norm_needed:
      x=self.patch_norm(x)
    return x
partition_layer=PatchPartition(in_channels=1, patch_size=[2,2],embed_dims=1).to(device)
layer_input=partition_layer(sample_input)
layer_input.shape

torch.Size([1, 1, 4, 4])


torch.Size([1, 16, 1])

In [43]:
import pandas as pd

def append_df_to_markdown(file_name: str, df: pd.DataFrame):
    """
    Appends the given DataFrame as Markdown to the specified file.
    If the file doesn't exist, it creates a new one.

    Args:
        file_name (str): The name of the Markdown file.
        df (pd.DataFrame): The DataFrame to append as Markdown.

    Returns:
        None
    """
    try:
        # Convert the DataFrame to Markdown format
        markdown_data = df.to_markdown(index=False)
        
        # Open the file in append mode and write the Markdown data
        with open(file_name, 'a') as file:
            file.write('\n\n' + markdown_data + '\n\n')
        
        print(f"DataFrame successfully appended to {file_name}.")
    except Exception as e:
        print(f"An error occurred while appending to the file: {e}")



import imgkit
import os

def add_border_to_table_and_convert_to_image(styled_df, image_name=None):
    # Define custom CSS for the table border
    css_style = """
        <style>
            table, th, td {
                border: 1px solid black;
                border-collapse: collapse;
                padding: 5px;
            }
            th {
                background-color: #f2f2f2;
            }
        </style>
    """
    css_style = """
        <style>
            body {
                margin: 0;
                padding: 0;
                width: 100%;
                text-align: center;
            }
            table {
                width: 100%;  /* Ensure the table uses the full width */
                table-layout: fixed;  /* Ensure fixed table layout */
                border: 1px solid black;
                border-collapse: collapse;
                margin: 0 auto;  /* Center the table */
            }
            th, td {
                border: 1px solid black;
                padding: 8px;  /* Adjust padding for better spacing */
                text-align: center;  /* Align text to the center */
            }
            th {
                background-color: #f2f2f2;
            }
        </style>
    """
    

    
    
    # Convert the DataFrame to HTML and append the CSS style
    html_table = css_style + styled_df.to_html(index=True,escape=False)
    
    # Write the table to an HTML file
    html_filename = "output_table_with_border.html"
    with open(html_filename, "w") as file:
        file.write(html_table)
    
    # Setup directory for saving images
    image_folder = "table_images"
    os.makedirs(image_folder, exist_ok=True)
    
    # Get the current number of images
    existing_images = os.listdir(image_folder)
    image_count = len([f for f in existing_images if f.endswith(".png")])
    
    # Define the image file path with incrementing names
    if image_name is None:
        image_name=f"table_image_{image_count + 1}.png"
    image_filename = os.path.join(image_folder,image_name)
    
    # Use imgkit to convert the HTML file to an image
    imgkit.from_file(html_filename, image_filename)
    
    return f"Table exported to {html_filename} and image saved as {image_filename}"

# Define a function to style a specific substring
def style_first_substring(val):
    if isinstance(val, str):  # Ensure the value is a string
        # Change color of the substring "Dev" (example)
        temp_data=[int(i) for i in val.strip("]").strip("[").split(",")]
        first_value=temp_data[0] 
        second_value=temp_data[1]
        return f'[<span style="color: red;">{first_value}</span>, {second_value}]'
    return val

# Define a function to style a specific substring
def style_second_substring(val):
    if isinstance(val, str):  # Ensure the value is a string
        # Change color of the substring "Dev" (example)
        temp_data=[int(i) for i in val.strip("]").strip("[").split(",")]
        first_value=temp_data[0] 
        second_value=temp_data[1]
        return f'[{first_value},<span style="color: red;">{second_value}</span>]'
    return val

# Relative position Bias Table

* Relative position bias is a learnable parameter which is added the to the attention matrix to capture the positial relationship of each patches in the images.

* This addition of bias doesn't change the overall output. But it will ensure that model capacity to learn complext relationship is unaffected.

* To work in this flow, consider You have the layer input of shape `1 * 4 * 4 * 1`. It's represented as `B * H * W * C`. Where 
    
    * B - Batch size of the input
    * H - Window height of the input matrix
    * W - Window width of the input matrix
    * C - Dimension of the each patches




In [3]:
# consider you have 4 * 4 patches
import pandas as pd
patches  = torch.randn(4,4).to(device=device)
tokens = patches.view(-1,1)
tokens.shape

NameError: name 'torch' is not defined

In [3]:
import pandas as pd
import numpy as np


# Define row and column indices
window_size=[3,3]
row_index = np.arange(window_size[0])
col_name = np.arange(window_size[1])

# Create the DataFrame with desired values
data = {col: [[i, col] for i in row_index] for col in col_name}
base_image_patches = pd.DataFrame(data, index=row_index)
add_border_to_table_and_convert_to_image(base_image_patches,"base_image_patches.png")
base_image_patches

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,0,1,2
0,"[0, 0]","[0, 1]","[0, 2]"
1,"[1, 0]","[1, 1]","[1, 2]"
2,"[2, 0]","[2, 1]","[2, 2]"


* We have now 4 * 4 patches, relative position bias is added to make a positional relationship with each patches. So each patche in a 4 * 4 patches have relationship with all other patches. 

* For an example, consider a first patch which is `[0,0]` is related to all other poistion. as you seen in the images

In [None]:
# Realtive position table

In [4]:
relative_position_index=pd.DataFrame()

for i in base_image_patches.values.reshape(-1):
    for j in base_image_patches.values.reshape(-1):
        relative_position_index.loc[str(i),str(j)]=str([i[0]-j[0],i[1]-j[1]])
add_border_to_table_and_convert_to_image(relative_position_index,image_name="realtive_position_index.png")

relative_position_index

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[0, 0]","[0, -1]","[0, -2]","[-1, 0]","[-1, -1]","[-1, -2]","[-2, 0]","[-2, -1]","[-2, -2]"
"[0, 1]","[0, 1]","[0, 0]","[0, -1]","[-1, 1]","[-1, 0]","[-1, -1]","[-2, 1]","[-2, 0]","[-2, -1]"
"[0, 2]","[0, 2]","[0, 1]","[0, 0]","[-1, 2]","[-1, 1]","[-1, 0]","[-2, 2]","[-2, 1]","[-2, 0]"
"[1, 0]","[1, 0]","[1, -1]","[1, -2]","[0, 0]","[0, -1]","[0, -2]","[-1, 0]","[-1, -1]","[-1, -2]"
"[1, 1]","[1, 1]","[1, 0]","[1, -1]","[0, 1]","[0, 0]","[0, -1]","[-1, 1]","[-1, 0]","[-1, -1]"
"[1, 2]","[1, 2]","[1, 1]","[1, 0]","[0, 2]","[0, 1]","[0, 0]","[-1, 2]","[-1, 1]","[-1, 0]"
"[2, 0]","[2, 0]","[2, -1]","[2, -2]","[1, 0]","[1, -1]","[1, -2]","[0, 0]","[0, -1]","[0, -2]"
"[2, 1]","[2, 1]","[2, 0]","[2, -1]","[1, 1]","[1, 0]","[1, -1]","[0, 1]","[0, 0]","[0, -1]"
"[2, 2]","[2, 2]","[2, 1]","[2, 0]","[1, 2]","[1, 1]","[1, 0]","[0, 2]","[0, 1]","[0, 0]"


In [51]:
# Each element lies in the range of 2M-1 , we will create a relative position bias table
import random
from IPython.display import display, HTML


random.seed(10)
relative_position_bias_table=pd.DataFrame()



cnt = 0 
for i in range(-window_size[0]+1,window_size[0]): # along x axis
    for j in range(-window_size[1]+1,window_size[1]): # along y axis
        relative_position_bias_table.loc[str(i),str(j)]=f"{round(random.random(),4)}<br><br>(idx={cnt})"
        cnt+=1

# Define a function to style a specific substring
def style_bias_table_substring(val):
    if isinstance(val, str):  # Ensure the value is a string
        # print(val)

        if "(idx=10)" in val:
            print("yes")

            val=val.replace("(idx=10)","(idx=10) +2")

            return f'<span style="color: red;">{val}</span>'
        if "(idx=12)" in val:
            return f'<span style="color: blue;">{val}</span>'
    return val

relative_position_bias_table_styled = relative_position_bias_table.copy().style.format(style_bias_table_substring)

add_border_to_table_and_convert_to_image(relative_position_bias_table_styled,image_name="realtive_position_bias_table_picking_1.png")


html_output = relative_position_bias_table.to_html(escape=False)
display(HTML(html_output))


yes
QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,-2,-1,0,1,2
-2,0.5714 (idx=0),0.4289 (idx=1),0.5781 (idx=2),0.2061 (idx=3),0.8133 (idx=4)
-1,0.8236 (idx=5),0.6535 (idx=6),0.1602 (idx=7),0.5207 (idx=8),0.3278 (idx=9)
0,0.25 (idx=10),0.9528 (idx=11),0.9966 (idx=12),0.0446 (idx=13),0.8602 (idx=14)
1,0.6032 (idx=15),0.3816 (idx=16),0.2836 (idx=17),0.675 (idx=18),0.4568 (idx=19)
2,0.6859 (idx=20),0.6618 (idx=21),0.133 (idx=22),0.7678 (idx=23),0.9824 (idx=24)


In [61]:
import random
from IPython.display import display, HTML


random.seed(10)
relative_position_bias_table=pd.DataFrame()
relative_position_bias_table_orig = pd.DataFrame()



cnt = 0 
for i in range(-window_size[0]+1,window_size[0]): # along x axis
    for j in range(-window_size[1]+1,window_size[1]): # along y axis
        rand_data=round(random.random(),4)
        relative_position_bias_table.loc[str(i),str(j)]=f"{rand_data}<br><br>(idx={cnt})"
        relative_position_bias_table_orig.loc[str(i),str(j)]=rand_data
        cnt+=1

# Define a function to style a specific substring
def style_bias_table_substring(val):
    if isinstance(val, str):  # Ensure the value is a string
        # print(val)

        if "(idx=10)" in val:
            print("yes")

            val=val.replace("(idx=10)","(idx=10) +1")

            return f'<span style="color: red;">{val}</span>'
        if "(idx=11)" in val:
            return f'<span style="color: blue;">{val}</span>'
    return val

relative_position_bias_table_styled = relative_position_bias_table.copy().style.format(style_bias_table_substring)

add_border_to_table_and_convert_to_image(relative_position_bias_table_styled,image_name="realtive_position_bias_table_picking_2.png")


html_output = relative_position_bias_table.to_html(escape=False)
display(HTML(html_output))


yes
QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,-2,-1,0,1,2
-2,0.5714 (idx=0),0.4289 (idx=1),0.5781 (idx=2),0.2061 (idx=3),0.8133 (idx=4)
-1,0.8236 (idx=5),0.6535 (idx=6),0.1602 (idx=7),0.5207 (idx=8),0.3278 (idx=9)
0,0.25 (idx=10),0.9528 (idx=11),0.9966 (idx=12),0.0446 (idx=13),0.8602 (idx=14)
1,0.6032 (idx=15),0.3816 (idx=16),0.2836 (idx=17),0.675 (idx=18),0.4568 (idx=19)
2,0.6859 (idx=20),0.6618 (idx=21),0.133 (idx=22),0.7678 (idx=23),0.9824 (idx=24)


array([0.5714, 0.4289, 0.5781, 0.2061, 0.8133, 0.8236, 0.6535, 0.1602,
       0.5207, 0.3278, 0.25  , 0.9528, 0.9966, 0.0446, 0.8602, 0.6032,
       0.3816, 0.2836, 0.675 , 0.4568, 0.6859, 0.6618, 0.133 , 0.7678,
       0.9824])

In [11]:
temp_relative_position_idx_styled = relative_position_index.copy().style.format(style_first_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_styled,image_name="realtive_position_index_x_axis_step_1.png")
temp_relative_position_idx_styled

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[0, 0]","[0, -1]","[0, -2]","[-1, 0]","[-1, -1]","[-1, -2]","[-2, 0]","[-2, -1]","[-2, -2]"
"[0, 1]","[0, 1]","[0, 0]","[0, -1]","[-1, 1]","[-1, 0]","[-1, -1]","[-2, 1]","[-2, 0]","[-2, -1]"
"[0, 2]","[0, 2]","[0, 1]","[0, 0]","[-1, 2]","[-1, 1]","[-1, 0]","[-2, 2]","[-2, 1]","[-2, 0]"
"[1, 0]","[1, 0]","[1, -1]","[1, -2]","[0, 0]","[0, -1]","[0, -2]","[-1, 0]","[-1, -1]","[-1, -2]"
"[1, 1]","[1, 1]","[1, 0]","[1, -1]","[0, 1]","[0, 0]","[0, -1]","[-1, 1]","[-1, 0]","[-1, -1]"
"[1, 2]","[1, 2]","[1, 1]","[1, 0]","[0, 2]","[0, 1]","[0, 0]","[-1, 2]","[-1, 1]","[-1, 0]"
"[2, 0]","[2, 0]","[2, -1]","[2, -2]","[1, 0]","[1, -1]","[1, -2]","[0, 0]","[0, -1]","[0, -2]"
"[2, 1]","[2, 1]","[2, 0]","[2, -1]","[1, 1]","[1, 0]","[1, -1]","[0, 1]","[0, 0]","[0, -1]"
"[2, 2]","[2, 2]","[2, 1]","[2, 0]","[1, 2]","[1, 1]","[1, 0]","[0, 2]","[0, 1]","[0, 0]"


In [12]:
# At first we need to change the both x and y axis distace starts from 0 instead of negative values
temp_relative_position_idx = relative_position_index.copy()

for i in temp_relative_position_idx.index:
    for j in temp_relative_position_idx.columns:
        temp_data=[int(i) for i in temp_relative_position_idx.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[0]+=window_size[0]-1
        temp_relative_position_idx.loc[i,j]=str(temp_data)

temp_relative_position_idx_styled = temp_relative_position_idx.style.format(style_first_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_styled,image_name="realtive_position_index_x_axis_step_2.png")
# temp_relative_position_idx_styled.to_html()
temp_relative_position_idx_styled

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[2, 0]","[2, -1]","[2, -2]","[1, 0]","[1, -1]","[1, -2]","[0, 0]","[0, -1]","[0, -2]"
"[0, 1]","[2, 1]","[2, 0]","[2, -1]","[1, 1]","[1, 0]","[1, -1]","[0, 1]","[0, 0]","[0, -1]"
"[0, 2]","[2, 2]","[2, 1]","[2, 0]","[1, 2]","[1, 1]","[1, 0]","[0, 2]","[0, 1]","[0, 0]"
"[1, 0]","[3, 0]","[3, -1]","[3, -2]","[2, 0]","[2, -1]","[2, -2]","[1, 0]","[1, -1]","[1, -2]"
"[1, 1]","[3, 1]","[3, 0]","[3, -1]","[2, 1]","[2, 0]","[2, -1]","[1, 1]","[1, 0]","[1, -1]"
"[1, 2]","[3, 2]","[3, 1]","[3, 0]","[2, 2]","[2, 1]","[2, 0]","[1, 2]","[1, 1]","[1, 0]"
"[2, 0]","[4, 0]","[4, -1]","[4, -2]","[3, 0]","[3, -1]","[3, -2]","[2, 0]","[2, -1]","[2, -2]"
"[2, 1]","[4, 1]","[4, 0]","[4, -1]","[3, 1]","[3, 0]","[3, -1]","[2, 1]","[2, 0]","[2, -1]"
"[2, 2]","[4, 2]","[4, 1]","[4, 0]","[3, 2]","[3, 1]","[3, 0]","[2, 2]","[2, 1]","[2, 0]"


In [13]:
temp_relative_position_idx_styled = temp_relative_position_idx.copy().style.format(style_second_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_styled,image_name="realtive_position_index_x_axis_step_3.png")
temp_relative_position_idx_styled

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[2,0]","[2,-1]","[2,-2]","[1,0]","[1,-1]","[1,-2]","[0,0]","[0,-1]","[0,-2]"
"[0, 1]","[2,1]","[2,0]","[2,-1]","[1,1]","[1,0]","[1,-1]","[0,1]","[0,0]","[0,-1]"
"[0, 2]","[2,2]","[2,1]","[2,0]","[1,2]","[1,1]","[1,0]","[0,2]","[0,1]","[0,0]"
"[1, 0]","[3,0]","[3,-1]","[3,-2]","[2,0]","[2,-1]","[2,-2]","[1,0]","[1,-1]","[1,-2]"
"[1, 1]","[3,1]","[3,0]","[3,-1]","[2,1]","[2,0]","[2,-1]","[1,1]","[1,0]","[1,-1]"
"[1, 2]","[3,2]","[3,1]","[3,0]","[2,2]","[2,1]","[2,0]","[1,2]","[1,1]","[1,0]"
"[2, 0]","[4,0]","[4,-1]","[4,-2]","[3,0]","[3,-1]","[3,-2]","[2,0]","[2,-1]","[2,-2]"
"[2, 1]","[4,1]","[4,0]","[4,-1]","[3,1]","[3,0]","[3,-1]","[2,1]","[2,0]","[2,-1]"
"[2, 2]","[4,2]","[4,1]","[4,0]","[3,2]","[3,1]","[3,0]","[2,2]","[2,1]","[2,0]"


In [14]:
# At first we need to change the both x and y axis distace starts from 0 instead of negative values
temp_relative_position_idx = temp_relative_position_idx.copy()

for i in temp_relative_position_idx.index:
    for j in temp_relative_position_idx.columns:
        temp_data=[int(i) for i in temp_relative_position_idx.loc[i,j].strip("]").strip("[").split(",")]
        # temp_data[0]+=window_size[1]-0
        temp_data[1]+=window_size[1]-1
        temp_relative_position_idx.loc[i,j]=str(temp_data)

temp_relative_position_idx_styled = temp_relative_position_idx.style.format(style_second_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_styled,image_name="realtive_position_index_x_axis_step_4.png")
# temp_relative_position_idx_styled.to_html()
temp_relative_position_idx_styled



QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[2,2]","[2,1]","[2,0]","[1,2]","[1,1]","[1,0]","[0,2]","[0,1]","[0,0]"
"[0, 1]","[2,3]","[2,2]","[2,1]","[1,3]","[1,2]","[1,1]","[0,3]","[0,2]","[0,1]"
"[0, 2]","[2,4]","[2,3]","[2,2]","[1,4]","[1,3]","[1,2]","[0,4]","[0,3]","[0,2]"
"[1, 0]","[3,2]","[3,1]","[3,0]","[2,2]","[2,1]","[2,0]","[1,2]","[1,1]","[1,0]"
"[1, 1]","[3,3]","[3,2]","[3,1]","[2,3]","[2,2]","[2,1]","[1,3]","[1,2]","[1,1]"
"[1, 2]","[3,4]","[3,3]","[3,2]","[2,4]","[2,3]","[2,2]","[1,4]","[1,3]","[1,2]"
"[2, 0]","[4,2]","[4,1]","[4,0]","[3,2]","[3,1]","[3,0]","[2,2]","[2,1]","[2,0]"
"[2, 1]","[4,3]","[4,2]","[4,1]","[3,3]","[3,2]","[3,1]","[2,3]","[2,2]","[2,1]"
"[2, 2]","[4,4]","[4,3]","[4,2]","[3,4]","[3,3]","[3,2]","[2,4]","[2,3]","[2,2]"


In [15]:
add_border_to_table_and_convert_to_image(temp_relative_position_idx,image_name="realtive_position_index_x_axis_step_5.png")


QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


'Table exported to output_table_with_border.html and image saved as table_images/realtive_position_index_x_axis_step_5.png'

In [54]:
# At first we need to change the both x and y axis distace starts from 0 instead of negative values
temp_relative_position_idx_1 = temp_relative_position_idx.copy()

for i in temp_relative_position_idx_1.index:
    for j in temp_relative_position_idx_1.columns:
        temp_data=[int(i) for i in temp_relative_position_idx_1.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[0]*=2 * window_size[1]-1
        # temp_data[1]+=window_size[1]-1
        temp_relative_position_idx_1.loc[i,j]=str(temp_data)
temp_relative_position_idx_styled = temp_relative_position_idx_1.style.format(style_first_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_styled,image_name="realtive_position_index_x_axis_step_7.png")
temp_relative_position_idx_1

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]","[10, 2]","[10, 1]","[10, 0]","[5, 2]","[5, 1]","[5, 0]","[0, 2]","[0, 1]","[0, 0]"
"[0, 1]","[10, 3]","[10, 2]","[10, 1]","[5, 3]","[5, 2]","[5, 1]","[0, 3]","[0, 2]","[0, 1]"
"[0, 2]","[10, 4]","[10, 3]","[10, 2]","[5, 4]","[5, 3]","[5, 2]","[0, 4]","[0, 3]","[0, 2]"
"[1, 0]","[15, 2]","[15, 1]","[15, 0]","[10, 2]","[10, 1]","[10, 0]","[5, 2]","[5, 1]","[5, 0]"
"[1, 1]","[15, 3]","[15, 2]","[15, 1]","[10, 3]","[10, 2]","[10, 1]","[5, 3]","[5, 2]","[5, 1]"
"[1, 2]","[15, 4]","[15, 3]","[15, 2]","[10, 4]","[10, 3]","[10, 2]","[5, 4]","[5, 3]","[5, 2]"
"[2, 0]","[20, 2]","[20, 1]","[20, 0]","[15, 2]","[15, 1]","[15, 0]","[10, 2]","[10, 1]","[10, 0]"
"[2, 1]","[20, 3]","[20, 2]","[20, 1]","[15, 3]","[15, 2]","[15, 1]","[10, 3]","[10, 2]","[10, 1]"
"[2, 2]","[20, 4]","[20, 3]","[20, 2]","[15, 4]","[15, 3]","[15, 2]","[10, 4]","[10, 3]","[10, 2]"


In [63]:
# At first we need to change the both x and y axis distace starts from 0 instead of negative values
temp_relative_position_idx_2 = temp_relative_position_idx.copy()

for i in temp_relative_position_idx_2.index:
    for j in temp_relative_position_idx_2.columns:
        temp_data=[int(i) for i in temp_relative_position_idx_2.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[0]*=2 * window_size[1]-1
        # temp_data[1]+=window_size[1]-1
        temp_relative_position_idx_2.loc[i,j]=str(sum(temp_data))
# temp_relative_position_idx_styled = temp_relative_position_idx_2.style.format(style_first_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_2,image_name="realtive_position_index_x_axis_step_8.png")
temp_relative_position_idx_2

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]",12,11,10,7,6,5,2,1,0
"[0, 1]",13,12,11,8,7,6,3,2,1
"[0, 2]",14,13,12,9,8,7,4,3,2
"[1, 0]",17,16,15,12,11,10,7,6,5
"[1, 1]",18,17,16,13,12,11,8,7,6
"[1, 2]",19,18,17,14,13,12,9,8,7
"[2, 0]",22,21,20,17,16,15,12,11,10
"[2, 1]",23,22,21,18,17,16,13,12,11
"[2, 2]",24,23,22,19,18,17,14,13,12


In [64]:
# At first we need to change the both x and y axis distace starts from 0 instead of negative values
temp_relative_position_idx_3 = temp_relative_position_idx.copy()
flattened_bias_table=relative_position_bias_table_orig.values.reshape(-1)

for i in temp_relative_position_idx_3.index:
    for j in temp_relative_position_idx_3.columns:
        temp_data=[int(i) for i in temp_relative_position_idx_3.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[0]*=2 * window_size[1]-1
        final_index=sum(temp_data)
        # temp_data[1]+=window_size[1]-1
        temp_relative_position_idx_3.loc[i,j]=flattened_bias_table[final_index]

# temp_relative_position_idx_styled = temp_relative_position_idx_2.style.format(style_first_substring)
add_border_to_table_and_convert_to_image(temp_relative_position_idx_3,image_name="realtive_position_index_x_axis_step_9.png")
temp_relative_position_idx_3

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[1, 0]","[1, 1]","[1, 2]","[2, 0]","[2, 1]","[2, 2]"
"[0, 0]",0.9966,0.9528,0.25,0.1602,0.6535,0.8236,0.5781,0.4289,0.5714
"[0, 1]",0.0446,0.9966,0.9528,0.5207,0.1602,0.6535,0.2061,0.5781,0.4289
"[0, 2]",0.8602,0.0446,0.9966,0.3278,0.5207,0.1602,0.8133,0.2061,0.5781
"[1, 0]",0.2836,0.3816,0.6032,0.9966,0.9528,0.25,0.1602,0.6535,0.8236
"[1, 1]",0.675,0.2836,0.3816,0.0446,0.9966,0.9528,0.5207,0.1602,0.6535
"[1, 2]",0.4568,0.675,0.2836,0.8602,0.0446,0.9966,0.3278,0.5207,0.1602
"[2, 0]",0.133,0.6618,0.6859,0.2836,0.3816,0.6032,0.9966,0.9528,0.25
"[2, 1]",0.7678,0.133,0.6618,0.675,0.2836,0.3816,0.0446,0.9966,0.9528
"[2, 2]",0.9824,0.7678,0.133,0.4568,0.675,0.2836,0.8602,0.0446,0.9966


In [257]:
relative_position_bias_flatten_table=pd.DataFrame()
for i in range(-window_size[0]+1,window_size[0]): # along x axis
    for j in range(-window_size[1]+1,window_size[1]): # along y axis
        relative_position_bias_flatten_table.loc[str(i),str(j)]=random.random()
row_index=[str([int(i),int(j)])for i in relative_position_bias_table.index.values.tolist() for j in relative_position_bias_table.columns.values.tolist()]

flatten_df=pd.DataFrame(relative_position_bias_flatten_table.values.reshape(-1),index=row_index).reset_index()
flatten_df.columns = ["realtive_index","learnable_data"]
flatten_df

Unnamed: 0,realtive_index,learnable_data
0,"[-2, -3]",0.664492
1,"[-2, -2]",0.23924
2,"[-2, -1]",0.51524
3,"[-2, 0]",0.969378
4,"[-2, 1]",0.473377
5,"[-2, 2]",0.766129
6,"[-2, 3]",0.350526
7,"[-1, -3]",0.552074
8,"[-1, -2]",0.778742
9,"[-1, -1]",0.697927


In [226]:
import pandas as pd


# Define a function to style a specific substring
def style_substring(val):
    if isinstance(val, str):  # Ensure the value is a string
        # Change color of the substring "Dev" (example)
        temp_data=[int(i) for i in val.strip("]").strip("[").split(",")]
        first_value=temp_data[0] 
        second_value=temp_data[1]
        return f'[<span style="color: red;">{first_value}</span>, {second_value}]'
    return val

# Apply the style function to the entire DataFrame or specific columns
styled_df = data.style.format(style_substring)

# Display the styled DataFrame in Jupyter Notebook or save it as HTML
styled_df

# pd.options.html.border = 2
# Display the styled DataFrame in Jupyter Notebook or save it as HTML
styled_df.to_html("help.html",border=1)


In [224]:
# at first get the index in the table
# we need change the first index to  factor which go at the max
window_size=[3,4]
for i in data.index:
    for j in data.columns:
        temp_data=[int(i) for i in data.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[0]+=window_size[0]-1
        data.loc[i,j]=str(temp_data)

styled_df = data.style.format(style_substring)
styled_df.to_html("help.html",border=1)
styled_df

Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[0, 3]","[1, 0]","[1, 1]","[1, 2]","[1, 3]","[2, 0]","[2, 1]","[2, 2]","[2, 3]"
"[0, 0]","[2, 0]","[2, -1]","[2, -2]","[2, -3]","[1, 0]","[1, -1]","[1, -2]","[1, -3]","[0, 0]","[0, -1]","[0, -2]","[0, -3]"
"[0, 1]","[2, 1]","[2, 0]","[2, -1]","[2, -2]","[1, 1]","[1, 0]","[1, -1]","[1, -2]","[0, 1]","[0, 0]","[0, -1]","[0, -2]"
"[0, 2]","[2, 2]","[2, 1]","[2, 0]","[2, -1]","[1, 2]","[1, 1]","[1, 0]","[1, -1]","[0, 2]","[0, 1]","[0, 0]","[0, -1]"
"[0, 3]","[2, 3]","[2, 2]","[2, 1]","[2, 0]","[1, 3]","[1, 2]","[1, 1]","[1, 0]","[0, 3]","[0, 2]","[0, 1]","[0, 0]"
"[1, 0]","[3, 0]","[3, -1]","[3, -2]","[3, -3]","[2, 0]","[2, -1]","[2, -2]","[2, -3]","[1, 0]","[1, -1]","[1, -2]","[1, -3]"
"[1, 1]","[3, 1]","[3, 0]","[3, -1]","[3, -2]","[2, 1]","[2, 0]","[2, -1]","[2, -2]","[1, 1]","[1, 0]","[1, -1]","[1, -2]"
"[1, 2]","[3, 2]","[3, 1]","[3, 0]","[3, -1]","[2, 2]","[2, 1]","[2, 0]","[2, -1]","[1, 2]","[1, 1]","[1, 0]","[1, -1]"
"[1, 3]","[3, 3]","[3, 2]","[3, 1]","[3, 0]","[2, 3]","[2, 2]","[2, 1]","[2, 0]","[1, 3]","[1, 2]","[1, 1]","[1, 0]"
"[2, 0]","[4, 0]","[4, -1]","[4, -2]","[4, -3]","[3, 0]","[3, -1]","[3, -2]","[3, -3]","[2, 0]","[2, -1]","[2, -2]","[2, -3]"
"[2, 1]","[4, 1]","[4, 0]","[4, -1]","[4, -2]","[3, 1]","[3, 0]","[3, -1]","[3, -2]","[2, 1]","[2, 0]","[2, -1]","[2, -2]"


In [225]:
window_size=[3,4]
for i in data.index:
    for j in data.columns:
        temp_data=[int(i) for i in data.loc[i,j].strip("]").strip("[").split(",")]
        temp_data[1]+=window_size[1]-1
        data.loc[i,j]=str(temp_data)

styled_df = data.style.format(style_substring)
styled_df.to_html("help.html",index=False,border=1,classes='table table-bordered')
styled_df

Unnamed: 0,"[0, 0]","[0, 1]","[0, 2]","[0, 3]","[1, 0]","[1, 1]","[1, 2]","[1, 3]","[2, 0]","[2, 1]","[2, 2]","[2, 3]"
"[0, 0]","[2, 3]","[2, 2]","[2, 1]","[2, 0]","[1, 3]","[1, 2]","[1, 1]","[1, 0]","[0, 3]","[0, 2]","[0, 1]","[0, 0]"
"[0, 1]","[2, 4]","[2, 3]","[2, 2]","[2, 1]","[1, 4]","[1, 3]","[1, 2]","[1, 1]","[0, 4]","[0, 3]","[0, 2]","[0, 1]"
"[0, 2]","[2, 5]","[2, 4]","[2, 3]","[2, 2]","[1, 5]","[1, 4]","[1, 3]","[1, 2]","[0, 5]","[0, 4]","[0, 3]","[0, 2]"
"[0, 3]","[2, 6]","[2, 5]","[2, 4]","[2, 3]","[1, 6]","[1, 5]","[1, 4]","[1, 3]","[0, 6]","[0, 5]","[0, 4]","[0, 3]"
"[1, 0]","[3, 3]","[3, 2]","[3, 1]","[3, 0]","[2, 3]","[2, 2]","[2, 1]","[2, 0]","[1, 3]","[1, 2]","[1, 1]","[1, 0]"
"[1, 1]","[3, 4]","[3, 3]","[3, 2]","[3, 1]","[2, 4]","[2, 3]","[2, 2]","[2, 1]","[1, 4]","[1, 3]","[1, 2]","[1, 1]"
"[1, 2]","[3, 5]","[3, 4]","[3, 3]","[3, 2]","[2, 5]","[2, 4]","[2, 3]","[2, 2]","[1, 5]","[1, 4]","[1, 3]","[1, 2]"
"[1, 3]","[3, 6]","[3, 5]","[3, 4]","[3, 3]","[2, 6]","[2, 5]","[2, 4]","[2, 3]","[1, 6]","[1, 5]","[1, 4]","[1, 3]"
"[2, 0]","[4, 3]","[4, 2]","[4, 1]","[4, 0]","[3, 3]","[3, 2]","[3, 1]","[3, 0]","[2, 3]","[2, 2]","[2, 1]","[2, 0]"
"[2, 1]","[4, 4]","[4, 3]","[4, 2]","[4, 1]","[3, 4]","[3, 3]","[3, 2]","[3, 1]","[2, 4]","[2, 3]","[2, 2]","[2, 1]"


In [208]:
styled_df.to_html()

'<style type="text/css">\n</style>\n<table id="T_75c8d">\n  <thead>\n    <tr>\n      <th class="blank level0" >&nbsp;</th>\n      <th id="T_75c8d_level0_col0" class="col_heading level0 col0" >[0, 0]</th>\n      <th id="T_75c8d_level0_col1" class="col_heading level0 col1" >[0, 1]</th>\n      <th id="T_75c8d_level0_col2" class="col_heading level0 col2" >[0, 2]</th>\n      <th id="T_75c8d_level0_col3" class="col_heading level0 col3" >[0, 3]</th>\n      <th id="T_75c8d_level0_col4" class="col_heading level0 col4" >[1, 0]</th>\n      <th id="T_75c8d_level0_col5" class="col_heading level0 col5" >[1, 1]</th>\n      <th id="T_75c8d_level0_col6" class="col_heading level0 col6" >[1, 2]</th>\n      <th id="T_75c8d_level0_col7" class="col_heading level0 col7" >[1, 3]</th>\n      <th id="T_75c8d_level0_col8" class="col_heading level0 col8" >[2, 0]</th>\n      <th id="T_75c8d_level0_col9" class="col_heading level0 col9" >[2, 1]</th>\n      <th id="T_75c8d_level0_col10" class="col_heading level0 col1

  relative_position_bias_table.style.applymap(color_cell)


Unnamed: 0,-3,-2,-1,0,1,2,3
-2,0.5714 (idx=0),0.4289 (idx=1),0.5781 (idx=2),0.2061 (idx=3),0.8133 (idx=4),0.8236 (idx=5),0.6535 (idx=6)
-1,0.1602 (idx=7),0.5207 (idx=8),0.3278 (idx=9),0.25 (idx=10),0.9528 (idx=11),0.9966 (idx=12),0.0446 (idx=13)
0,0.8602 (idx=14),0.6032 (idx=15),0.3816 (idx=16),0.2836 (idx=17),0.675 (idx=18),0.4568 (idx=19),0.6859 (idx=20)
1,0.6618 (idx=21),0.133 (idx=22),0.7678 (idx=23),0.9824 (idx=24),0.9694 (idx=25),0.6133 (idx=26),0.0443 (idx=27)
2,0.0041 (idx=28),0.134 (idx=29),0.941 (idx=30),0.3029 (idx=31),0.3661 (idx=32),0.8982 (idx=33),0.3144 (idx=34)


In [220]:
data=pd.DataFrame()

for i in base_image_patches.values.reshape(-1):
    for j in base_image_patches.values.reshape(-1):
        data.loc[str(i),str(j)]=str([i[0]-j[0],i[1]-j[1]])
 

add_border_to_table_and_convert_to_image(data)

QStandardPaths: XDG_RUNTIME_DIR not set, defaulting to '/tmp/runtime-root'
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


'Table exported to output_table_with_border.html and image saved as table_images/table_image_2.png'

In [143]:
import numpy as np
window_size=[3,4]
row_coords = np.arange(window_size[0])
column_coords = np.arange(window_size[1])
axis_1,axis_2=np.meshgrid(row_coords, column_coords) # 2, Wh, Ww
axis_1

array([[0, 1, 2],
       [0, 1, 2],
       [0, 1, 2],
       [0, 1, 2]])

In [29]:
"""Since you flatten the relative position index table, we have index of the flatten table we need to point the each values"""
# to use this logic 
# we need a relative position index table as you seen ableve
# it startes from 0,0 and move left to right at each row in backwards


# First calculate the how many number of rows we need to skip


# get pair-wise relative position index for each token inside the window and get the value from the bias table
import torch
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
coords_flatten

tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])

In [30]:
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2]],

        [[ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1]],

        [[ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
  

In [31]:
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  0],
         [-1, -1],
         [-1, -2]],

        [[ 1,  1],
         [ 1,  0],
         [ 1, -1],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  1],
         [-1,  0],
         [-1, -1]],

        [[ 1,  2],
         [ 1,  1],
         [ 1,  0],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  2],
  

In [21]:
relative_coords.shape

torch.Size([9, 9, 2])

In [32]:
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 0]

tensor([[2, 2, 2, 1, 1, 1, 0, 0, 0],
        [2, 2, 2, 1, 1, 1, 0, 0, 0],
        [2, 2, 2, 1, 1, 1, 0, 0, 0],
        [3, 3, 3, 2, 2, 2, 1, 1, 1],
        [3, 3, 3, 2, 2, 2, 1, 1, 1],
        [3, 3, 3, 2, 2, 2, 1, 1, 1],
        [4, 4, 4, 3, 3, 3, 2, 2, 2],
        [4, 4, 4, 3, 3, 3, 2, 2, 2],
        [4, 4, 4, 3, 3, 3, 2, 2, 2]])

In [33]:
relative_coords[:, :, 1] += window_size[1] - 1

relative_coords[:, :, 1]

tensor([[2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2],
        [2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2],
        [2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2]])

In [37]:
relative_coords[:, :, 0] * (2 * window_size[1] -1)

tensor([[10, 10, 10,  5,  5,  5,  0,  0,  0],
        [10, 10, 10,  5,  5,  5,  0,  0,  0],
        [10, 10, 10,  5,  5,  5,  0,  0,  0],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [20, 20, 20, 15, 15, 15, 10, 10, 10],
        [20, 20, 20, 15, 15, 15, 10, 10, 10],
        [20, 20, 20, 15, 15, 15, 10, 10, 10]])

In [24]:
# relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
# relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
relative_position_index.view(-1)

tensor([12, 11, 10,  7,  6,  5,  2,  1,  0, 13, 12, 11,  8,  7,  6,  3,  2,  1,
        14, 13, 12,  9,  8,  7,  4,  3,  2, 17, 16, 15, 12, 11, 10,  7,  6,  5,
        18, 17, 16, 13, 12, 11,  8,  7,  6, 19, 18, 17, 14, 13, 12,  9,  8,  7,
        22, 21, 20, 17, 16, 15, 12, 11, 10, 23, 22, 21, 18, 17, 16, 13, 12, 11,
        24, 23, 22, 19, 18, 17, 14, 13, 12])

In [25]:
relative_coords[:, :, 0]

tensor([[10, 10, 10,  5,  5,  5,  0,  0,  0],
        [10, 10, 10,  5,  5,  5,  0,  0,  0],
        [10, 10, 10,  5,  5,  5,  0,  0,  0],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [15, 15, 15, 10, 10, 10,  5,  5,  5],
        [20, 20, 20, 15, 15, 15, 10, 10, 10],
        [20, 20, 20, 15, 15, 15, 10, 10, 10],
        [20, 20, 20, 15, 15, 15, 10, 10, 10]])

In [26]:
relative_coords[:, :, 1]

tensor([[2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2],
        [2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2],
        [2, 1, 0, 2, 1, 0, 2, 1, 0],
        [3, 2, 1, 3, 2, 1, 3, 2, 1],
        [4, 3, 2, 4, 3, 2, 4, 3, 2]])

In [27]:
relative_position_index

tensor([[12, 11, 10,  7,  6,  5,  2,  1,  0],
        [13, 12, 11,  8,  7,  6,  3,  2,  1],
        [14, 13, 12,  9,  8,  7,  4,  3,  2],
        [17, 16, 15, 12, 11, 10,  7,  6,  5],
        [18, 17, 16, 13, 12, 11,  8,  7,  6],
        [19, 18, 17, 14, 13, 12,  9,  8,  7],
        [22, 21, 20, 17, 16, 15, 12, 11, 10],
        [23, 22, 21, 18, 17, 16, 13, 12, 11],
        [24, 23, 22, 19, 18, 17, 14, 13, 12]])

In [72]:
# each patch or token are related to all others token, each token has a 2M-1 relative positions
# for an example 
total_positions=(2*4-1)*(2*4-1)
relative_position_bias_table = torch.randn(total_positions).to(device)

relative_position_bias_table


tensor([-0.6425,  0.2957,  0.1081,  1.8895, -0.0268, -0.4967,  0.8868,  1.0817,
         0.5598, -0.0542, -2.7614, -0.7690, -1.1982,  0.2637,  0.8191, -0.1034,
        -0.7390,  0.8793, -0.4680,  1.3096, -0.9973, -1.3211,  1.1276,  1.5724,
        -2.0000,  0.9001,  0.3333, -1.1355,  0.3097, -1.2543,  1.2419, -0.1193,
        -1.5015,  1.2621,  1.8425,  0.3884,  1.1521,  0.0786, -0.1984,  0.3343,
        -0.1842, -1.0856, -0.2024,  2.0813, -0.4464, -0.7416,  0.1038, -0.2292,
        -0.9966], device='cuda:0')

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
# get the index value for the each tokens, where in our case we have 16 tokens with 1 dimensions






In [153]:
# relative position index

window_size =[3,4]
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])

coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
coords_flatten

tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
        [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]])

In [154]:
coords_flatten[0,:, None]

tensor([[0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [2]])

In [155]:
coords_flatten[:, :, None].shape

torch.Size([2, 12, 1])

In [156]:
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords

tensor([[[ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
         [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
         [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
         [ 0,  0,  0,  0, -1, -1, -1, -1, -2, -2, -2, -2],
         [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
         [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
         [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
         [ 1,  1,  1,  1,  0,  0,  0,  0, -1, -1, -1, -1],
         [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
         [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
         [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0],
         [ 2,  2,  2,  2,  1,  1,  1,  1,  0,  0,  0,  0]],

        [[ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, -3],
         [ 1,  0, -1, -2,  1,  0, -1, -2,  1,  0, -1, -2],
         [ 2,  1,  0, -1,  2,  1,  0, -1,  2,  1,  0, -1],
         [ 3,  2,  1,  0,  3,  2,  1,  0,  3,  2,  1,  0],
         [ 0, -1, -2, -3,  0, -1, -2, -3,  0, -1, -2, 

In [157]:
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [ 0, -3],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-1, -3],
         [-2,  0],
         [-2, -1],
         [-2, -2],
         [-2, -3]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  1],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  2],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  3],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  3],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  3],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 1, -3],
    

In [158]:
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [ 0, -3],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-1, -3],
         [-2,  0],
         [-2, -1],
         [-2, -2],
         [-2, -3]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  1],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  2],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  3],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  3],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  3],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 1, -3],
    

In [159]:
relative_coords[:, :, 0]=relative_coords[:, :, 0]+window_size[0]-1
relative_coords[:, :, 0]

tensor([[2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
        [2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
        [2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
        [2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0],
        [3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1],
        [3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1],
        [3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1],
        [3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1],
        [4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2],
        [4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2],
        [4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2],
        [4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2]])

In [160]:
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 1]

tensor([[3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
        [4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1],
        [5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2],
        [6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3],
        [3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
        [4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1],
        [5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2],
        [6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3],
        [3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0],
        [4, 3, 2, 1, 4, 3, 2, 1, 4, 3, 2, 1],
        [5, 4, 3, 2, 5, 4, 3, 2, 5, 4, 3, 2],
        [6, 5, 4, 3, 6, 5, 4, 3, 6, 5, 4, 3]])

In [161]:
relative_coords[:, :, 0] *= 2 * window_size[1] - 1

relative_coords[:, :, 0]

tensor([[14, 14, 14, 14,  7,  7,  7,  7,  0,  0,  0,  0],
        [14, 14, 14, 14,  7,  7,  7,  7,  0,  0,  0,  0],
        [14, 14, 14, 14,  7,  7,  7,  7,  0,  0,  0,  0],
        [14, 14, 14, 14,  7,  7,  7,  7,  0,  0,  0,  0],
        [21, 21, 21, 21, 14, 14, 14, 14,  7,  7,  7,  7],
        [21, 21, 21, 21, 14, 14, 14, 14,  7,  7,  7,  7],
        [21, 21, 21, 21, 14, 14, 14, 14,  7,  7,  7,  7],
        [21, 21, 21, 21, 14, 14, 14, 14,  7,  7,  7,  7],
        [28, 28, 28, 28, 21, 21, 21, 21, 14, 14, 14, 14],
        [28, 28, 28, 28, 21, 21, 21, 21, 14, 14, 14, 14],
        [28, 28, 28, 28, 21, 21, 21, 21, 14, 14, 14, 14],
        [28, 28, 28, 28, 21, 21, 21, 21, 14, 14, 14, 14]])

In [162]:
relative_coords.sum(-1)

tensor([[17, 16, 15, 14, 10,  9,  8,  7,  3,  2,  1,  0],
        [18, 17, 16, 15, 11, 10,  9,  8,  4,  3,  2,  1],
        [19, 18, 17, 16, 12, 11, 10,  9,  5,  4,  3,  2],
        [20, 19, 18, 17, 13, 12, 11, 10,  6,  5,  4,  3],
        [24, 23, 22, 21, 17, 16, 15, 14, 10,  9,  8,  7],
        [25, 24, 23, 22, 18, 17, 16, 15, 11, 10,  9,  8],
        [26, 25, 24, 23, 19, 18, 17, 16, 12, 11, 10,  9],
        [27, 26, 25, 24, 20, 19, 18, 17, 13, 12, 11, 10],
        [31, 30, 29, 28, 24, 23, 22, 21, 17, 16, 15, 14],
        [32, 31, 30, 29, 25, 24, 23, 22, 18, 17, 16, 15],
        [33, 32, 31, 30, 26, 25, 24, 23, 19, 18, 17, 16],
        [34, 33, 32, 31, 27, 26, 25, 24, 20, 19, 18, 17]])

In [151]:
relative_coords

tensor([[[ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [ 0, -3],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-1, -3],
         [-2,  0],
         [-2, -1],
         [-2, -2],
         [-2, -3]],

        [[ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [ 0, -2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-1, -2],
         [-2,  1],
         [-2,  0],
         [-2, -1],
         [-2, -2]],

        [[ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [ 0, -1],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-1, -1],
         [-2,  2],
         [-2,  1],
         [-2,  0],
         [-2, -1]],

        [[ 0,  3],
         [ 0,  2],
         [ 0,  1],
         [ 0,  0],
         [-1,  3],
         [-1,  2],
         [-1,  1],
         [-1,  0],
         [-2,  3],
         [-2,  2],
         [-2,  1],
         [-2,  0]],

        [[ 1,  0],
         [ 1, -1],
         [ 1, -2],
         [ 1, -3],
    