Transformer From Scratch 
========================

Reading along Dan Jurafsky and James H. Martin's [Speech and Language Processing](https://web.stanford.edu/~jurafsky/slp3/) book, I decided to follow through Chapter 8 of their book to implement a Transformer using Pytorch. It is my goal to have a working transformer which I can use to train on the guitar dataset. I know some linear algebra, the book essentially gives the entire algorithm in terms of linear algebra, and pytorch provides a nice but still very informative abstractions for doing linear algebra. I had no reason not to pursue this project on top of whatever I proposed to do initially. 

Attention Layer
---------------

At the heart of Transformer is the **attention layer**. It is a mechanism that allows words(tokens) to gain contextual meaning from their surrounding words(tokens). It can have multiple **"heads"**, where each "head" can be thought of as a specialist who asks particular set of questions given some data. For instance, one head could focus solely on grammar while another could instead focus on sentiments (even though that might not be exactly what occurs under the hood).

Each head's job, then, is to ask the right kinds of *questions* to choose which of previous words it has seen matters the most to the current word. To do this, each head consists of three main components: **Query**, **Key**, and **Value** weight matrices. 

<!-- 
    essentially, what it is at the end of the day is weighted sum, but it's obviously lot more complicated than that
    don't forget to write out the equations that I have referenced
    maybe throw in some pictures
    say something about how masking and softmax is used to determine what key's to focus on
    also explain how results from different heads are consolidated at the end
--!>

In [1]:
from myTransformer import *
batch_size = 10
N = 10
model_dim = 24
num_heads = 4
key_dim = 3

M = 8
X = torch.rand((batch_size, N, model_dim)) # batch_size is 10, 3 words represented as dim (1, 4) tensors
Y = torch.rand((batch_size, M, model_dim)) # 3 words represented as dim (1, 4) tensors
mask = torch.tensor([[0 if i>= j else -torch.inf for j in range(N)] for i in range(N)])

multihead_attention = AttentionLayer(model_dim=model_dim, key_dim=key_dim, num_heads=num_heads)
multihead_attention(X, H_enc=Y, mask=mask).shape
#multihead_attention.to("cuda")
#multihead_attention(X.to("cuda"), Y.to("cuda"), mask=mask.to("cuda"))

torch.Size([10, 10, 24])

In [2]:
encoder_block = TransformerBlock(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads)
decoder_block = TransformerBlock(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads, cross_attention=True)
encoder_block(X,mask=mask)
decoder_block(X, Y)

