### Imports

In [1]:
import torch
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


### Data

In [2]:
dataset = load_dataset("alpindale/light-novels")

Generating train split: 100%|██████████| 9240994/9240994 [00:04<00:00, 2260670.95 examples/s]


In [9]:
text_dataset = "\n".join(dataset["train"]["text"])

### Tokenizer

In [20]:
# char based tokenization and not BPE (for simplicity)
vocab = list(set(text_dataset))
encoder_map = {v: i for i, v in enumerate(vocab)}
decoder_map = {i: v for i, v in enumerate(vocab)}

encode = lambda sentence: [encoder_map[character] for character in sentence]
decode = lambda lofints: "".join([decoder_map[i] for i in lofints])

### Model

In [26]:
embedding = torch.randn((100, 32), dtype=torch.float32)

In [27]:
embedding[5]

tensor([-0.3904,  0.1600, -0.0364,  0.7662, -2.0117,  0.2296,  1.9901, -1.1554,
        -1.3125,  0.8000,  1.6572,  0.1503,  0.6569,  0.0314, -1.1379,  0.1635,
        -1.0883, -0.1894,  0.8328, -0.1696,  0.2562, -0.2693, -1.0407, -0.2597,
        -0.9397, -1.2543, -0.2378, -1.8291, -1.8532,  0.7988,  0.2974, -0.1380])

In [267]:
class BasicTransformerNetwork:

    def __init__(self, vocab_size, model_dim, nb_heads, attn_size):
        
        # model vector sizes
        self.vocab_size = vocab_size
        self.model_dim = model_dim
        self.nb_heads = nb_heads

        # layers
        self.attn_projs = {}
        for n in range(self.nb_heads):
            self.attn_projs[f"k_proj_{n}"] = torch.randn((model_dim, attn_size), dtype=torch.float32)
            self.attn_projs[f"q_proj_{n}"] = torch.randn((model_dim, attn_size), dtype=torch.float32)
            self.attn_projs[f"v_proj_{n}"] = torch.randn((model_dim, attn_size), dtype=torch.float32)
        self.embedding = torch.randn((vocab_size, model_dim), dtype=torch.float32)
        self.attn_proj_back = torch.randn((attn_size * nb_heads, model_dim), dtype=torch.float32)
        self.ff_layer = torch.randn((model_dim, model_dim), dtype=torch.float32)
        self.ff_layer2 = torch.randn((model_dim, model_dim), dtype=torch.float32)
        self.ff_bias = torch.randn((1, model_dim), dtype=torch.float32)
        self.ff_bias2 = torch.randn((1, model_dim), dtype=torch.float32)
        self.output = torch.randn((model_dim, vocab_size), dtype=torch.float32)

        # parameters and activations tracker
        self.parameters = {}
        self.cache = {}
        self.activations = {}

    def positional_encoding(self, x_pos):
        """sinus non-trainable positional encoding to add with input"""
        seq_len = len(x_pos)
        positions = x_pos.unsqueeze(1)
        div_term = torch.arange(0, self.model_dim, 2) * (- torch.log(torch.tensor(1e4)) / self.model_dim)
        pe = torch.zeros(seq_len, self.model_dim)
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)
        return pe
    
    def residual_connexion(self, current_layer, dragged_layer):
        """generic residual connexion operation"""
        return current_layer + dragged_layer

    def input_block(self, x, x_pos):
        """embedding input and positional encoding input"""
        input_embed = self.embedding[x]
        position_embed = self.positional_encoding(x_pos)
        total_input = input_embed + position_embed
        total_input, total_input_mean, total_input_std = self.layernorm(total_input)
        total_input = self.dropout(total_input, 0.2)
        return total_input
    
    def attention_block(self, x):
        """full attention block"""
        all_attn_outputs = []
        for n in range(self.nb_heads):
            K, Q, V = (
                x @ self.attn_projs[f"k_proj_{n}"],
                x @ self.attn_projs[f"q_proj_{n}"],
                x @ self.attn_projs[f"v_proj_{n}"]
            )
            pre_attention = Q @ K.T
            scale = pre_attention / torch.sqrt(torch.tensor(self.model_dim, dtype=torch.float32))
            ones = torch.ones(scale.size(), dtype=torch.float32)
            mask = torch.tril(ones)
            masked_scale = scale * mask
            masked_scale = masked_scale.masked_fill(mask == 0, float("-inf"))
            attn_softm = torch.softmax(masked_scale, dim=1)
            attn_softm = self.dropout(attn_softm, 0.2)
            attn_output = attn_softm @ V
            all_attn_outputs.append(attn_output)
        all_attn_outputs = torch.cat(all_attn_outputs, dim=-1)
        all_attn_outputs = all_attn_outputs @ self.attn_proj_back
        all_attn_outputs = self.dropout(all_attn_outputs, 0.2)
        return all_attn_outputs

    def linear_block(self, x):
        """linear layer block before output"""
        self.cache["linear_block_input"] = x
        x, x_mean, x_std = self.layernorm(x)
        self.cache["linear_block_normalized_input"] = x
        self.cache["linear_block_mean"] = x_mean
        self.cache["linear_block_std"] = x_std
        ff_outout = x @ self.ff_layer + self.ff_bias # a0 * W1 + b1
        ff_outout = self.dropout(ff_outout, 0.2) # 
        ff_activation = self.relu_activation(ff_outout) # a1 = activation(W1 * x + b1)
        ff_output2 = ff_activation @ self.ff_layer2 + self.ff_bias2 # a1 * W2  + b2
        logits = ff_output2 @ self.embedding.T
        return logits
    
    def dropout(self, x, p):
        mask = (torch.rand_like(x) > p).float()
        return x * mask
    
    def layernorm(self, x):
        """generic normalization layer"""
        x_mean = torch.mean(x, dim=1, keepdim=True)
        x_std = torch.std(x, dim=1, keepdim=True)
        return (x - x_mean) / (x_std + 1e-6), x_mean, x_std
    
    def relu_activation(self, x):
        """generic rectified linear unit layer"""
        return torch.relu(x)

    def forward(self, x, x_idx):
        """full transformer forward pass"""
        total_input = self.input_block(x, x_idx)
        attention = self.attention_block(total_input)
        residual1 = self.residual_connexion(attention, total_input)
        logits = self.linear_block(residual1)
        return logits

