In [45]:
import torch
import torch.nn as nn

In [46]:
def batch_loader(raw_dataset, context_length, batch_size):
    tokenized = raw_dataset.split()
    indices = torch.randint(low=0, high=(len(tokenized)-context_length), size=(batch_size,)).tolist()
    X = []
    Y = []
    for idx in indices:
        X.append(tokenized[idx:idx+context_length])
        Y.append(tokenized[idx+1:idx+1+context_length])
    return X, Y

In [47]:
raw_dataset = "Hello darkness my old friend how are you"
context_length = 4
batch_size = 4
X, Y = batch_loader(raw_dataset, context_length, batch_size)
print("X = ", X)
print("Y = ", Y )

X =  [['my', 'old', 'friend', 'how'], ['old', 'friend', 'how', 'are'], ['old', 'friend', 'how', 'are'], ['Hello', 'darkness', 'my', 'old']]
Y =  [['old', 'friend', 'how', 'are'], ['friend', 'how', 'are', 'you'], ['friend', 'how', 'are', 'you'], ['darkness', 'my', 'old', 'friend']]


In [97]:
class GPT(nn.Module):
    
    def __init__(self, vocab_size: int, context_length: int, model_dim: int, num_blocks: int, num_heads: int):
        super().__init__()
        torch.manual_seed(0)
        self.word_embeddings = nn.Embedding(vocab_size, model_dim)
        self.position_embeddings = nn.Embedding(context_length, model_dim)
        self.transformer_blocks = nn.Sequential()
        for i in range(num_blocks):
            self.transformer_blocks.append(self.TransformerBlock(model_dim, num_heads))
        self.final_norm = nn.LayerNorm(model_dim)
        self.vocab_projection = nn.Linear(model_dim, vocab_size)

    def forward(self, context):
        torch.manual_seed(0)
        embedded = self.word_embeddings(context)
        context_length = context.shape[1]
        positions = torch.arange(context_length)
        embedded = embedded + self.position_embeddings(positions)

        raw_output = self.vocab_projection(self.final_norm(self.transformer_blocks(embedded)))
        # raw_output is batch by context_length by vocab_size

        probabilities = nn.functional.softmax(raw_output, dim = -1)
        return torch.round(probabilities, decimals=4)
    
    class TransformerBlock(nn.Module):

        class MultiHeadedSelfAttention(nn.Module):

            class SingleHeadAttention(nn.Module):
                def __init__(self, model_dim: int, head_size: int):
                    super().__init__()
                    torch.manual_seed(0)
                    self.key_gen = nn.Linear(model_dim, head_size, bias=False)
                    self.query_gen = nn.Linear(model_dim, head_size, bias=False)
                    self.value_gen = nn.Linear(model_dim, head_size, bias=False)
                
                def forward(self, embedded):
                    k = self.key_gen(embedded)
                    q = self.query_gen(embedded)
                    v = self.value_gen(embedded)

                    scores = q @ torch.transpose(k, 1, 2) # @ is the same as torch.matmul()
                    context_length, attention_dim = k.shape[1], k.shape[2]
                    scores = scores / (attention_dim ** 0.5)

                    lower_triangular = torch.tril(torch.ones(context_length, context_length))
                    mask = lower_triangular == 0
                    scores = scores.masked_fill(mask, float('-inf'))
                    scores = nn.functional.softmax(scores, dim = 2)

                    return scores @ v
                
            def __init__(self, model_dim: int, num_heads: int):
                super().__init__()
                torch.manual_seed(0)
                self.att_heads = nn.ModuleList()
                for i in range(num_heads):
                    self.att_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads))

            def forward(self, embedded):
                head_outputs = []
                for head in self.att_heads:
                    head_outputs.append(head(embedded))
                concatenated = torch.cat(head_outputs, dim = 2)
                return concatenated
        
        class VanillaNeuralNetwork(nn.Module):

            def __init__(self, model_dim: int):
                super().__init__()
                torch.manual_seed(0)
                self.up_projection = nn.Linear(model_dim, model_dim * 4)
                self.relu = nn.ReLU()
                self.down_projection = nn.Linear(model_dim * 4, model_dim)
                self.dropout = nn.Dropout(0.2) # using p = 0.2
            
            def forward(self, x):
                torch.manual_seed(0)
                return self.dropout(self.down_projection(self.relu(self.up_projection(x))))

        def __init__(self, model_dim: int, num_heads: int):
            super().__init__()
            torch.manual_seed(0)
            self.attention = self.MultiHeadedSelfAttention(model_dim, num_heads)
            self.linear_network = self.VanillaNeuralNetwork(model_dim)
            self.first_norm = nn.LayerNorm(model_dim)
            self.second_norm = nn.LayerNorm(model_dim)

        def forward(self, embedded):
            torch.manual_seed(0)
            embedded = embedded + self.attention(self.first_norm(embedded)) # skip connection
            embedded = embedded + self.linear_network(self.second_norm(embedded)) # another skip connection
            return embedded


In [102]:
def generate(model, new_chars, context, context_length, int_to_char):
    generator = torch.manual_seed(0)
    initial_state = generator.get_state()
    res = []
    for i in range(new_chars):
        if len(context.T) > context_length:
            context = context[:, -context_length:]
        prediction = model(context) # B, T, Vocab_size
        last_time_step = prediction[:, -1, :] # B, Vocab_size

        probabilities = nn.functional.softmax(last_time_step, dim=-1)
        next_char = torch.multinomial(probabilities, 1, generator=generator)
        generator.set_state(initial_state)
        context = torch.cat((context, next_char), dim=-1)
        res.append(int_to_char[next_char.item()])
        return ''.join(res)
    

In [103]:
model = GPT(104, 128 , 252, 6, 6)
new_chars = 1
context = torch.zeros(1, 1, dtype=int)
context_length = 128
int_to_char={0: '\n', 1: ' ', 2: '!', 3: '"', 5: '%', 6: '&', 7: "'", 8: '(', 9: ')', 10: '*', 11: '+', 12: ',', 13: '-', 14: '.', 15: '/', 16: '0', 17: '1', 18: '2', 19: '3', 20: '4', 21: '5', 22: '6', 23: '7', 24: '8', 25: '9', 26: ':', 27: ';', 28: '?', 29: 'A', 30: 'B', 31: 'C', 32: 'D', 33: 'E', 34: 'F', 35: 'G', 36: 'H', 37: 'I', 38: 'J', 39: 'K', 40: 'L', 41: 'M', 42: 'N', 43: 'O', 44: 'P', 45: 'Q', 46: 'R', 47: 'S', 48: 'T', 49: 'U', 50: 'V', 51: 'W', 52: 'X', 53: 'Y', 54: 'Z', 55: '[', 56: ']', 57: '_', 58: 'a', 59: 'b', 60: 'c', 61: 'd', 62: 'e', 63: 'f', 64: 'g', 65: 'h', 66: 'i', 67: 'j', 68: 'k', 69: 'l', 70: 'm', 71: 'n', 72: 'o', 73: 'p', 74: 'q', 75: 'r', 76: 's', 77: 't', 78: 'u', 79: 'v', 80: 'w', 81: 'x', 82: 'y', 83: 'z', 84: '{', 85: '|', 86: '}', 87: 'à', 88: 'á', 89: 'è', 90: 'é', 91: 'ë', 92: 'ñ', 93: 'ó', 94: 'ú', 95: '\u2005', 96: '–', 97: '—', 98: '‘', 99: '’', 100: '“', 101: '”', 102: '…', 103: '\u205f'}

In [104]:
print(generate(model, new_chars, context, context_length, int_to_char))

%
