In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show
from bertviz import head_view
from transformers import AutoModel
from torch import nn
from transformers import AutoConfig
import torch
from math import sqrt
import torch.nn.functional as F

# Transformer

![Image Title](images/transformer.png)

*Encoder-decoder architecture of the Transformer, with the encoder shown in the upper half of the figure and the decoder in the lower half.*

## Transformer Encoder

The Transformer’s encoder consists of many encoder layers stacked next to each other. Each encoder layer receives a sequence of embeddings and feeds them through the following sub-layers:
1. A multi-head self-attention layer.
2. A shared feed-forward layer.

The output embeddings of each encoder layer have the same size as the inputs.

The main role of the encoder stack is to "update" the input embeddings to produce representations that encode some contextual information in the sequence. For example the word "apple" will be updated to be more "company-like" and less "fruit-like" if the words "keynote" or "phone" are close to the word.

Each of these sub-layers also has a skip connection and layer normalization, which are standard tricks to train deep neural networks effectively.

![Image Title](images/encoder.png)

*Zooming into the encoder layer.*

## Self-Attention

**Attention** is a mechanism that allows neural networks to assign a different amount of weight or "attention" to each element in a sequence. For text sequences, the elements are token embeddings. Each token is mapped to a vector of some fixed dimension. For example, in BERT each token is represented as a 768-dimensional vector.

The main idea behind self-attention is that instead of using a fixed embedding for each token, we can use the whole sequence to compute a weighted average of each embedding. Another way to formulate this is to say that given a sequence of token embeddings $x_1, \ldots, x_n$, self-attention produces a sequence of new embeddings $x'_1, \ldots, x'_n$ where each $x'_i$ is a linear combination of all the $x_j$:

$$x'_i = \sum_{j=1}^{n} w_{ij}x_j$$

The coefficients $(w_{ji})$ are called attention weights and are normalized so that $\sum_{j} w_{ji} = 1$.

Embeddings that are generated in this way are called contextualized embeddings.

### Time flies like an arrow; fruit flies like a banana

![Image Title](images/contextualized_embeddings.png)

*Diagram showing how self-attention updates raw token embeddings (upper) into contextualized embeddings (lower) to create representations that incorporate information from the whole sequence.*

### Scaled Dot-Product Attention

There are several ways to implement a self-attention layer, but the most common is scaled dot-product attention from the Attention is All You Need paper where the Transformer was introduced. There are four main steps needed to implement this mechanism:
1. **Create query, key, and value vectors.** Each token embedding is projected into query, key, and value vectors.
2. **Compute attention scores.** Determine how much the query and key vectors relate using a similarity function. As the name suggests, the similarity function for scaled dot-product attention is the dot-product, computed efficiently using matrix multiplication of the embeddings. Similar queries and keys will have a large dot product, while those that don't share much in common will have little to no overlap. The outputs from this step are called the attention scores, and for a sequence with $n$ input tokens, there is a corresponding $n \times n$ matrix of attention scores.
3. **Compute attention weights.** Dot products can, in general, produce arbitrarily large numbers, which can destabilize the training process. To handle this, the attention scores are first multiplied by a scaling factor to normalize their variance and then normalized with a softmax to ensure all the column values sum to one. The resulting $n \times n$ matrix now contains all the attention weights $w_{ji}$.
4. **Update the token embeddings.** Once the attention weights are computed, we multiply them by the value vector $( v_1, \ldots, v_n )$ to obtain an updated representation for embedding $(x'_i = \sum_{j} w_{ji}v_j )$.

**The first step of self-attention calculation**: For each word, we create a Query vector, a Key vector, and a Value vector. These vectors are created by multiplying the embedding by three matrices that we trained during the training process.

Notice that these new vectors are smaller in dimension than the embedding vector. Their dimensionality is 64, while the embedding and encoder input/output vectors have dimensionality of 512. They DON'T HAVE to be smaller, this is an architecture choice to make the computation of multiheaded attention (mostly) constant.

![Image Title](images/qkv.png)

*Multiplying $x_1$ by the $W_Q$ weight matrix produces $q_1$, the "query" vector associated with that word. We end up creating a "query", a "key", and a "value" projection of each word in the input sentence.*

![Image Title](images/self_attention_process.png)

*Self-attention calculation.*

**The second step of self-attention calculation**: We're calculating the self-attention for the first word in this example, "Thinking". We need to score each word of the input sentence against this word. The score determines how much focus to place on other parts of the input sentence as we encode a word at a particular position.

The score is calculated by taking the dot product of the query vector with the key vector of the respective word we're scoring. So, if we're processing the self-attention for the word in position #1, the first score would be the dot product of $q_1$ and $k_1$. The second score would be the dot product of $q_1$ and $k_2$.

**The third and fourth steps** are to divide the scores by 8 (the square root of the dimension of the key vectors used in the paper—64), which leads to more stable gradients. Other possible values here could be used, but this is the default. Then, the result is passed through a softmax operation. Softmax normalizes the scores to be positive and add up to 1.

This softmax score determines how much each word will be expressed at this position. Clearly, the word at this position will have the highest softmax score, but sometimes, attending to another word relevant to the current word is helpful.

**The fifth step** is to multiply each value vector by the softmax score (in preparation for summarizing them). The intuition here is to keep intact the values of the word(s) we want to focus on and drown out irrelevant words (by multiplying them by tiny numbers like 0.001, for example).

**The sixth step** is to sum up the weighted value vectors. This produces the output of the self-attention layer at this position (for the first word).

$$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

![Image Title](images/attention_matrix_form.png)

*The self-attention calculation in matrix form.*

That concludes the self-attention calculation. The resulting vector is one we can send along to the feed-forward neural network.

In [2]:
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
config = AutoConfig.from_pretrained(model_ckpt)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 433/433 [00:00<00:00, 1036014.62B/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 440473133/440473133 [00:25<00:00, 16956931.21B/s]


In [3]:
text = "time flies like an arrow"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [4]:
text = "fruit flies like a banana"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Attention implementation

In [5]:
text = "time flies like an arrow"

inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
inputs.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612]])