In [291]:
def compute_loss(logits, y):
    exp_logits = torch.exp(logits - torch.max(logits))
    softmax = exp_logits / torch.sum(exp_logits, dim=-1, keepdim=True)
    loss = - torch.log(softmax.gather(1, y) + 1e-9).mean() 
    return softmax, loss

In [153]:
def layernorm_backprop(dout, x, mean, std):

    xmu = x - mean
    xsgma = 1 / (std + 1e-6)
    dnorm = dout * xsgma

    dx = dnorm
    dx -= dnorm.mean(dim=1, keepdims=1)
    dx -= (xmu / (std + 1e-6) ** 2) * (dnorm * xmu)/mean(dim=1, keepdims=1)
    return dx

def linear_block_backprop(y, output_softmax, parameters, activations, cache):

    grads = {}
    dZ2 = output_softmax - output_softmax.gather(1, y)

    grads["dW2"] = torch.matmul(activations[f"A1"].T, dZ2)
    grads["db2"] = torch.sum(dZ2, dim=0, keepdim=True)

    dA1 = torch.matmul(dZ2, parameters[f"W2"].T)
    dZ1 = dA1 * (activations["A0"] > 0) # this is the input of the linear block

    grads["dW1"] = torch.matmul(activations["A0"].T, dZ1)
    grads["db1"] = torch.sum(dZ1, dim=0, keepdim=True)

    dA0 = torch.matmul(dZ1, parameters["W1"].T)

    grads["dA0"] = layernorm_backprop(
        dA0,
        cache["linear_block_input"],
        cache["linear_block_mean"],
        cache["linear_block_std"],
    )
    
    
    return grads

In [None]:
def optimizer_step():
    pass

In [268]:
model_dim, nb_heads, attn_size = 64, 1, 32
transformer = BasicTransformerNetwork(vocab_size=len(vocab), model_dim=model_dim, nb_heads=nb_heads, attn_size=attn_size)

In [164]:
example = "This is an example sentence"
input_ids = encode(example)
pos_ids = list(range(len(input_ids)))

input_ids = torch.tensor(input_ids, dtype=torch.long)
pos_ids = torch.tensor(pos_ids, dtype=torch.long)

In [269]:
transformer.forward(input_ids, pos_ids).size()

torch.Size([27, 3165])

In [285]:
logits = torch.randn((27, 64), dtype=torch.float32)
y = torch.randint(0, 63, (1, 27), dtype=torch.long)
prediction = torch.tensor(30, dtype=torch.float32)
exp_logits = torch.exp(logits - torch.max(logits))
softmax = exp_logits / torch.sum(exp_logits, dim=-1, keepdim=True)

In [286]:
y.size(), y.unsqueeze(1).size(), softmax.size()

(torch.Size([1, 27]), torch.Size([1, 1, 27]), torch.Size([27, 64]))

In [282]:
y

tensor([[ 4],
        [58],
        [17],
        [61],
        [37],
        [54],
        [46],
        [ 8],
        [ 1],
        [39],
        [46],
        [ 4],
        [12],
        [47],
        [ 0],
        [32],
        [42],
        [ 5],
        [43],
        [54],
        [53],
        [58],
        [26],
        [44],
        [21],
        [ 0],
        [ 7]])

In [290]:
softmax.gather(1, y)

torch.Size([1, 27])

In [None]:
torch.gather()