In [None]:
import torch

In [None]:
class QKVAttention(torch.nn.Module):
    def __init__(self, scale, n, q_channels):
        super().__init__()
        self.softmax = torch.nn.Softmax(dim=-1)
        self.scale = scale
        self.n = n
        self.q_channels = q_channels
    
    def forward(self, q, k, v): # Note: KINDA SUS, TRY TO GET RIGHT COMBO OF RESHAPE AND PERMUTE TO GET PROPER HEAD AND INDEX DIM SEPARATION
        #following permutes are sussy, intention: to get the heads together but each index dim seperate
        q = q.permute([0, 2, 1, 3])
        k = k.permute([0, 2, 1, 3])
        v = v.permute([0, 2, 1, 3])
        att_factor = torch.matmul(q, k.permute([0, 1, 3, 2]))
        att_factor = self.softmax(att_factor * self.scale)
        attention = torch.matmul(att_factor, v)
        attention = attention.permute([0, 2, 1, 3]) # intent: keep index dim together when heads are merged
        attention = attention.reshape([-1, self.n, self.q_channels])
        return attention

In [None]:
attn = QKVAttention(0.5, n=2, q_channels=6)
q = torch.randn([1, 2, 3, 2])
k = torch.randn([1, 10, 3, 2])
v = torch.randn([1, 10, 3, 2])
attn(q, k, v).shape

In [None]:
class MLP(torch.nn.Module):
    def __init__(self, num_channels, wide_factor):
        super().__init__()
        inner_channels = wide_factor * num_channels
        self.lin_in = torch.nn.Linear(num_channels, inner_channels) #Note: remove bias if needed
        self.lin_out = torch.nn.Linear(inner_channels, num_channels) #Note: remove bias if needed
        self.gelu = torch.nn.GELU()

    def forward(self, x):
        x = self.lin_in(x)
        x = self.gelu(x)
        x = self.lin_out(x)
        return x

In [None]:
in_val = torch.randn([10, 32])
mlp = MLP(32, 4)
mlp(in_val).shape

In [None]:
class Attention(torch.nn.Module):
    def __init__(self, num_heads, n, q_channels, kv_channels):
        super().__init__()
        if q_channels % num_heads != 0:
            raise ValueError('Bro, your damn channels don\'t work with your head count!', q_channels, num_heads)
        # if kv_channels % num_heads != 0:
        #     raise ValueError('Bro, your damn channels don\'t work with your head count!', kv_channels, num_heads)
        self.num_heads = num_heads
        self.q_head_channels = q_channels // num_heads
        
        self.q_trans = torch.nn.Linear(q_channels, q_channels)
        self.k_trans = torch.nn.Linear(kv_channels, q_channels)
        self.v_trans = torch.nn.Linear(kv_channels, q_channels)
        self.out_trans = torch.nn.Linear(q_channels, q_channels)
        self.qkvAttn = QKVAttention(q_channels ** -0.5, n, q_channels)

    def forward(self, q_in, k_in, v_in):
        q = self.q_trans(q_in)
        k = self.k_trans(k_in)
        v = self.v_trans(v_in)

        q = q.reshape([-1, q.shape[1], self.num_heads, self.q_head_channels])
        k = k.reshape([-1, k.shape[1], self.num_heads, self.q_head_channels])
        v = v.reshape([-1, v.shape[1], self.num_heads, self.q_head_channels])

        attn = self.qkvAttn(q, k, v)

        out = self.out_trans(attn)
        return out

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, n, num_channels, heads, wide_factor, p_dropout=0.1):
        super().__init__()
        self.in_norm = torch.nn.LayerNorm(num_channels)
        self.mlp_norm = torch.nn.LayerNorm(num_channels)
        self.attn = Attention(heads, n, num_channels, num_channels)
        self.mlp = MLP(num_channels, wide_factor)
        self.dropout = torch.nn.Dropout(p=p_dropout)

    def forward(self, in_val):
        x = in_val
        attn = self.in_norm(in_val)
        attn = self.attn(attn, attn, attn)
        x += attn

        mlp = self.mlp_norm(x)
        mlp = self.mlp(mlp)
        x += mlp

        return x

In [None]:
selfattn = SelfAttention(3, 4, 2, 4)
garbo = torch.randn([1, 3, 4])
selfattn(garbo).shape

In [129]:
class CrossAttention(torch.nn.Module):
    def __init__(self, n, q_channels, kv_channels, heads, wide_factor, p_dropout=0.1):
        super().__init__()
        self.q_norm = torch.nn.LayerNorm(q_channels)
        self.kv_norm = torch.nn.LayerNorm(kv_channels)
        self.mlp_norm = torch.nn.LayerNorm(q_channels)
        self.attn = Attention(heads, n, q_channels, kv_channels)
        self.mlp = MLP(q_channels, wide_factor)
        self.dropout = torch.nn.Dropout(p=p_dropout)

    def forward(self, q_kv):
        q, kv = q_kv
        x = q
        q = self.q_norm(q)
        kv = self.kv_norm(kv)
        attn = self.attn(q, kv, kv)
        x += attn

        mlp = self.mlp_norm(x)
        mlp = self.mlp(mlp)
        x += mlp

        return x

In [131]:
q = torch.randn([1, 5, 4])
kv = torch.randn([1, 10, 20])
cross = CrossAttention(5, 4, 20, 4, 4)
cross((q, kv)).shape

torch.Size([1, 5, 4])

In [141]:
class PerceiverEncoder(torch.nn.Module):
    def __init__(self, n, q_channels, kv_channels, heads, wide_factor, latent_count, repeat_count=1, p_dropout=0.1):
        super().__init__()
        self.repeat_count = repeat_count

        latentBlocks = [SelfAttention(n, q_channels, heads, wide_factor, p_dropout) for _ in range(latent_count)]
        self.block = torch.nn.Sequential(
            CrossAttention(n, q_channels, kv_channels, heads, wide_factor, p_dropout),
            *latentBlocks
        )
    
    def forward(self, q, kv):
        x = self.block((q, kv))
        for _ in range(self.repeat_count-1):
            x = self.block((x, kv))
        return self.block((q, kv))

In [147]:
encoder = PerceiverEncoder(n=10, q_channels=32, kv_channels=64, heads=4, wide_factor=4, latent_count=6, repeat_count=5)
q = torch.randn([6, 10, 32])
kv = torch.randn([6, 20, 64])
encoder(q, kv).shape

torch.Size([6, 10, 32])

In [149]:
class PerceiverInternal(torch.nn.Module):
    def __init__(self, n, q_channels, kv_channels, heads, wide_factor, latent_count, q_out_dim, repeat_count=1, p_dropout=0.1):
        super().__init__()
        self.encoder = PerceiverEncoder(n, q_channels, kv_channels, heads, wide_factor, latent_count, repeat_count, p_dropout)

        q_out = torch.zeros(q_out_dim)
        torch.nn.init.xavier_normal_(q_out)
        self.q_out = torch.nn.Parameter(q_out)

        self.out_cross = CrossAttention(n, q_channels=q_out_dim[0], kv_channels=q_channels, heads=heads, wide_factor=wide_factor, p_dropout=p_dropout)

    def forward(self, in_val):
        x = self.encoder(in_val)
        x = self.out_cross(self.q_out, x)
        return x

In [None]:
percInternal = PerceiverInternal(n=10, q_channels=32, kv_channels=64, heads=4, wide_factor=4, latent_count=3, q_out_dim=(37,42)