In [2]:
from typing import Tuple, Optional, List, Union, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

In [None]:
class FeedForward(nn.Module):
    def __init__(self,
                 in_features: int,
                 hidden_features: int,
                 out_features: int,
                 dropout: float = 0.1,) -> None:
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        x = self.dropout1(F.gelu(self.fc1(x)))
        return self.dropout2(self.fc2(x))
        

In [None]:
def get_relative_distances(window_size: int) -> torch.Tensor:


In [None]:
class WindowMultiHeadAttention(nn.Module):
    def __init__(self,
                 in_features: int,
                 window_size: int,
                 number_of_heads: int,
                 dropout_attention: float = 0.1,
                 dropout_projection: float = 0.1,
                 meta_network_hidden_features: int = 256,
                 sequential_self_attention: bool = False) -> None:
        super(WindowMultiHeadAttention, self).__init__()
        assert in_features % number_of_heads == 0, 'in_features must be divisible by number_of_heads'
        self.in_features: int = in_features
        self.window_size: int = window_size
        self.num_of_heads: int = number_of_heads
        self.sequential_self_attention: bool = sequential_self_attention
        self.linears: nn.Module = nn.ModuleList([nn.Linear(in_features, in_features) for _ in range(4)])
        self.attention_dropout: nn.Dropout = nn.Dropout(dropout_attention)
        self.dropout: nn.Dropout = nn.Dropout(dropout_projection)
        self.meta_network: nn.Module = nn.Sequential(
            nn.Linear(in_features=2, out_features=meta_network_hidden_features),
            nn.ReLU(),
            nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads),
        )
        self.tau = nn.Parameter(torch.ones(1, number_of_heads, 1, 1))
        self.__make_pairwise_relative_positions()
    
    def __make_pairwise_relative_positions(self) -> None:
        indices = torch.arange(self.window_size)
        coordinates = torch.stack(torch.meshgrid([indices, indices]), dim=0)
        coordinates = coordinates.flatten(1).transpose(0, 1)
        relative_distances = coordinates[None, :, :] - coordinates[:, None, :]
        # shape of relative_distances_log: [window_size ** 2, window_size ** 2, 2]
        relative_distances_log = torch.sign(relative_distances) * torch.log(torch.abs(relative_distances) + 1)
        self.register_buffer('relative_distances_log', relative_distances_log)
    
    def update_resolution(self,
                          new_window_size: int) -> None:
        self.window_size: int = new_window_size
        self.__make_pairwise_relative_positions()
    
    def __get_relative_positional_encodings(self) -> torch.Tensor:
        relative_position_bias: torch.Tensor = self.meta_network(self.relative_distances_log)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_of_heads, self.window_size ** 2, self.window_size ** 2)
        return relative_position_bias.unsqueeze(0)
    
    def __self_attention(self,
                         query: torch.Tensor,
                         key: torch.Tensor,
                         value: torch.Tensor,
                         batch_size: int,
                         tokens: int,
                         mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        attention_map = torch.matmul(query, key) / (torch.norm(query, dim=-1, keepdim=True) * torch.norm(key, dim=-1, keepdim=True))
        attention_map = attention_map / max(0.01, self.tau)
        attention_map = attention_map + self.__get_relative_positional_encodings()
        if mask is not None:
            