In [2]:
import torch
import torch.nn as nn
import math
from transformers import LlamaTokenizer
from icecream import ic
from config.config import conf
from model.Transformer import Transformer
import torch.utils.data as Data
import os
from tqdm import tqdm, trange
import shutil
from torch.utils.tensorboard import SummaryWriter
from ordered_set import OrderedSet

In [3]:
############################################### norm ###############################################
class LayerNorm(nn.Module):
    '''layer normalization

    Args:
        shape (int): length of the embedding vector
        eps (float): a small number to prevent division by zero
    '''
    def __init__(self, shape, eps = 1e-5):
        super(LayerNorm, self).__init__()
        # initialize two learnable parameters
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # prevent division by zero
        self.eps = eps
        
    # 8 * len * d_model
    def forward(self, x):
        # mean and variance
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased = False, keepdim=True)
        normalized = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * normalized + self.beta

In [4]:
############################################### attention ###############################################
class SelfAttention(nn.Module):
    '''self attention by dot-product

    Args:
        scale_factor (): scale_factor
        dropout (float): dropout rate
    '''
    def __init__(self, scale_factor, dropout=0.0):
        super(SelfAttention, self).__init__()
        self.scale_factor = scale_factor
        # self.dropout = nn.Dropout(dropout)
 
    # (4d_k + 3)len*len (should multiply by heads)
    def forward(self, q, k, v, mask=None):
        # matmul & scale
        # 2*d_k * len * len + len*len
        scores = torch.matmul(q, k.transpose(2, 3)) / self.scale_factor
 
        # optional mask
        # 2*len*len
        if mask is not None:
            # use -1e9 as the large negative value to mask the padding tokens
            scores = scores.masked_fill(mask == 0, -1e9)
        # softmax
        # 3 * len * len
        scores = torch.softmax(scores, dim=-1)
        # matmul
        # 2*d_k * len * len
        output = torch.matmul(scores, v)
        # 返回 output和注意力分数
        return output, scores
    

class MultiAttention(nn.Module):
    '''multi-head attention

    Args:
        n_heads (int): number of heads
        dim (int): length of the embedding vector
        dim_k (int): dim of k
        dim_v (int): dim of v
        dropout (float): dropout rate
    '''
    def __init__(self, n_heads = 8, dim = 512, dim_k = 64, dim_v = 64,  dropout=0.1):
        super(MultiAttention, self).__init__()
        self.n_heads = n_heads
        self.dim_k = dim_k
        self.dim_v = dim_v
 
        # weight matrices for Q, K, V
        # the linear layer represents the weight matrices
        self.Wq = nn.Linear(dim, n_heads * dim_k, bias=False)
        self.Wk = nn.Linear(dim, n_heads * dim_k, bias=False)
        self.Wv = nn.Linear(dim, n_heads * dim_v, bias=False)
        self.fc = nn.Linear(n_heads * dim_v, dim, bias=False)
 
        self.attention = SelfAttention(scale_factor=dim_k ** 0.5)
 
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = LayerNorm(dim, eps=1e-6)
 
    # 8 * len * d_model * d_model + 4 * len * len * d_model + 3 * len * len * h + 9 * len * d_model
    def forward(self, q, k, v, mask=None):
        # q, k, v：[batch_size, seq_num, dim]
        # len_k为输入的序列长度
        batch_size = q.size(0)
        # for residual connection
        residual = q
        # breakpoint()
 
        # multiplied by W^Q, W^K, W^V
        # (batch_size, length, n_heads, dim_k) => (batch_size, n_heads, length, dim_k)
        # 3 * 2 * len * d_model * d_model
        query = self.Wq(q).view(batch_size, -1, self.n_heads, self.dim_k).transpose(1, 2)
        key   = self.Wk(k).view(batch_size, -1, self.n_heads, self.dim_k).transpose(1, 2)
        value = self.Wv(v).view(batch_size, -1, self.n_heads, self.dim_v).transpose(1, 2)
 
        if mask is not None:
            mask = mask.unsqueeze(1) 
        
        # (4d_k + 3) * len * len * h => 
        # 4 * d_model * len * len + 3 * len * len * h
        x, attn = self.attention(query, key, value, mask=mask)
 
        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        # (batch_size, n_heads, length, d_k/d_v) => (batch_size, length, n_heads, d_k/d_v) => (batch_size, length, dim)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.dim_k)
        # the final linear layer
        # 2 * len * d_model * d_model
        # > we apply dropout to the output of each sub-layer, before it is added to the sub-layer input and normalized.
        x = self.dropout(self.fc(x))
        # add in add & norm
        # len * d_model
        x += residual
        # norm in add & norm
        # 8 * len * d_model
        x = self.layer_norm(x)
        return x, attn