tensor([[[ 7.5296e-01,  3.2767e-01,  1.0773e+00,  ...,  0.0000e+00,
           6.2788e-01,  3.0886e-01],
         [ 3.5242e-01,  8.7099e-01, -7.0052e-02,  ...,  8.5288e-01,
           3.2177e-01,  7.5983e-01],
         [ 4.9739e-01,  6.1805e-01,  9.0899e-01,  ...,  1.4382e+00,
           8.5608e-01,  8.9060e-01],
         ...,
         [ 7.3464e-01,  5.5780e-01,  5.1116e-01,  ...,  1.3091e+00,
           9.7294e-01,  4.7608e-01],
         [ 9.2129e-01,  4.6460e-01,  5.4470e-01,  ...,  1.5370e+00,
           8.6947e-01,  3.8889e-01],
         [ 2.9361e-01,  6.3682e-01,  4.8986e-01,  ...,  1.1732e+00,
           7.7545e-01,  5.0968e-04]],

        [[ 1.1431e+00,  1.5871e+00,  5.3359e-01,  ...,  0.0000e+00,
           5.9994e-01,  3.5556e-01],
         [ 1.3322e+00,  1.5784e+00,  4.2434e-01,  ...,  1.0190e+00,
           2.0703e-01,  0.0000e+00],
         [ 9.2406e-01,  1.1009e+00,  6.6255e-01,  ...,  3.5642e-01,
           5.7891e-01, -1.5057e-01],
         ...,
         [ 0.0000e+00,  1

In [3]:
stack = TransformerStack(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=8, num_heads=num_heads, num_stack=9)
stack.state_dict()
stack.train()
stack(X)

tensor([[[ 1.2460,  1.5768, -2.8658,  ...,  4.0664, -0.4537, -0.6890],
         [ 2.7587,  3.6024, -0.0000,  ..., -2.0651, -0.2213,  2.5566],
         [ 0.1862,  3.2953,  2.2246,  ...,  0.0000, -0.2880,  0.0477],
         ...,
         [ 0.0000,  0.7855,  0.7146,  ...,  0.0902, -0.0000, -0.0000],
         [-0.1667, -0.0000,  1.0640,  ...,  3.4153,  0.4446, -0.3815],
         [ 0.3954, -0.3886, -0.8220,  ..., -0.1606, -0.7062,  0.0294]],

        [[-0.1389, -0.6492, -1.7427,  ...,  0.4945, -0.0881, -0.8141],
         [-0.0083,  3.5833, -0.1980,  ..., -0.2740,  1.6221,  2.7117],
         [-0.0079,  0.0000, -1.3811,  ...,  0.1921,  0.3176, -1.9613],
         ...,
         [-1.0546,  0.7826,  0.0911,  ...,  0.0000, -0.1618, -0.8852],
         [ 2.3896,  0.0000,  1.1581,  ...,  3.2611, -1.2631,  0.0000],
         [-0.6762, -0.2197, -1.0679,  ...,  3.4527,  0.5191,  0.2945]],

        [[-0.2005, -0.3512, -0.8756,  ...,  3.0672,  0.0000, -0.1345],
         [ 0.0070,  1.5368, -0.6318,  ..., -2

In [4]:
# from https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
# i don't completely understand positional encoding yet, but I have built the intuition that 
# it is analogous to how binary numbers encode numbers; smaller bits flips more frequently 
# than larger bits; this is modeled by the sinusodial waves 
# it also takes advantage of linearity of trigonometric addition formulas, which supposedly 
# helps the model to figure out relative positioning...
# https://medium.com/thedeephub/positional-encoding-explained-a-deep-dive-into-transformer-pe-65cfe8cfe10b 
class PositionalEncoding(nn.Module):

    def __init__(self, model_dim: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(max_len, 1, model_dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, X):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        X = X + self.pe[:X.size(0)]
        return self.dropout(X)

In [5]:
# what I need to do to finish up the encoder_decoder architecture 
# the only difference for the decoder architecture is the cross attention layer `
# which is much like the self-attension layer except that it is using both the final 
# H of the encoder and that of decoder to do query-key matching, thus decoder needs to 
# take in memory from encoder
pos = PositionalEncoding(model_dim=model_dim)


In [11]:
import dac
from ICMTSMTGuitarData import *
from jupyter_audio_utils import *
import torch

model_path = dac.utils.download(model_type="44khz")
model = dac.DAC.load(model_path, weights_only=True).eval()

mono_data2 = ICMTSMTGuitarDataMono(DEFAULT_PEDAL_PROBS)
test_2 = mono_data2[random.randint(0, len(mono_data2)-1)] 

waveform, sr = test_2[0] 
play_audio(*test_2[0])

x = model.preprocess(waveform, sr)

with torch.no_grad():
    z, codes, latents, _, _ = model.encode(x.unsqueeze(dim=0))

audio_tokens = z.transpose(-2,-1) #(batch_size, seq_len, model_dim)
N = audio_tokens.shape[-2] 
model_dim = 32

# further compress information or is this just baseless; the hope is that this will distill features that actually matter
linear = torch.nn.Linear(in_features=audio_tokens.shape[-1], out_features=model_dim)

key_dim = 256
hidden_dim = 512
num_heads=4
print(audio_tokens.shape)

compressed = linear(audio_tokens)
print(compressed.shape)
encoder = TransformerStack(N=N, model_dim=model_dim, key_dim=key_dim, hidden_dim=256, num_heads=num_heads, num_stack=1)
latent = encoder(compressed)

torch.Size([1, 173, 1024])
torch.Size([1, 173, 32])


In [15]:
reverse = torch.nn.Linear(in_features=model_dim, out_features=audio_tokens.shape[-1])

noise = reverse(latent).transpose(-1, -2) 
y = model.decode(noise)


In [33]:
y_detached = y.detach().squeeze().unsqueeze(dim=0)
play_audio(y_detached, 44100)

In [None]:
from myTransformer import *

# blueprint
class EncoderDecoder(nn.Module): 
    def __init__(
        self, 
        N, 
        model_dim,
        key_dim,
        encoder_mask,
        encoder_hidden_dim, 
        encoder_num_stack,
        encoder_num_heads,
        decoder_hidden_dim, 
        decoder_num_stack,
        decoder_num_heads,
        decoder_vocab
    ):
        pass
        self.positional_encoder = PositionalEncoding()
        self.encoder_stack = TransformerStack(
            N=N, 
            model_dim=model_dim, 
            key_dim=key_dim, 
            hidden_dim=encoder_hidden_dim, 
            num_heads=encoder_num_heads, 
            num_stack=encoder_num_stack
        )
        self.decoder_stack = TransformerStack(
            N=N, 
            model_dim=model_dim, 
            key_dim=key_dim, 
            hidden_dim=decoder_hidden_dim, 
            num_heads=decoder_num_heads, 
            num_stack=decoder_num_stack,
            cross_attention=True
        )
        self.language_head = None
        self.decoder_vocab = decoder_vocab
        self.mask = encoder_mask # do the self register thingy
    def forward(self, X): # (batch_size, N, model_dim)
        X = self.positional_encoder(X)
        H1 = self.encoder_stack(X, mask=self.mask)
        H2 = self.decoder_stack(X, H1)
        Y = self.language_head(H2,...) # not implemented yet

        return Y