In [3]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# use Performer, as it had the best reported numbers

from performer_pytorch import SelfAttention as PerformerAttention

In [5]:
def exists(val):
    return val is not None

def get_module_device(module):
    return next(module.parameters()).device

def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False
    ):
        super().__init__()
        inner_dim = heads * dim_head
        self.heads =  heads
        self.scale = dim_head ** -0.5
        self.causal = causal

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, d, h, device = *x.shape, self.heads, x.device
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j')
            sim.masked_fill_(~mask, max_neg_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            causal_mask = rearrange(causal_mask, 'i j -> () i j')
            sim.masked_fill_(causal_mask, max_neg_value)

        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.to_out(out)




In [6]:
# main class

class Omninet(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.,
        feature_redraw_interval = 1000
    ):
        super().__init__()

        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                PerformerAttention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

        # keep track of redrawing projection matrix for Performer
        self.feature_redraw_interval = feature_redraw_interval
        self.register_buffer('calls_since_last_redraw', torch.tensor(0))

    def fix_projection_matrices_(self):
        self.feature_redraw_interval = None

    def check_redraw_projections(self):
        if not self.training:
            return

        if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
            device = get_module_device(self)

            fast_attentions = find_modules(self, FastAttention)
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix(device)

            self.calls_since_last_redraw.zero_()
            return

        self.calls_since_last_redraw += 1

    def forward(self, x, mask = None):
        self.check_redraw_projections()
        pool_num_layers = self.pool_num_layers

        hiddens = [x]

        for attn, ff, efficient_attn in self.layers:
            x = attn(x, mask = mask) + x
            x = ff(x) + x

            hiddens.append(x)
            if exists(efficient_attn):
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                all_tokens = torch.stack(layers_to_pool)
                all_tokens = rearrange(all_tokens, 'l b n d -> b (n l) d')

                pool_attn_mask = None
                if exists(mask):
                    pool_attn_mask = repeat(mask, 'b n -> b (n l)', l = num_layers)

                attended_tokens = efficient_attn(all_tokens, mask = pool_attn_mask)

                attended_tokens = rearrange(attended_tokens, 'b n c -> b c n')
                pooled_tokens = F.max_pool1d(attended_tokens, kernel_size = num_layers, stride = num_layers)
                x += rearrange(pooled_tokens, 'b c n -> b n c')

        return x



In [7]:
# causal case is sufficiently different to warrant its own class
# use layer axial attention for now, until I rewrite the linear attention cuda kernel

class OmninetCausal(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        pool_layer_tokens_every = 2,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()

        self.layer_pos_emb = nn.Parameter(torch.randn(depth + 1, dim))

        layers = nn.ModuleList([])
        for ind in range(depth):
            num_layers = ind + 1
            should_pool = num_layers % pool_layer_tokens_every

            layers.append(nn.ModuleList([
                PreNorm(dim, Attention(causal = True, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim = dim, dropout = ff_dropout)),
                Attention(dim = dim, heads= heads, dim_head = dim_head) if should_pool else None
            ]))

        self.layers = layers
        self.pool_num_layers = pool_layer_tokens_every

    def forward(self, x, mask = None):
        pool_num_layers = self.pool_num_layers

        b = x.shape[0]
        pos_embs = rearrange(self.layer_pos_emb, 'n d -> () n d')

        x += pos_embs[:, 0]
        hiddens = [x]

        for ind, (attn, ff, layer_axial_attn) in enumerate(self.layers):

            x = attn(x, mask = mask) + x
            x = ff(x) + x

            x += pos_embs[:, ind + 1]
            hiddens.append(x)

            if exists(layer_axial_attn):
                layers_to_pool = hiddens[-pool_num_layers:]
                num_layers = len(layers_to_pool)

                layer_tokens = rearrange(torch.stack(layers_to_pool), 'l b n d -> (b n) l d')

                attended_tokens = layer_axial_attn(layer_tokens)
                attended_tokens = rearrange(attended_tokens, '(b n) l d -> b n l d', b = b)
                pooled_attended_tokens = attended_tokens.max(dim = -2).values
                x += pooled_attended_tokens

        return x

In [8]:
omninet = Omninet(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1,              # feedforward dropout
    feature_redraw_interval = 1000 # how often to redraw the projection matrix for omni attention net - Performer
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

tensor([[[ 0.9816,  1.4537,  2.9661,  ...,  0.2320, -0.9211, -0.4982],
         [-1.1163,  0.2141,  0.1512,  ...,  1.9984,  0.1828, -0.5664],
         [ 0.9360, -0.6398, -0.1149,  ...,  0.2647, -0.9080,  0.2136],
         ...,
         [-0.9472,  0.6716,  1.9491,  ...,  0.5262, -0.3791, -0.6060],
         [ 3.3216,  0.5934,  2.6454,  ...,  0.7024, -0.3119, -1.1473],
         [ 0.0977, -0.9421,  1.7860,  ...,  2.1599,  1.3780, -1.1123]]],
       grad_fn=<AddBackward0>)

In [9]:
omninet = OmninetCausal(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1               # feedforward dropout
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

tensor([[[ 3.5471,  6.6507,  4.3266,  ...,  4.1534,  0.9080,  3.8440],
         [ 6.0989,  4.7633,  4.2315,  ...,  2.0625,  0.6029,  1.1094],
         [ 6.5093,  5.0774,  3.3120,  ...,  5.8083,  0.0604,  2.0568],
         ...,
         [ 4.2883,  5.6745,  4.2724,  ...,  4.3588,  4.1432,  3.7385],
         [ 5.8884,  4.6082,  4.3106,  ...,  3.4019,  2.5031,  1.3954],
         [ 5.8154,  4.0144,  2.9318,  ...,  5.6129,  1.8431, -0.3573]]],
       grad_fn=<AddBackward0>)