In [6]:
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
tokens

['time', 'flies', 'like', 'an', 'arrow']

Then, we need to create some dense embeddings. "Dense" in this context means that each entry in the embeddings contains a non-zero value.

In [7]:
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

Each input ID will be mapped to one of the 30,522 embedding vectors stored in `nn.Embedding`, each with a size of 768. Note that the token embeddings at this point are independent of their context.

In [8]:
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [9]:
# look-up table
inputs_embeds = token_emb(inputs.input_ids)
inputs_embeds.size()

torch.Size([1, 5, 768])

### Embedding Layer vs Linear Layer

Both `nn.Embedding` and a linear layer (`nn.Linear`) in `PyTorch` are used to transform input data into a different representation, but they serve different purposes and have different characteristics:

1. Purpose:
    * **Embedding Layer** (`nn.Embedding`): Specifically designed for handling sparse data where the input is categorical, such as word embeddings in NLP tasks. It maps discrete indices (representing words or tokens) to dense vectors, capturing semantic relationships between tokens.
    * **Linear Layer** (`nn.Linear`): A general-purpose layer used for linear transformations of input data. It's typically used for tasks such as classification, regression, or other tasks with continuous input features.

2. Input Data:
    * **Embedding Layer**: Expects integer indices as input, representing categories or tokens. Each integer index corresponds to a row in the embedding matrix.
    * **Linear Layer**: Expects continuous-valued input data, such as feature vectors or the output of other layers.

3. Output:
    * **Embedding Layer**: Produces dense vector representations (embeddings) of the input tokens. The vocabulary size and the embedding dimension specified during initialization determine the output dimensions.
    * **Linear Layer**: Produces a linear transformation of the input data. The output dimensions are determined by the number of output units specified during initialization.

4. Parameters:
    * **Embedding Layer**: The parameters are the embedding matrix itself, which is learned during training. The vocabulary size and embedding dimension determine the size of the embedding matrix.
    * **Linear Layer**: The parameters are the weight matrix and bias vector, both learned during training. The size of the weight matrix is determined by the input and output dimensions specified during initialization.

5. Usage:
    * **Embedding Layer**: Primarily used in NLP tasks to represent words or tokens as dense vectors. It is often used as the first layer in neural network architectures for processing text.
    * **Linear Layer**: Widely used in various neural network architectures for transforming input data or connecting layers. It's used in tasks like classification, regression, or any other task requiring linear transformations.

The next step is to create the query, key, and value vectors and calculate the attention scores using the dot-product as the similarity function:

In [10]:
# We’ll see later that the query, key, and value vectors are generated by applying independent weight matrices W_Q, W_K, W_V to the embeddings, but for now we’ve kept them equal for simplicity.

