# Decoder-only Transofmer

In [1]:
# ! pip install lightning

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

In [3]:
token_to_id = {
    'what': 0,
    'is': 1,
    'statquest': 2,
    'awosome"': 3,
    '<EOS>': 4
}

In [4]:
id_to_token = {id: token for token, id in token_to_id.items()}
# id_to_token = dict(map(reversed, token_to_id.items()))

In [5]:
inputs = torch.tensor([[token_to_id['what'], 
                        token_to_id['is'], 
                        token_to_id['statquest'],
                        token_to_id['<EOS>'],
                        token_to_id['awosome"']],
                        
                        [token_to_id['awosome"'],
                         token_to_id['is'], 
                         token_to_id['what'],
                         token_to_id['<EOS>'], 
                         token_to_id['awosome"']]])

labels = torch.tensor([[token_to_id['is'],
                        token_to_id['statquest'], 
                        token_to_id['<EOS>'], 
                        token_to_id['awosome"'], 
                        token_to_id['<EOS>']],
                        
                        [token_to_id['is'],
                         token_to_id['what'], 
                         token_to_id['<EOS>'], 
                         token_to_id['awosome"'], 
                         token_to_id['<EOS>']]])

In [6]:
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)

In [7]:
# word embedding
# nn.Embedding()

In [8]:
# Position Encoding

In [9]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model, max_len):
        '''
        d_model: dimesion of the model, number of word embedding values per token
        max_len: max number of tokens our SimpleGPT can process -- input and output combined
        '''

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)
        # torch.arange - create a sequence of numbers, values from the interval [start, end)
        # unsqueeze(1) - turn the sequence to a column matrix
        embedding_index = torch.arange(start=0, end=d_model, step=2).float()
        # step=2 - based on PE formula it is multiple by 2
        # For d_model=2, the embedding_index is just tensor[.0]

        div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position * div_term)
        # for row ':' will consider all rows, and for columns '0::2' means start from 0 and 2 means every other column after that 
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe)
        # to move pe to GPU

    def forward(self, word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0), :]

In [10]:
# Masked Self-Attention: (Word Embedding + Positional Encoding) * Weightd = Query | Key | Value

In [11]:
class Attention(nn.Module):
    
    def __init__(self, d_model=2):

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        # nn.Linear - creating Weight matrix and math computation
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):

        q = self.W_q(encodings_for_q) 
        # Q: n, d_k - (WE+PE): n, d * W_q: d, d_k
        k = self.W_k(encodings_for_k) # n, d_k
        v = self.W_v(encodings_for_v) # n, d_v

        # attention(Q, K, V) = softmax(QK.T / Sqrt(d_k)) V

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        # transpose() - transpose the selected dims (dim0, dim1) of the tensor/matrix
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            # masking is used to prevent early tokens from cheating and looking at later tokens
        attention_precents = F.softmax(scaled_sims, dim=self.col_dim)
        # determining the precentages of influence that each token should have on the others
        attention_scores = torch.matmul(attention_precents, v)
        return attention_scores

In [12]:
class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, num_tokens=4, d_model=2, max_len=6):
        # num_tokens: max number of tokens in the vocab
        super().__init__()
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)
        self.self_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, token_ids):
        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)

        mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))
        mask = mask == 0

        self_attention_values = self.self_attention(position_encoded,
                                                    position_encoded,
                                                    position_encoded,
                                                    mask=mask)

        residual_connection_layers = position_encoded + self_attention_values
        fc_layer_output = self.fc_layer(residual_connection_layers)

        return fc_layer_output

In [13]:
def configure_optimizers(self):
    return Adam(self. parameters(), lr=0.1)

def training_step(self, batch, batch_idx):
    input_tokens, labels = batch
    output = self. forward (input_tokens[0])
    loss = self. loss(output, labels[0])
    return loss

In [14]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)

In [None]:
trainer = L.Trainer (max_epochs=30)
trainer.fit(model, train_dataloaders=dataloader)

In [15]:

model_input = torch.tensor([token_to_id["what"],
                            token_to_id["is"], 
                            token_to_id["statquest"], 
                            token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
print(predicted_id)
predicted_ids = predicted_id

max_length = 6
for i in range(input_length, max_length) :
    if (predicted_id == token_to_id["<EOS>"]):
        break
    
    model_input = torch.cat((model_input, predicted_id))

    predictions = model(model_input)
    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])
    predicted_ids = torch.cat((predicted_ids, predicted_id) )

print("Predicted Tokens:\n")
for id in predicted_ids:
    print("\t", id_to_token[id.item()])

tensor([1])
Predicted Tokens:

	 is
	 statquest
	 statquest