In [5]:
############################################### embedding & encoding ###############################################
class Embeddings(nn.Module):
    '''word embeddings

    Args:
        dim (int): length of the embedding vector
        vocab (int): number of words in the vocabulary
    '''
    def __init__(self, dim, vocab):
        super(Embeddings, self).__init__()
        # 调用nn.Embedding预定义层，获得实例化词嵌入对象self.lut
        self.lut = nn.Embedding(vocab, dim)
        self.dim = dim  #表示词向量维度
 
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.dim)
    
class RoPE(nn.Module):
    '''positional encoding with sin and cos

    Args:
        dim (int): length of the embedding vector, same as embedding dim
        max_len (int): max length of the input sequence
        dropout (float): dropout rate
    '''
    def __init__(self, dim, max_len=5000, dropout=0.1):
        super(RoPE, self).__init__()
 
        # positional encoding matrix
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)   
        pe[:, 1::2] = torch.cos(position * div_term)   
        # batch size occupies one dim
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        # x: [batch_size, seq_len, dim] embedding
        x = x + self.pe[:, :x.size(1)].clone().detach() 
        # > We apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks.
        return self.dropout(x)

In [6]:
############################################### ff ###############################################
class FeedForward(nn.Module):
    '''position-wise feed-forward network

    Args:
        dim (int): dimension of input/output
        hidden_dim (int): dimension of hidden layer
        dropout (float): dropout rate
    '''
    def __init__(self, dim = 512, hidden_dim = 2048, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(dim, hidden_dim, bias=False),
            nn.ReLU(),
            # > we apply dropout to the output of each sub-layer, before it is added to the sub-layer input and normalized.
            nn.Linear(hidden_dim, dim, bias=False),
            nn.Dropout(dropout),
        )
        self.layer_norm = LayerNorm(dim, eps=1e-6)
        
    # 4 * len*d_model * d_ff + 9 * len * d_model
    def forward(self, x):
        # x: [batch_size, len, dim]
        residual = x
        x = self.linear(x)
        # add in add & norm
        x += residual
        # norm in add &norm
        x = self.layer_norm(x)
        return x

In [7]:
############################################### decoder ###############################################
# mask out previous tokens by a triangular matrix
# position of previous tokens will be True
def mask_subsequence(sequence):
    # get the size of the sequence
    batch_size, seq_len = sequence.size()
    # create a triangular matrix
    mask = torch.tril(torch.ones(batch_size, seq_len, seq_len)).bool()
    # mask = torch.triu(torch.ones(batch_size, seq_len, seq_len), diagonal=1).bool()
    return mask