# In scaled dot-product attention, the dot-products are scaled by the size of the embedding vectors so that we don’t get too many large numbers during training that can cause the softmax we will apply next to saturate.

query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k) # bmm performs a batch matrix-matrix product
scores.size()

torch.Size([1, 5, 5])

In [11]:
weights = F.softmax(scores, dim=-1)
print(weights.shape)

torch.Size([1, 5, 5])


In [12]:
scores

tensor([[[25.8431, -2.1171, -1.0447,  0.4381,  2.0931],
         [-2.1171, 25.3569, -0.4879, -0.6890, -0.5012],
         [-1.0447, -0.4879, 27.0742, -1.2382,  0.0499],
         [ 0.4381, -0.6890, -1.2382, 27.0541, -0.9226],
         [ 2.0931, -0.5012,  0.0499, -0.9226, 29.2967]]],
       grad_fn=<DivBackward0>)

In [13]:
weights

tensor([[[1.0000e+00, 7.1950e-13, 2.1027e-12, 9.2625e-12, 4.8472e-11],
         [1.1700e-12, 1.0000e+00, 5.9666e-12, 4.8798e-12, 5.8879e-12],
         [6.1391e-13, 1.0713e-12, 1.0000e+00, 5.0590e-13, 1.8344e-12],
         [2.7595e-12, 8.9400e-13, 5.1620e-13, 1.0000e+00, 7.0777e-13],
         [1.5333e-12, 1.1453e-13, 1.9874e-13, 7.5151e-14, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)

In [14]:
torch.bmm(weights, value).shape

torch.Size([1, 5, 768])

In [15]:
weights = F.softmax(scores, dim=-1)
weights

tensor([[[1.0000e+00, 7.1950e-13, 2.1027e-12, 9.2625e-12, 4.8472e-11],
         [1.1700e-12, 1.0000e+00, 5.9666e-12, 4.8798e-12, 5.8879e-12],
         [6.1391e-13, 1.0713e-12, 1.0000e+00, 5.0590e-13, 1.8344e-12],
         [2.7595e-12, 8.9400e-13, 5.1620e-13, 1.0000e+00, 7.0777e-13],
         [1.5333e-12, 1.1453e-13, 1.9874e-13, 7.5151e-14, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)

In [16]:
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [17]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 5, 768])

In [18]:
attn_outputs

tensor([[[ 2.1856e+00,  2.3719e-01,  5.9591e-01,  ...,  1.6541e+00,
          -2.2697e-01, -1.5417e-01],
         [-1.3669e+00, -4.5329e-01, -6.8402e-01,  ..., -4.8425e-02,
           6.9000e-01, -3.4179e-01],
         [ 1.3553e-01, -1.1197e+00, -1.4007e+00,  ...,  1.1851e-01,
          -1.5931e+00, -1.1045e+00],
         [-6.9330e-01,  3.7640e-01,  1.5337e+00,  ..., -5.6836e-01,
          -1.0827e+00, -9.6147e-01],
         [ 2.0786e-03, -8.7819e-01,  3.1212e-01,  ...,  1.6152e+00,
          -9.8263e-01, -3.2690e-02]]], grad_fn=<BmmBackward0>)

In [19]:
def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In practice, the self-attention layer applies three independent linear transformations to each embedding to generate the query, key, and value vectors.

These transformations project the embeddings and each projection carries its own set of learnable parameters, which allows the self-attention layer to focus on different semantic aspects of the sequence.

In [20]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
        return attn_outputs

### Multi-Headed Attention

It also turns out to be beneficial to have multiple sets of linear projections, each one representing a so-called attention head.

But why do we need more than one attention head? The reason is that softmax of one head tends to focus on mostly one aspect of similarity.

Having several heads allows to focus on several aspects at once. For instance one head can focus on subject-verb interaction, whereas another finds nearby adjectives.

Obviously, we don’t handcraft these relations into the model and they are fully learned from the data.

![Image Title](images/multi_head_att.png)

*Multi-headed attention*

In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)] )
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.output_linear(x)
        return x

In practice `head_dim` is chosen to be a multiple of `embed_dim` so that the computation across each head is constant. For example in BERT has 12 attention heads, so the dimension of each head is 768/12 = 64.

In [22]:
multihead_attn = MultiHeadAttention(config)
attn_output = multihead_attn(inputs_embeds)
attn_output.size()

torch.Size([1, 5, 768])

In [23]:
attn_output

tensor([[[-0.1212, -0.0313,  0.1591,  ...,  0.0748,  0.0909,  0.1840],
         [-0.1439, -0.1211,  0.1224,  ...,  0.0935,  0.0102,  0.1706],
         [-0.1026, -0.1296,  0.1588,  ...,  0.0783,  0.0108,  0.1866],
         [-0.0641, -0.1473,  0.0774,  ...,  0.0525,  0.0317,  0.1762],
         [-0.1801, -0.1613,  0.1175,  ...,  0.0094, -0.0077,  0.2061]]],
       grad_fn=<ViewBackward0>)

In [24]:
model = AutoModel.from_pretrained(model_ckpt, output_attentions=True)
sentence_a = "time flies like an arrow"
sentence_b = "fruit flies like a banana"
viz_inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')
attention = model(**viz_inputs).attentions
sentence_b_start = (viz_inputs.token_type_ids == 0).sum(dim=1)
tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])
head_view(attention, tokens, sentence_b_start, heads=[8])

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

