In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops import rearrange

In [57]:
class PixelEmbedding(nn.Module):
    def __init__(self, N, embed_dim = 64):
        super().__init__()

        num_pixels = N*N

        self.projection = nn.Linear(2, embed_dim)
        self.positional_embeddings = nn.Parameter(torch.randn(N*N,embed_dim), requires_grad=True)

        
    def forward(self, x):

        # Rearrange input pixel condition (Batch, 2, N, N) to (Batch, N*N, 2)
        x = rearrange(x, 'b c h w -> b (h w) c')

        x = self.projection(x)
        
        # add position embedding
        x += self.positional_embeddings

        return x

In [58]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim = 64, d_model = 64, num_heads = 8, dropout = 0.1):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.d_model = d_model


        self.key_map = nn.Linear(emb_dim, d_model)
        self.query_map = nn.Linear(emb_dim, d_model)
        self.value_map = nn.Linear(emb_dim, d_model)

        self.d_k = d_model // num_heads

        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(d_model, emb_dim)
        
    def forward(self, x):

        # No masking
        # split keys, queries and values in num_heads

        # n corresponds to N
        # h d = d_model, h: number of heads, d: d_k

        key = rearrange(self.key_map(x), 'b n (h d) -> b h n d', h = self.num_heads)
        query = rearrange(self.query_map(x), "b n (h d) -> b h n d", h=self.num_heads)
        value  = rearrange(self.value_map(x), "b n (h d) -> b h n d", h=self.num_heads)

        
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', query, key) # batch, num_heads, query_len, key_len, same as transposing the last two dimension
            
        scaling = self.d_k ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)

        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, value) # simple matrix multiplcation over the last dimension

        # concatenate the head
        out = rearrange(out, "b h n d -> b n (h d)")

        out = self.projection(out)

        return out

In [59]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_dim = 64, expansion = 4, dropout = 0.1):
        super().__init__(
            nn.Linear(emb_dim, expansion * emb_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(expansion * emb_dim, emb_dim),
        )

In [60]:
class Transformer_Encoder_Block(nn.Module):
    def __init__(self, emb_dim = 64, d_model = 64, num_heads = 8, expansion = 4, dropout = 0.1):
        super().__init__()
        self.layernorm = nn.LayerNorm(emb_dim)
        self.multiheadattention = MultiHeadAttention(emb_dim = emb_dim, d_model= d_model, num_heads = num_heads, dropout = dropout)
        self.feedforward = FeedForwardBlock(emb_dim = emb_dim, expansion = expansion, dropout = dropout)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):

        identity = x

        x = self.layernorm(x)
        x = self.multiheadattention(x)
        x = self.dropout(x)
        x += identity

        identity = x

        x = self.layernorm(x)
        x = self.feedforward(x)
        x = self.dropout(x)
        x+= identity

        return x

In [61]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth = 4, **kwargs):
        super().__init__(*[Transformer_Encoder_Block(**kwargs) for _ in range(depth)])

In [62]:
class ATnet(nn.Module):
    def __init__(self, N = 10, emb_dim = 64, depth = 4, d_model = 64, num_heads = 8, expansion = 4, dropout = 0.1):

        super().__init__()

        self.pixelembedding = PixelEmbedding(N, emb_dim)
        self.transformerencoder = TransformerEncoder(depth = depth, emb_dim=emb_dim, d_model = d_model, \
        num_heads = num_heads, expansion = expansion, dropout = dropout)
        self.conv = nn.Conv1d(emb_dim, 1, 1)
        self.output = nn.Sigmoid()

    def forward(self, x):

        x = self.pixelembedding(x)
        x = self.transformerencoder(x)
        # (Batch, N*N, C) --> (Batch, C, N*N)
        x = rearrange('b l c -> b c l')
        x = self.conv(x)
        x = torch.squeeze(x)
        x = self.output(x)

        return x