class DecoderLayer(nn.Module):
    '''one layer of decoder
    
    Args:
        dim (int): length of the embedding vector
        ff_dim (int): length of the hidden layer in the feedforward network
        n_heads (int): number of heads
    '''
    def __init__(self, dim, ff_dim, n_heads):
        super(DecoderLayer, self).__init__()
        self.decoder_masked_atten = MultiAttention(dim=dim, n_heads=n_heads, dim_k=dim//n_heads, dim_v=dim//n_heads)
        self.ffnet = FeedForward(dim=dim, hidden_dim=ff_dim)
    
    # inputs is the input of the decoder
    # masked_atten_mask is the mask for the masked multi-head attention
    def forward(self, inputs, masked_atten_mask):
        outputs, masked_attention = self.decoder_masked_atten(inputs, inputs, inputs, masked_atten_mask)
        # the results from last masked multi-head attention is used as Q
        outputs = self.ffnet(outputs)
        return outputs, masked_attention
    
# for encoder-decoder architecture
class Decoder(nn.Module):
    '''decoder itself

    Args:
        vocab (int): number of words in the vocabulary
        emb_dim (int): length of the embedding vector
        ff_dim (int): length of the hidden layer in the feedforward network
        context_len (int): length of the context
        n_layers (int): number of layers
        n_heads (int): number of heads
        device
    '''
    def __init__(self, vocab, emb_dim, ff_dim, context_len, n_layers = 6, n_heads = 8, device = 'cuda'):
        super(Decoder, self).__init__()
        self.embedding = Embeddings(emb_dim, vocab)
        self.encoding = RoPE(emb_dim)
        self.layers = nn.ModuleList([DecoderLayer(emb_dim, ff_dim, n_heads) for _ in range(n_layers)])
        self.device = device
       
    # inputs is the input of the decoder
    def forward(self, inputs): 
        # mask out stop words and subsequence in the inputs
        masked_atten_mask = mask_subsequence(inputs).to(self.device)
        # embedding & encoding
        embedding = self.embedding(inputs)
        outputs = self.encoding(embedding)
        # decode
        for layer in self.layers:
            outputs, _ = layer(outputs, masked_atten_mask)
        return outputs

In [8]:
############################################### transformer ###############################################
class Transformer(nn.Module):
    def __init__(self, emb_dim = 512, ff_dim=2048, context_len=256,
                 dec_layers = 6, n_heads = 8,
                 vocab = 1e4, device = 'cuda'):
        super(Transformer, self).__init__()
        self.decoder = Decoder(vocab, emb_dim, ff_dim, context_len, dec_layers, n_heads, device)
        self.projection = nn.Linear(emb_dim, vocab, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.context_len = context_len
        
    def forward(self, inputs):
        dec_outputs = self.decoder(inputs)
        # softmax is included in CrossEntropyLoss
        # 2 * d_model * len * vocab
        outputs = self.projection(dec_outputs)
        return outputs.view(-1, outputs.size(-1))
    
    def generate(self, context, max_len = 200, terminate = None):
        for _ in range(max_len):
            dec_outputs = self.decoder(context[:, -self.context_len:])
            outputs = self.projection(dec_outputs)
            outputs = outputs[:, -1, :]
            outputs = self.softmax(outputs)
            # sample one token id as the next
            next_token = torch.multinomial(outputs, 1)
            context = torch.cat([context, next_token], dim=1)
            # terminiate if the token is the end token
            if terminate is not None and next_token.item() == terminate:
                break
        return context

In [9]:
# wget -c https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# git clone https://github.com/facebookresearch/llama
# apply as the guidance in the README.md
# bash llama/download.sh

# https://github.com/bl0nder/makespeare/blob/main/makespeare.py
def tokenize(path: str = "input.txt"):
    tokenizer = LlamaTokenizer.from_pretrained("tokenizer.model")
    # padding and beginning of sequence
    # special_tokens = ['<s>']
    special_tokens = ['<s>', '</s>']
    # '<unk>': 0  '<s>': 1  '</s>': 2  '<0x0A>': 13(carriage return)
    
    # convert_tokens_to_string
    # (Pdb) tokenizer('<s>cdc</s>')
    # {'input_ids': [1, 1, 274, 13891, 2], 'attention_mask': [1, 1, 1, 1, 1]}
    # (Pdb) tokenizer('<s>cdc</s>', add_special_tokens=False)
    # {'input_ids': [1, 274, 13891, 2], 'attention_mask': [1, 1, 1, 1]}

    # open input.txt and tokenize it
    with open(path, "r") as f:
        text = f.read()
        raw_tokens = tokenizer.tokenize(text)
    # to ensure the set is in the same order
    tokens = OrderedSet(raw_tokens)
    for tk in special_tokens: 
        tokens.discard(tk)
    vocab = {}
    reverse_vocab = {}
    for i, tk in enumerate(special_tokens):
        vocab[tk] = i
        reverse_vocab[i] = tk
    for id, token in enumerate(tokens):
        # leave 0 for padding
        vocab[token] = id + len(special_tokens)
        reverse_vocab[id + len(special_tokens)] = token
    raw_tokens.insert(0, '<s>')
    # this kind of data does not need this kind of separation
    i = 0
    while i < len(raw_tokens) - 1:
        if raw_tokens[i] == '<0x0A>' and raw_tokens[i + 1] == '<0x0A>':
            # indicate the beginning of a new sequence
            raw_tokens[i] = '</s>'
            raw_tokens[i + 1] = '<s>'
            i += 1
        i += 1
    if raw_tokens[-1] == '<s>':
        # raw_tokens[-1] = '<0x0A>'
        raw_tokens.pop()
        
    id_text = torch.LongTensor([vocab[token] for token in raw_tokens])
    print(f'{len(vocab)} tokens in total')
    print(f'{len(id_text) /1e6} M input tokens in total')
            
    return vocab, reverse_vocab, id_text

# decode a list of ids to a readable string
def decode(ids, reverse_vocab):
    tokenizer = LlamaTokenizer.from_pretrained("tokenizer.model", add_prefix_space=False)
    res = ''.join([reverse_vocab[id.item()] for id in ids[0]])
    res = tokenizer.convert_tokens_to_string(res)
    res = res.replace('<0x0A>', '\n')
    res = res.replace('</s>', '\n')
    res = res.replace('<s>', '\n')
    return res

class TinyDataset(Data.Dataset):
    def __init__(self, inputs, context_length, pad_idx = 0):
        super(TinyDataset, self).__init__()
        self.inputs = inputs
        self.context_length = context_length
        self.pad_idx = pad_idx
  
    def __len__(self):
        return self.inputs.size(0)
    
    def __getitem__(self, idx):
        x = self.inputs[idx: idx+self.context_length]
        y = self.inputs[idx+1: idx+self.context_length+1]
        # padding
        if x.size(0) < self.context_length:
            x = torch.cat([x, torch.full([self.context_length - x.size(0)], self.pad_idx, dtype=torch.long)])
        if y.size(0) < self.context_length:
            y = torch.cat([y, torch.full([self.context_length - y.size(0)], self.pad_idx, dtype=torch.long)])
            # y = torch.cat([y, torch.zeros(self.context_length - y.size(0), dtype=torch.long)])
        return x, y

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vocab, reverse_vocab, id_text = tokenize()
    split_idx = int(len(id_text) * 0.9)
    train_data = id_text[:split_idx]
    val_data = id_text[split_idx:]
    
    model = Transformer(emb_dim=conf['emb_dim'], ff_dim=conf['ff_dim'], context_len=conf['context_length'],
                        dec_layers=conf['decoder_layers'], n_heads=conf['heads'],
                        vocab=len(vocab), device=device).to(device)
    print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
    print((sum(p.numel() for p in model.parameters()) - conf['emb_dim'] * len(vocab))/1e6, 'M non-embedding parameters')
    criterion = nn.CrossEntropyLoss()
    # 5e-5
    optimizer = torch.optim.AdamW(model.parameters(), lr=conf['lr'])
    # 1e-3
    # optimizer = torch.optim.SGD(model.parameters(), lr=conf['lr'], momentum=0.99)
    
    train_loader = Data.DataLoader(TinyDataset(train_data, conf['context_length'], vocab['<0x0A>']), conf['batch_size'], True)
    valid_loader = Data.DataLoader(TinyDataset(val_data, conf['context_length'], vocab['<0x0A>']), conf['batch_size'], True)
    os.makedirs(os.path.join(conf['ckpt_path'], conf['exp']), exist_ok=True)
    # print('Start training...')
    # shutil.copyfile('config/config.py', os.path.join(conf['ckpt_path'], conf['exp'], 'config.py'))
    # writer = SummaryWriter(os.path.join('logs', conf['exp']))
    # for idx in trange(1, int(conf['iterations']) + 1):
    #     inputs, targets = next(iter(train_loader))
    #     inputs, targets = inputs.to(device), targets.to(device)
    #     optimizer.zero_grad()
    #     outputs = model(inputs)
    #     loss = criterion(outputs, targets.view(-1))
    #     loss.backward()
    #     optimizer.step()
        
    #     # lr decay
    #     decay_steps = conf['lr_decay'] * 1000
    #     decay_factor = 0.1 ** (1 / decay_steps)
    #     if idx > int(conf['iterations'] * conf['decay_initiation']):
    #         for param_group in optimizer.param_groups:
    #             param_group['lr'] = param_group['lr'] * decay_factor
            
    #     if idx % (int(conf['log_iter']) / 10) == 0:
    #         writer.add_scalar('loss/loss', loss, idx)
    #         writer.add_scalar('lr', optimizer.param_groups[0]['lr'], idx)
            
    #     if idx % int(conf['log_iter']) == 0:
    #         val_loss = 0
    #         with torch.no_grad():
    #             model.eval()
    #             for i in range(int(conf['val_iterations'])):
    #                 x, y = next(iter(valid_loader))
    #                 x, y = x.to(device), y.to(device)
    #                 outputs = model(x)
    #                 val_loss += criterion(outputs, y.view(-1))
    #             writer.add_scalar('loss/val_loss', val_loss / conf['val_iterations'], idx)
    #             model.train()
    #         tqdm.write(f'Iteration: {idx} loss = {loss:.8f} val_loss = {val_loss / conf["val_iterations"]:.8f}')
            
    #     if idx % int(conf['ckpt_iter']) == 0:
    #         torch.save(model.state_dict(), os.path.join(conf['ckpt_path'], conf['exp'], f'model_{idx}.pt'))
      
    # writer.close()
    # print('Done!')
    
    # generate
    ckpt = torch.load(os.path.join(conf['ckpt_path'], conf['exp'], f'model_30000.pt'))
    # ckpt = torch.load(os.path.join(conf['ckpt_path'], conf['exp'], f'model_{int(conf["iterations"])}.pt'))
    model.load_state_dict(ckpt)
    # randomly pick one word as the initial input
    model.eval()
    inputs = torch.LongTensor([[vocab['<s>'], vocab['▁C'], vocab['AM'], vocab['ILL'], vocab['O'], vocab[':']]]).to(device)
    result = model.generate(inputs, 64, vocab['</s>'])
    res = decode(result, reverse_vocab)
    print(res)
    with open(os.path.join(conf['ckpt_path'], conf['exp'], f'output_{int(conf["iterations"])}.txt'), 'a+') as f:
        f.write(res)
        f.write('\n----------------------------------------------\n')


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


7986 tokens in total
0.368633 M input tokens in total
33.359872 M parameters
29.27104 M non-embedding parameters

 CAMILLO:
Come, come; 'tis a second to me:
I have no hat you know, sir:
You do not wedged, I have underta'enLoved
His regiment lies half a mile at least
South from the mighty power of the king.



In [2]:
%load_ext tensorboard
%tensorboard --logdir ./logs

Reusing TensorBoard on port 6006 (pid 301142), started 0:01:24 ago. (Use '!kill 301142' to kill it.)

# Parameters & FLOP

## Parameter

Notation that will be used: 

|   symbol    |                           meaning                            |
| :---------: | :----------------------------------------------------------: |
|     $V$     |                      size of vocabulary                      |
|     $N$     | number of encoder layers and decoder layers (assumed to be the same here for simplicity) |
| $d_{model}$ |                length of the embedding vector                |
|  $d_{ff}$   | size of the hidden layer of the feedforward network, usually $4d_{model}$ |
|     $h$     |       number of the heads used in multi-head attention       |
|    $d_k$    | length of query vectors and key vectors, usually $h\cdot d_k = d_{model}$ |
|    $d_v$    |          length of value vectors, usually $d_v=d_k$          |

Firstly, we adopt the same hyper-parameters as the base model of the transformer model, and that is

- $d_{model}=512$
- $d_{ff}=2048$
- $h=8$
- $d_k=d_v=64$
- $V=7986\approx8000$ (given by llama2 tokenizer on [tiny Shakespeare dataset](https://huggingface.co/datasets/tiny_shakespeare))

We use the decoder-only transformer model. Theoretically we can have

- parameters of embeddings: $P_{emb}=Vd_{model}$
- parameters of one layer normalization: $P_{layernorm}=2d_{model}$
- parameters of multi-head attention: $P_{multi}=2hd_kd_{model}+2hd_vd_{model}+P_{layernorm}$
- parameters of the feedforward network: $P_{ff}=2d_{model}d_{ff}+P_{layernorm}=8d_{model}^2+P_{layernorm}$
- parameters of one decoder layer: $P_{declayer}=P_{multi}+P_{ff}$
- parameters of the projection layer: $P_{proj}=Vd_{model}$
- parameters of the non-embeddings: $P_{model}=NP_{declayer}+P_{proj}=N(12d_{model}^2+4d_{model})+Vd_{model}$
- parameters of the full model with embeddings: $P_{full}=P_{emb}+NP_{declayer}+P_{proj}=N(12d_{model}^2+4d_{model})+2Vd_{model}$

With the equation, we can tell that theoretically the modal has **22974976** parameters. 

Experimentally, we can calculate the number of parameters with code like `sum(p.numel() for p in model.parameters())`, and we will obtain the same number **22974976**. 



## FLOP

Before calculation, some extra notations need to be demonstrated:

|   symbol    |                         meaning                          |
| :---------: | :------------------------------------------------------: |
|    $N_b$    |               batch size used in training                |
|  $d_{ctx}$  |                    length of context                     |
| $d_{token}$ | how many tokens used in training, $d_{token}=d_{ctx}N_b$ |
|    $N_i$    |               total iterations of training               |

Firstly, we should obtain the FLOPs of the used GPU. With the `deviceQuery` tool from cuda, we can know that RTX 4090 has 16384 cores and a max clock rate of 2.6 GHz. Then we can approximately calculate the computation capability by $2 * 2.6 * 16384=85196.8\ GFLOPs\approx 85TFLOPs$. And there is another way to justify the FLOPs, that is to check out [the whitepaper published by NVIDIA](https://images.nvidia.com/aem-dam/Solutions/Data-Center/l4/nvidia-ada-gpu-architecture-whitepaper-v2.1.pdf). From the whitepaper we can know the official number of RTX 4090 is **82.6 TFLOPs**, which is quite similar to that calculated by hand. So we can confirm that the GPU can perform 82.6 TFLOP per second. 

Let's calculate the FLOP of the decoder-only transformer model. Theoretically we have 

- FLOP of embedding and encoding: $C_{emb}=d_{token}d_{model}+d_{token}d_{model}=2d_{token}d_{model}$
- FLOP of one layer normalization on input of $(d_{token}, d_{model})$: $C_{layernorm}=8d_{token}d_{model}$
- FLOP of softmax on input of $(d_{token}, d_{model})$: $C_{softmax}=3d_{token}d_{model}$
- FLOP of self attention: $C_{self}=2d_kd_{token}^2+d_{token}^2+2d_{token}^2+3d_{token}^2+2d_kd_{token}^2=(4d_k+3)d_{token}^2$
- FLOP of multi-head attention: $C_{multi}=3*2d_{token}d_{model}^2+hC_{self}+2d_{token}d_{model}^2+d_{token}d_{model}+C_{layernorm}$
- FLOP of the feedforward network on input of $(d_{token}, d_{model})$: $C_{ff}=4d_{token}d_{model}d_{ff}+d_{token}d_{model}+C_{layernorm}=16d_{token}d_{model}^2+9d_{token}d_{model}$
- FLOP of one decoder layer: $C_{declayer}=C_{ff}+C_{multi}$
- FLOP of the projection layer: $C_{proj}=2Vd_{token}d_{model}$
- FLOP of the model per token: $C=(C_{emb}+NC_{declayer}+C_{proj})/d_{token}\approx N(24d_{model}^2+18d_{model})+2Vd_{model}$

From the equations above, we can tell that as $N$ and $d_{model}$ increase, $C$ and $2P_{model}$ would become closer and closer, and we will obtain the famous scaling law $C=2P$. 

We specify all the hyper-parameters and calculate expected FLOP specifically with 

- $d_{ctx}=256$
- $N_b=32$
- $N_i=100000$

The expected FLOP per token would be **202.76 M**. Further we can have the expected total FLOP to be **498308 TFLOP** (3 times the calculation result including forward and backward), and the expected training time would be around **6076 seconds**. 

Experimentally, we can simply record the training used with the specified hyper-parameters, and the total training time is around **6000 seconds**. That is quite acceptable.  


# Experiments

Firstly we need to process the dataset. The tokenizer of llama2 includes some special tokens, like `<s>` specifying the beginning of the sequence and `</s>` specifying the ending of the sequence. Although the problem description askes us to generate *poems* of Shakespeare, but in fact the input data consists of multiple *plays* rather than *poems*. I added `<s>` and `</s>` to mark each dialogue as one sequence. 

Then I trained the model with the hyper-parameters like

```python
conf = {
    'exp': 'base',
    # length of embedding vector
    'emb_dim': 512,
    'ff_dim': 512*4,
    'heads': 8,
    'decoder_layers': 8,
    # how many tokens processed at a time
    'context_length': 256,
    'batch_size': 32,
    'iterations': 8e4,
}
```

The input data contains **0.36** tokens in total but the model has **29 M** non-embedding parameters. So we can clearly observe over-fitting. The validation loss becomes larger and larger. 

So a smaller model is necessary. I tried another model like 

```python
conf = {
    'exp': 'trial2',
    # length of embedding vector
    'emb_dim': 24,
    'ff_dim': 24*4,
    'heads': 8,
    'decoder_layers': 6,
    # how many tokens processed at a time
    'context_length': 256,
    'batch_size': 32,
    'iterations': 6e4,
    
}
```

And the over-fitting becomes weaker. The loss curve becomes more beautiful too. 

From base model: 

```
 CAMILLO:
Do you blanks too:
You know the just cause remove.

----------------------------------------------

 CAMILLO:
I would not stay awhile.

----------------------------------------------

 CAMILLO:
Well, my lord.

----------------------------------------------
```

From biggerlr model:

```
 CAMILLO:
O, not depart to Warwick, how long by.
I'I twice, be great: hold you be still: knock and threats
Great-board, from his majesty as flower of words,
dy, but thought of how o' hence and loathed the rottench
----------------------------------------------

 CAMILLO:
I am a little world is sweet boy.

----------------------------------------------

 CAMILLO:
Are you well; it?

----------------------------------------------
```

From smallerlr model:

```
 CAMILLO:
She be made flatter? I know for my bottomister from Here kind
Ege,
That a disposition she lay that I saw my table North by the Duke:
Hounds, or thy father

----------------------------------------------

 CAMILLO:
Ineless hour serviceured Edward.

----------------------------------------------

 CAMILLO:
My son, you to:
The ranksithmetic makes the green; Lord H Watch rifketh herraint,
Butigter by earth
Oxmpold, a devil.

----------------------------------------------
```

From trial2 model:

```
 CAMILLO:
Your marriage?

----------------------------------------------

 CAMILLO:
O me?

----------------------------------------------

 CAMILLO:
Go, sir, know it late not is cruel
My humble that standtesy impaleful step thee well
But well! in the heaven.

----------------------------------------------
```