### Feed Forward Layer

The feed forward sub-layer in the encoder and decoder is just a simple 2-layer fully-connected neural network, but with a twist: instead of processing the whole sequence of embeddings as a single vector, it processes each embedding independently.

In [25]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        print(x.shape)
        x = self.linear_1(x)
        print(x.shape)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [26]:
feed_forward = FeedForward(config)
ff_outputs = feed_forward(attn_outputs)
ff_outputs.size()

torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])


torch.Size([1, 5, 768])

### Adding Layer Normalization

![Image Title](images/normalization.png)

*Different arrangements of layer normalization in a transformer encoder layer.*

In [27]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        
    def forward(self, x):
        # Apply layer normalization and then copy input into query, key, value
        hidden_state = self.layer_norm_1(x)
        # Apply attention with a skip connection
        x = x + self.attention(hidden_state)
        # Apply feed-forward layer with a skip connection
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

In [28]:
encoder_layer = TransformerEncoderLayer(config)
inputs_embeds.shape, encoder_layer(inputs_embeds).size()

torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])


(torch.Size([1, 5, 768]), torch.Size([1, 5, 768]))

## Positional Encodings

In [29]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size,
                                             config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()
        
    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0) # create token and position embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        print(position_embeddings)
        # Combine token and position embeddings
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [30]:
embedding_layer = Embeddings(config)
embedding_layer(inputs.input_ids).size()

tensor([[[ 0.9510, -1.6721,  0.1688,  ..., -0.2077,  1.0721,  0.6463],
         [-0.4921, -0.0740,  1.4788,  ..., -0.7184, -1.0116,  0.4600],
         [-0.8165,  1.2812, -0.4372,  ..., -0.6765,  1.1058, -0.0846],
         [ 0.3262, -0.3868, -0.2307,  ..., -0.0365,  1.0664, -0.2932],
         [-0.1875, -0.7765,  0.3548,  ...,  2.9257,  1.1251,  0.0363]]],
       grad_fn=<EmbeddingBackward0>)


torch.Size([1, 5, 768])

In [31]:
config.max_position_embeddings

512

In [32]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [33]:
encoder = TransformerEncoder(config)
encoder(inputs.input_ids).size()

tensor([[[-0.9269, -0.2352, -0.6885,  ...,  0.1051, -0.2508,  0.5033],
         [-0.4158, -1.0667, -0.3377,  ...,  1.6644, -0.6727,  0.4324],
         [ 0.2909, -0.4881,  0.8657,  ...,  0.4432, -0.1674,  0.0734],
         [ 0.4359,  0.2352, -1.3667,  ...,  2.3093, -0.8356, -0.8579],
         [ 1.0572,  0.3104,  0.5621,  ...,  1.9411,  0.4417,  0.5687]]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])


torch.Size([1, 5, 768])

In [34]:
# For classification tasks, it is common practice to just use the hidden state associated with the [CLS] token as the input feature.

class TransformerForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):
        x = self.encoder(x)[:, 0, :] # select hidden state of [CLS] token
        x = self.dropout(x)
        x = self.classifier(x)
        return x

In [35]:
config.num_labels = 3
encoder_classifier = TransformerForSequenceClassification(config)
encoder_classifier(inputs.input_ids).size()

