In [53]:
import torch
import torch.nn as nn
from torch.nn import functional as F

device = "cuda:0"
obs = torch.rand([13,1]).to(device)



208

In [68]:
# Knowledge deficit
from einops import rearrange

class SinusoidalPosition(nn.Module):
    """Relative positional encoding"""
    def __init__(self, dim, min_timescale = 2., max_timescale = 1e4):
        super().__init__()
        freqs = torch.arange(0, dim, min_timescale)
        inv_freqs = max_timescale ** (-freqs / dim)
        self.register_buffer('inv_freqs', inv_freqs)

    def forward(self, seq_len):
        seq = torch.arange(seq_len - 1, -1, -1.)
        sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d')
        pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1)
        return pos_emb

pos_embedding = SinusoidalPosition(13)
embedding = pos_embedding(16)

embedding.shape

# x = embedding.deatch().cpu().clone().numpy()

# import matplotlib.pyplot as plt
# plt.plot(embedding)

torch.Size([16, 14])

In [57]:
class Head(nn.Module):
    def __init__(
            self,
            head_size:int,
            seq_len  :int,
            emb_size :float,
            dropout  :float
            ):
        super(Head).__init__()
        self.key   = nn.Linear(emb_size,head_size,bias=False)
        self.query = nn.Linear(emb_size,head_size,bias=False)
        self.value = nn.Linear(emb_size,head_size,bias=False)
        self.register_buffer('tril',torch.tril(torch.ones(seq_len,seq_len)))
        self.dropout = nn.Dropout(dropout)

    def foward(self,embed,mask):
        # mask wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        # What's the T dim

        key   = self.key(embed)
        query = self.query(embed)

        weights = query @ key * key.shape[0] ** (0.5)
        weights = weights.masked_fill(mask == 0, float("-inf"))
        weights = F.softmax(weights)
        out = weights @ self.value(embed)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(
            self,
            num_heads:int,
            emb_size :float,
            seq_len  :int,
            dropout  :float
            ):
        super(MultiHeadAttention).__init__()
        head_size = emb_size // num_heads
        assert (num_heads*head_size == emb_size), 'emb_size must be an integer multiple of num_heads'
        
        self.heads = nn.ModuleList(
            [
                Head(head_size,seq_len,emb_size,dropout) for _ in range(num_heads)
                ]
            )
        self.proj    = nn.Linear(num_heads*head_size,emb_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,embed):
        out = torch.cat([head(embed) for head in self.heads])
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(
            self,
            emb_size    :float,
            ff_expansion:int = 4
            ):
        super(FeedForward).__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_size,emb_size*ff_expansion),
            nn.ReLU(),
            nn.Linear(emb_size*ff_expansion,emb_size)
        )
    
    def forward(self,input):
        return self.net(input)

class DecoderBlock(nn.Module):
    def __init__(
            self,
            num_heads   :int,
            seq_len     :int,
            emb_size    :float,
            ff_expansion:int,
            dropout     :float
            ):
        super(DecoderBlock).__init__()
        self.attention = MultiHeadAttention(num_heads,emb_size,seq_len,dropout)
        self.ffnet   = FeedForward(emb_size,ff_expansion)
        self.lnorm_1 = nn.LayerNorm(emb_size,emb_size)
        self.lnorm_2 = nn.LayerNorm(emb_size,emb_size)
    
    def forward(self,embed):
        attention = self.attention(self.lnorm_1(embed))
        attention = self.lnorm_2(embed + attention)
        out = embed + self.ffnet(attention)
        return out
    
class Decoder(nn.Module):
    def __init__(
            self,
            num_layers  :int,
            num_heads   :int,
            seq_len     :int,
            obs_len     :int,
            emb_size    :float,
            ff_expansion:int,
            dropout     :float
            ):
        super(Decoder).__init__()
        self.embedding = nn.Embedding(seq_len,emb_size)
        self.pos_embedding = nn.Embedding(obs_len,emb_size)
        self.decoders = nn.Sequential(
            *[DecoderBlock(num_heads,seq_len,emb_size,ff_expansion,dropout) for _ in range(num_layers)]
        )
        self.lnorm  = nn.LayerNorm(emb_size)
        self.fc_out = nn.Linear(emb_size,seq_len)seq_len
    
    def forward(self,x):
        embed   = self.embedding(x)
        indices = torch.rand(x.shape[1]).to(device)
        encode  = self.pos_embedding(indices)
        embed   = embed + encode
        out = self.decoders(embed)
        out = self.fc_out(self.lnorm(out))