tensor([[[-0.1515, -0.1657,  0.0414,  ..., -0.5167, -0.7033,  1.1934],
         [ 0.9560, -0.1290,  1.1958,  ...,  1.3158,  0.2679,  0.1183],
         [-0.3716,  0.5560, -0.0558,  ..., -0.1341, -0.8486,  0.9830],
         [ 0.6791, -1.3523,  0.0791,  ...,  0.1129,  0.0725, -2.1017],
         [-0.5500,  1.9636, -0.5887,  ..., -0.6279,  0.5956, -0.3951]]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])
torch.Size([1, 5, 768])
torch.Size([1, 5, 3072])


torch.Size([1, 3])

## Transformer Decoder

![Image Title](images/decoder.png)

*Zooming into the Transformer decoder layer*

## Whisper summary

![Image Title](images/whisper_tasks.png)

*Whisper tasks.*

![Image Title](images/training_data.png)

*Training data. Of those 680k hours of audio, 117k hours cover 96 other languages.*

### Approach

**Data-processing**

1. No specific data pre-processing applied.
2. Developed several automated filtering methods to improve transcript quality.
3. Developed many heuristics to detect and remove machine-generated transcripts from the training dataset.
    1. An all-uppercase or all-lowercase transcript is very unlikely to be human-generated.
4. Use an audio language detector to ensure that the spoken language matches the language of the transcript according to CLD2.
5. Break audio files into 30-second segments paired with the subset of the transcript that occurs within that time segment.
6. Train on all audio, including segments with no speech (though with sub-sampled probability), and use these segments as training data for voice activity detection.
7. After training an initial model, they aggregated information about its error rate on training data sources. They manually inspected these data sources, sorting by a combination of high error rates and data source size to identify and remove low-quality ones efficiently.

**Model**

1. Using an off-the-shelf architecture avoids confounding our findings with model improvements.
2. Choose an encoder-decoder Transformer.
3. All audio is re-sampled to 16,000 Hz, and an 80-channel log-magnitude Mel spectrogram representation is computed on 25 millisecond windows with a stride of 10 milliseconds.
4. They globally scale the input between -1 and 1 with approximately zero mean across the pre-training dataset for feature normalization.
5. They use the same byte-level BPE text tokenizer used in GPT-2 for the English-only models and refit the vocabulary (but keep the same size) for the multilingual models to avoid excessive fragmentation on other languages since the GPT-2 BPE vocabulary is English-only.

![Image Title](images/whisper_architecture.png)

*Whisper architecture.*

**Multitask Format**

1. Since their decoder is an audio-conditional language model, they also train it to condition the history of the transcript's text in the hope that it will learn to use longer-range text context to resolve ambiguous audio.
2. Specifically, with some probability, they add the transcript text preceding the current audio segment to the decoder's context.
3. A sequence-to-sequence Transformer model is trained on many different speech processing tasks, including multilingual speech recognition, speech translation, spoken language identification, and voice activity detection. These tasks are jointly represented as a sequence of tokens to be predicted by the decoder, allowing for a single model to replace many stages of a traditional speech-processing pipeline.

![Image Title](images/whisper_multitask.png)

# Homework

Implement class `Trainer` to fine-tune Whisper (tiny) on Toronto dataset. You may refer to the HF Trainer, but you have to implement it on your own.

Divide dataset into train, eval and test. Compare performance on test set using original model, original model with original LM, fine-tuned model, fine-tuned model with new LM.

**Sources**:
- Link to the Toronto dataset: https://drive.google.com/file/d/1j9d91QqE7_WnOnmEmidtOG55tpmxQUeJ/view?usp=sharing.
- Another training dataset: https://huggingface.co/datasets/skypro1111/whisper-dataset-ytb-uk. *I'm not sure about its quality*.

In [None]:
class Trainer:
    def __init__(self, ):
        self.model = model

        self.train_dataloader = ...
        self.eval_dataloader = ...
        self.criterion = ...

        # logging, parameters, etc


    def predict(self):
        # generate prediction of the model (generate sequence)
        pass

    def train_step(self, batch):
        # calculate loss on single batch
        return loss

    def train_epoch(self, epoch, loss_metrics):
        # do backward prob, calculate WER on train epoch
        pass

    def validate(self):
        # calculate WER on validation subset
        pass

    def train(self):
        for e in range(n_epoch):
            self.train_epoch(e)
            self.validate(e - 1)