# Step #1 : Self-Supervised Learning (SSL) of Language Model (LM)

## Task : from context, predict next word

### Xavier Bresson, xavier.bresson@gmail.com, https://twitter.com/xbresson

### Number of data points for GPT-3, 175B parameters
+ Step #1 : 300B tokens
+ Step #2 : 10k-100k pairs (prompt, response)
+ Step #3 : 100k-1M triples (prompt, positive response, negative response)
+ Step #4 : 10k-100k prompts

### Number of data points for this tutorial
+ Step #1 : 3M tokens
+ Step #2 : 10k pairs (prompt, response)
+ Step #3 : 10k triples (prompt, positive response, negative response)
+ Step #4 : 1k prompts

### Objectives
+ Step-by-step approach to self-supervised learning of LM
+ Implementation of word prediction with { previous word, bag-of-words, attention mechanism }
+ Implement Transformer architecture with single head, multiple heads, PE, RC, LN, MLP, and dropout
+ Train with batch of sequences for fast training with GPU
+ Save the pre-trained LM network for Step #2


In [1]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

Mounted at /content/gdrive
/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs
/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs


In [2]:
# Libraries
import torch
print(torch.__version__)
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings
import os, datetime


2.1.0+cu118


## Time stamp for save/load data


In [4]:
# save time stamp
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")

# check dataset folder exists
data_dir = os.path.join('dataset')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# select a time stamp
use_saved_time_stamp = False
#use_saved_time_stamp = True
if use_saved_time_stamp:
    time_stamp = '23-11-23--12-26-17' # trained on GPU on '23-11-23--12-26-17'

print('time_stamp:', time_stamp, '\n')


time_stamp: 23-12-04--10-32-53 



## Generate training sequence

In [5]:
# generate arithmetic series
m = max_value = 100 # maximum value in the sequence
def arithmetic_series(m, s, d, n):
    seq = []
    for i in range(n):
        v = s + i * d
        if v <= m:
            seq.append(v)
        else:
            break
    return seq

# generate training data, i.e. a long sequence of tokens as
#  seq = [ 2, 4, 6, <SEP>, 14, 17, 20, 22, <SEP>, 6, 8, ... ]
save_training_data = False
save_training_data = True
if save_training_data:

    # parameters for arithmetic series
    m = max_value # maximum value in the sequence
    s = torch.randint(low=0, high=m, size=(1,)).item() # starting integer of the series
    d = torch.randint(low=1, high=10, size=(1,)).item() # value of common difference
    n = torch.randint(low=5, high=15, size=(1,)).item() # number of element in the series
    print('max_value: %d, start_value: %d, common_difference: %d, number_of_terms: %d' % (m,s,d,n))
    seq = arithmetic_series(m,s,d,n)
    print('an arithmetic series:',seq)

    # generate and save a sequence of arithmetic series separated with token <SEP>
    len_dataset = 100 # debug, e.g. 100
    len_dataset = 3000000 # length of the sequence of arithmetic series, e.g. 3M
    seq = []
    separator_token = '<SEP>' # separator token between series
    start = time.time()
    while len(seq)<=len_dataset:
        s = torch.randint(low=0, high=m, size=(1,)).item() # starting integer of the series
        d = torch.randint(low=1, high=10, size=(1,)).item() # value of common difference
        n = torch.randint(low=5, high=15, size=(1,)).item() # number of element in the series
        series = arithmetic_series(m,s,d,n) # generate arithmetic series
        series_token = [str(i) for i in series] # convert seq of integers into seq of tokens w/ string type
        series_token.append(separator_token) # append separator token
        seq.extend(series_token) # append one generated series to the sequence
    seq_tokens = seq[:len_dataset] # truncate the sequence to "len_dataset" number of tokens
    print('len(seq) data: %d, time(sec): %.3f' % (len(seq_tokens), time.time()-start) )

    # print
    print('number of tokens in the sequence :',len(seq_tokens))
    print('print first 50 tokens :',seq_tokens[:50],'\n')

    # save training data
    save_file = data_dir + '/step1_01_SSL_training_set_token_' + time_stamp + '.pt'
    print('save_file:', save_file, '\n')
    torch.save([seq_tokens],save_file) # save the sequence

else:

    # load data
    load_file = data_dir + '/step1_01_SSL_training_set_token_' + time_stamp + '.pt'
    print('load_file:', load_file, '\n')
    seq_tokens = torch.load(load_file)[0]
    print('number of tokens in the sequence :',len(seq_tokens))
    print('print first 50 tokens :',seq_tokens[:50])


max_value: 100, start_value: 85, common_difference: 5, number_of_terms: 6
an arithmetic series: [85, 90, 95, 100]
len(seq) data: 3000000, time(sec): 8.064
number of tokens in the sequence : 3000000
print first 50 tokens : ['45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '<SEP>', '87', '94', '<SEP>', '76', '81', '86', '91', '96', '<SEP>', '78', '87', '96', '<SEP>', '47', '51', '55', '59', '63', '67', '<SEP>', '14', '17', '20', '23', '26', '29', '32', '35', '38', '41', '44', '47', '50', '53', '<SEP>'] 

save_file: dataset/step1_01_SSL_training_set_token_23-12-04--10-32-53.pt 



## Get dictionary of tokens and convert sequence of tokens to integers

In [6]:
save_dictionary = False
save_dictionary = True
if save_dictionary:

    # create the dictionary of tokens by extracting unique tokens (words)
    load_file = data_dir + '/step1_01_SSL_training_set_token_' + time_stamp + '.pt'
    print('load_file:', save_file, '\n')
    print('number of tokens in the sequence :',len(seq_tokens),'\n')
    dictionary = []
    num_tokens = 0
    for token in seq_tokens:
        if token not in dictionary:
            dictionary.append(token)
            num_tokens += 1
    print('dictionary:',dictionary,'\n')
    print('num_tokens (unique):',num_tokens,'\n')

    # add tokens to the dictionary for step #2
    tokens_for_step2 = ['generate', 'an', 'arithmetic', 'series', 'with', 'terms', 'starting', 'value', \
                        'and', 'common', 'difference', 'Let', 'be', 'the', 'number', 'of', 'then', 'write', \
                        'make', 'a', 'type', 'which', 'starts', 'at', 'elements', '<PAD>', '<EOS>']
    for token in torch.arange(m).tolist(): tokens_for_step2.append(str(token))
    # update dictionary
    for token in tokens_for_step2:
        if token not in dictionary:
            dictionary.append(token); num_tokens += 1
    print('updated dictionary:',dictionary,'\n')
    print('num_tokens (unique):',num_tokens,'\n')

    # token2index : dict w/ key=token(str) and value=index(int)
    # index2token : dict w/ key=index(int) and value=token(str)
    token2index = { token:index for index,token in enumerate(dictionary) }
    index2token = { index:token for index,token in enumerate(dictionary) }
    print('token2index:', token2index,'\n')
    print('index2token:', index2token,'\n')

    # func_tokens2indices : function that converts token (str) to indices (int) for token embedding
    # func_indices2tokens : function that converts indices (int) to token (str)
    # func_str2tokens : function that converts a string into tokens (str)
    # func_tokens2str : function that converts tokens (str) to a string
    func_tokens2indices = lambda list_tokens: [token2index[token] for token in list_tokens] # ['Let', '5', 'be', 'the'] => [113, 46, 114, 115]
    func_indices2tokens = lambda list_ints: [index2token[integer] for integer in list_ints] # [113, 46, 114, 115] => ['Let', '5', 'be', 'the']
    func_str2tokens = lambda input_str: [token_str for token_str in input_str.split()]      # 'Let 5 be the' => ['Let', '5', 'be', 'the']
    func_tokens2str = lambda list_str: ' '.join(list_str)                                   # ['Let', '5', 'be', 'the'] => 'Let 5 be the'

    # example
    seq_token = seq_tokens[:10] # first tokens
    print('seq_token:', seq_token,'\n')
    seq_ind = func_tokens2indices(seq_token) # token (str) to indices (int)
    print('seq_ind:', seq_ind,'\n')
    seq_token = func_indices2tokens(seq_ind) # indices (int) to token (str)
    print('seq_token:', seq_token,'\n')

    # convert long seq from tokens to torch integers for training
    seq = torch.tensor(func_tokens2indices(seq_tokens))
    print('number of tokens in the sequence :',seq.size(0),'\n')

    # save dictionary and training data
    save_file_dictionary = data_dir + '/step1_02_SSL_dictionary_' + time_stamp + '.pt'
    print('save_file_dictionary:', save_file_dictionary, '\n')
    torch.save([dictionary, num_tokens, token2index, index2token], save_file_dictionary) # save dictionary of tokens
    save_file_seq = data_dir + '/step1_03_SSL_training_set_int_' + time_stamp + '.pt'
    print('save_file_seq:', save_file_seq, '\n')
    torch.save([seq], save_file_seq) # save the sequence of integers

else:

    # load dictionary and training data
    load_file_dictionary = data_dir + '/step1_02_SSL_dictionary_' + time_stamp + '.pt'
    print('load_file_dictionary:', load_file_dictionary, '\n')
    dictionary, num_tokens, token2index, index2token = torch.load(load_file_dictionary) # load dictionary of tokens
    load_file_seq = data_dir + '/step1_03_SSL_training_set_int_' + time_stamp + '.pt'
    print('load_file_seq:', load_file_seq, '\n')
    seq = torch.load(load_file_seq)[0] # load the sequence of integers

    # print
    print('dictionary:',dictionary,'\n')
    print('num_tokens (unique):',num_tokens,'\n')
    print('token2index:', token2index,'\n')
    print('index2token:', index2token,'\n')
    print('number of tokens in the sequence :',len(seq),'\n')
    func_tokens2indices = lambda list_tokens: [token2index[token] for token in list_tokens] # ['Let', '5', 'be', 'the'] => [113, 46, 114, 115]
    func_indices2tokens = lambda list_ints: [index2token[integer] for integer in list_ints] # [113, 46, 114, 115] => ['Let', '5', 'be', 'the']
    func_str2tokens = lambda input_str: [token_str for token_str in input_str.split()]      # 'Let 5 be the' => ['Let', '5', 'be', 'the']
    func_tokens2str = lambda list_str: ' '.join(list_str)                                   # ['Let', '5', 'be', 'the'] => 'Let 5 be the'

    # example
    seq_token = seq_tokens[:10] # first tokens
    print('seq_token:', seq_token,'\n')
    seq_ind = func_tokens2indices(seq_token) # token (str) to indices (int)
    print('seq_ind:', seq_ind,'\n')
    seq_token = func_indices2tokens(seq_ind) # indices (int) to token (str)
    print('seq_token:', seq_token,'\n')


load_file: dataset/step1_01_SSL_training_set_token_23-12-04--10-32-53.pt 

number of tokens in the sequence : 3000000 

dictionary: ['45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '<SEP>', '87', '94', '76', '81', '86', '91', '96', '78', '59', '63', '67', '14', '17', '20', '23', '26', '29', '32', '35', '38', '41', '44', '42', '62', '72', '77', '82', '92', '97', '65', '18', '28', '33', '43', '73', '74', '75', '79', '80', '83', '84', '0', '7', '21', '70', '30', '34', '36', '40', '31', '37', '39', '61', '64', '71', '89', '95', '93', '99', '25', '27', '66', '10', '12', '16', '22', '24', '85', '60', '68', '100', '98', '88', '69', '1', '13', '19', '90', '6', '11', '9', '15', '2', '3', '4', '5', '8'] 

num_tokens (unique): 102 

updated dictionary: ['45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '<SEP>', '87', '94', '76', '81', '86', '91', '96', '78', '59', '63', '67', '14', '17', '20', '23', '26', '29', '32', '35', '38'

## Get batch of sub-sequences

In [7]:
# prepare batch of sub-sequences of the long sequence of tokens
#
#                              seq_len
#                  ------------------------------
# seq            = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, ... ]
#                       |<= start_idx = 2 (randomly selected in [0,1,...,batch_length-1])
#                       -------
#                     batch_length = 3 tokens
#
# batch_seq      = [ [0, 6, 3],   |
#                    [2, 9, 5],   | batch_size
#                    [2, 3, 4] ]  |
#                     -------
#                   batch_length
#
# batch_target =   [ [6, 3, 8],   |
#  = batch_seq + 1   [9, 5, 6],   | batch_size
#                    [3, 4, 5] ]  |
#                     -------
#                   batch_length
#

# parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 3; batch_length = 6 # debug
batch_size = 100; batch_length = 100 # GPU
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# create batch of sub-sequences
def get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx):
    batch_idx = torch.randperm(list_batch_idx.size(0))[:batch_size] # sample "batch_size" number of batch indices
    batch_idx = list_batch_idx[batch_idx] # sample from remaining list of batch indices
    batch_seq = torch.stack([seq[start_idx+i*batch_length : start_idx+(i+1)*batch_length] for i in batch_idx]) # extract batch at start_idx with batch_length, size=[batch_size, batch_length]
    target_seq = torch.stack([seq[start_idx+i*batch_length+1 : start_idx+(i+1)*batch_length+1] for i in batch_idx]) # target = batch_seq shifted by +1 to predict next token, size=[batch_size, batch_length]
    if list_batch_idx.size(0) > batch_size:
        list_batch_idx = torch.stack([i for i in list_batch_idx if i not in batch_idx]) # remove sampled batch indices from the list of batch indices, size=[rem_num_batch_indices]
    else:
        list_batch_idx = torch.tensor([]) # last batch of epoch, size=[] (empty tensor)
    return batch_seq, target_seq, list_batch_idx

# print example
for _ in range(1): # e.g. num_batch for one full epoch
    print('list_batch_idx (before) :',list_batch_idx, list_batch_idx.size(),'\n')
    batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
    print('batch_seq               :',batch_seq, batch_seq.size(),'\n')
    print('target_seq              :',target_seq, target_seq.size(),'\n')
    print('list_batch_idx (after)  :',list_batch_idx, list_batch_idx.size(),'\n')


seq_len: 3000000, batch_size: 100, batch_length: 100, num_subseq: 30000, num_batch: 300

list_batch_idx (before) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157,

## Define class of token embedding

In [8]:
# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# print example
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
list_batch_idx = torch.arange(num_batch) # size=[batch_size]
batch_int, _, _ = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch, size=[batch_size, batch_length]
print('batch_int :',batch_int.size())

token2vec_layer = token2vec(num_tokens, d=128)
batch_vec = token2vec_layer(batch_int) # size=[batch_size, batch_length, d=128]
print('batch_vec :',batch_vec.size())


batch_int : torch.Size([100, 100])
batch_vec : torch.Size([100, 100, 128])


## Vanilla LM : Predict next token(t+1) given context = {current token : token(t)}

In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

#    seq = [ 2, 3, 4, 5, 6 ]
#               - <= context = { current token } = 3
#               | <= predict next token = 4
# target = [ 3, 4, 5, 6, 7 ]
#               | <= score vector v must predict token "4"
# scores = [ v, v, v, v, v ]
#
# batch_seq      = [ [ 2, 3, 4, 5, 6 ],   |
#                    [ 2, 9, 5, 7, 3 ],   | batch_size
#                    [ 0, 6, 3, 5, 2 ] ]  |
#                     --------------
#                       batch_length
#
# batch_scores =   [ [ v, v, v, v, v ],   |
#                    [ v, v, v, v, v ],   | batch_size, v is a vector of "num_tokens" dimensions
#                    [ v, v, v, v, v ] ]  |
#                     --------------
#                      batch_length
#
# batch_target =   [ [ 3, 4, 5, 6, 7 ],   |
#  = batch_seq + 1   [ 9, 5, 7, 3, 1 ],   | batch_size
#                    [ 6, 3, 5, 2, 8 ] ]  |
#                     --------------
#                      batch_length
#
class vanillaLM(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        batch_seq_vec = self.token2vec(batch_seq) # size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
print('num_tokens: %d, d: %d\n' % (num_tokens, d) )
vanillaLMnet = vanillaLM(num_tokens, d)
num_param = number_param(vanillaLMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(vanillaLMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = vanillaLMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128

num_net_parameters: 32896 / 0.03 million

Epoch: 0, time(sec): 0.062, lr= 0.000300, loss_epoch: 4.982
Epoch: 10, time(sec): 0.084, lr= 0.000300, loss_epoch: 4.766
Epoch: 20, time(sec): 0.100, lr= 0.000300, loss_epoch: 4.504
Epoch: 30, time(sec): 0.114, lr= 0.000300, loss_epoch: 4.216
Epoch: 40, time(sec): 0.130, lr= 0.000300, loss_epoch: 3.969
Epoch: 50, time(sec): 0.147, lr= 0.000300, loss_epoch: 3.732
Epoch: 60, time(sec): 0.164, lr= 0.000300, loss_epoch: 3.560
Epoch: 70, time(sec): 0.177, lr= 0.000300, loss_epoch: 3.353
Epoch: 80, time(sec): 0.190, lr= 0.000300, loss_epoch: 3.117
Epoch: 90, time(sec): 0.213, lr= 0.000300, loss_epoch: 2.897
Epoch: 100, time(sec): 0.233, lr= 0.000300, loss_epoch: 2.765


## Bag-Of-Token LM : Predict next token(t+1) given context = {bag of tokens : token(<=t)}
### Aggregator of tokens is the mean operator (as a bag-of-token is a set of unordered tokens)


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

#    seq = [ 2, 3, 4, 5, 6 ]
#            ---------- <= context = { tokens(<=t) } = 2,3,4,5
#                     | <= predict next token = 6
# target = [ 3, 4, 5, 6, 7 ]
#                     | <= score vector v must predict token "6"
# scores = [ v, v, v, v, v ]
#
# triu(ones(3,3)) = [ [1, 0, 0]
#                     [1, 1, 0]
#                     [1, 1, 1] ]
#
class BOT_LM(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        batch_size = batch_seq.size(0); batch_len = batch_seq.size(1)
        batch_seq_vec = self.token2vec(batch_seq) # size=[batch_size, batch_length, d]
        mean_operator = torch.tril(torch.ones(batch_len,batch_len)).long() # mask to use previous tokens only : { token(<=t) }, size=[batch_len,batch_len]
        mean_operator = mean_operator/ torch.sum(mean_operator, dim=1).unsqueeze(1) # normalize w.r.t. number of previous tokens
        mean_operator = mean_operator.repeat(batch_size,1,1) # repeat masks batch_size times, size=(batch_size, batch_len, batch_len)
        batch_seq_vec =  mean_operator @ batch_seq_vec # matrix-matrix multiplication (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, num_tokens)
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token


# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
print('num_tokens: %d, d: %d\n' % (num_tokens, d) )
BOT_LMnet = BOT_LM(num_tokens, d)
num_param = number_param(BOT_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(BOT_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = BOT_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128

num_net_parameters: 32896 / 0.03 million

Epoch: 0, time(sec): 0.004, lr= 0.000300, loss_epoch: 4.862
Epoch: 10, time(sec): 0.040, lr= 0.000300, loss_epoch: 4.763
Epoch: 20, time(sec): 0.066, lr= 0.000300, loss_epoch: 4.674
Epoch: 30, time(sec): 0.087, lr= 0.000300, loss_epoch: 4.585
Epoch: 40, time(sec): 0.108, lr= 0.000300, loss_epoch: 4.436
Epoch: 50, time(sec): 0.126, lr= 0.000300, loss_epoch: 4.419
Epoch: 60, time(sec): 0.143, lr= 0.000300, loss_epoch: 4.331
Epoch: 70, time(sec): 0.159, lr= 0.000300, loss_epoch: 4.258
Epoch: 80, time(sec): 0.177, lr= 0.000300, loss_epoch: 4.164
Epoch: 90, time(sec): 0.194, lr= 0.000300, loss_epoch: 4.067
Epoch: 100, time(sec): 0.219, lr= 0.000300, loss_epoch: 4.025


## Vanilla Self-Attention LM : Predict next token(t+1) given context = {token(<=t)}
### Aggregator of tokens is the self-attention operator : $\textrm{softmax}( HH^T / \sqrt{d} ) H$



In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

class VSA_LM(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        batch_size = batch_seq.size(0); batch_len = batch_seq.size(1)
        H = self.token2vec(batch_seq) # size=[batch_size, batch_length, d]
        attention_score = H @ H.transpose(2,1) * H.size(2)**-0.5 # HH^T/sqrt(d), (B,L,d) @ (B,d,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        mask = torch.tril(torch.ones(batch_len,batch_len)).long() # mask to use previous tokens only : { token(<=t) }, size=[batch_len,batch_len]
        attention_score = attention_score.masked_fill(mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len)
        batch_seq_vec = attention_score @ H # softmax( HH^T / sqrt(d) ) H, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d)
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
print('num_tokens: %d, d: %d\n' % (num_tokens, d) )
VSA_LMnet = VSA_LM(num_tokens, d)
num_param = number_param(VSA_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(VSA_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = VSA_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128

num_net_parameters: 32896 / 0.03 million

Epoch: 0, time(sec): 0.004, lr= 0.000300, loss_epoch: 4.982
Epoch: 10, time(sec): 0.040, lr= 0.000300, loss_epoch: 4.766
Epoch: 20, time(sec): 0.062, lr= 0.000300, loss_epoch: 4.504
Epoch: 30, time(sec): 0.089, lr= 0.000300, loss_epoch: 4.216
Epoch: 40, time(sec): 0.106, lr= 0.000300, loss_epoch: 3.969
Epoch: 50, time(sec): 0.121, lr= 0.000300, loss_epoch: 3.732
Epoch: 60, time(sec): 0.136, lr= 0.000300, loss_epoch: 3.560
Epoch: 70, time(sec): 0.152, lr= 0.000300, loss_epoch: 3.354
Epoch: 80, time(sec): 0.167, lr= 0.000300, loss_epoch: 3.117
Epoch: 90, time(sec): 0.183, lr= 0.000300, loss_epoch: 2.898
Epoch: 100, time(sec): 0.198, lr= 0.000300, loss_epoch: 2.765


## Standard Self-Attention LM : Predict next token(t+1) given context = {token(<=t)}
### Aggregator of tokens is the self-attention operator : $\textrm{softmax}( QK^T / \sqrt{d} ) V$
### with learnable dictionary Q=Query, K=Key, V=Value


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

class SA_LM(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.query = nn.Linear(d, d, bias=False) # query embedding layer
        self.key = nn.Linear(d, d, bias=False) # key embedding layer
        self.value = nn.Linear(d, d) # value embedding layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        batch_size = batch_seq.size(0); batch_len = batch_seq.size(1)
        H = self.token2vec(batch_seq) # size=[batch_size, batch_length, d]
        Q = self.query(H) # size=[batch_size, batch_length, d]
        K = self.key(H) # size=[batch_size, batch_length, d]
        V = self.value(H) # size=[batch_size, batch_length, d]
        attention_score = Q @ K.transpose(2,1) * H.size(2)**-0.5 # QK^T/sqrt(d), (B,L,d) @ (B,d,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        mask = torch.tril(torch.ones(batch_len,batch_len)).long() # mask to use previous tokens only : { token(<=t) }, size=[batch_len,batch_len]
        attention_score = attention_score.masked_fill(mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len)
        batch_seq_vec = attention_score @ V # softmax( QK^T / sqrt(d) ) V, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d)
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
print('num_tokens: %d, d: %d\n' % (num_tokens, d) )
SA_LMnet = SA_LM(num_tokens, d)
num_param = number_param(SA_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(SA_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = SA_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128

num_net_parameters: 82176 / 0.08 million

Epoch: 0, time(sec): 0.005, lr= 0.000300, loss_epoch: 4.836
Epoch: 10, time(sec): 0.049, lr= 0.000300, loss_epoch: 4.691
Epoch: 20, time(sec): 0.083, lr= 0.000300, loss_epoch: 4.409
Epoch: 30, time(sec): 0.106, lr= 0.000300, loss_epoch: 4.104
Epoch: 40, time(sec): 0.129, lr= 0.000300, loss_epoch: 3.859
Epoch: 50, time(sec): 0.152, lr= 0.000300, loss_epoch: 3.456
Epoch: 60, time(sec): 0.177, lr= 0.000300, loss_epoch: 3.192
Epoch: 70, time(sec): 0.201, lr= 0.000300, loss_epoch: 2.951
Epoch: 80, time(sec): 0.233, lr= 0.000300, loss_epoch: 2.718
Epoch: 90, time(sec): 0.261, lr= 0.000300, loss_epoch: 2.662
Epoch: 100, time(sec): 0.283, lr= 0.000300, loss_epoch: 2.328


## Self-Attention LM : Predict next token(t+1) given context = {token(<=t)}
## Add positional encoding (PE) to self-attention / (single) attention head
###  PE is required to add ordering information to token


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# single head attention layer
class head_attention(nn.Module):
    def __init__(self, d, context_length):
        super().__init__()
        self.query = nn.Linear(d, d, bias=False) # query embedding layer
        self.key = nn.Linear(d, d, bias=False) # key embedding layer
        self.value = nn.Linear(d, d) # value embedding layer
        self.mask = torch.tril(torch.ones(context_length, context_length)).long() # mask to use previous tokens only : { token(<=t) }, size=[context_length, context_length]
    def forward(self, H):
        Q = self.query(H) # size=[batch_size, batch_length, d]
        K = self.key(H) # size=[batch_size, batch_length, d]
        V = self.value(H) # size=[batch_size, batch_length, d]
        attention_score = Q @ K.transpose(2,1) * H.size(2)**-0.5 # QK^T/sqrt(d), (B,L,d) @ (B,d,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        attention_score = attention_score.masked_fill(self.mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len]
        H_head = attention_score @ V # softmax( QK^T / sqrt(d) ) V, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d]
        return H_head

class PE_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.HA = head_attention(d, context_length) # self-attention layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        batch_seq_vec = self.HA(H) # (single) attention head, size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
print('num_tokens: %d, d: %d, batch_length: %d\n' % (num_tokens, d, batch_length) )
PE_LMnet = PE_LM(num_tokens, d, batch_length)
num_param = number_param(PE_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(PE_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = PE_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20

num_net_parameters: 84736 / 0.08 million

Epoch: 0, time(sec): 0.007, lr= 0.000300, loss_epoch: 4.904
Epoch: 10, time(sec): 0.053, lr= 0.000300, loss_epoch: 4.706
Epoch: 20, time(sec): 0.086, lr= 0.000300, loss_epoch: 4.494
Epoch: 30, time(sec): 0.112, lr= 0.000300, loss_epoch: 4.179
Epoch: 40, time(sec): 0.136, lr= 0.000300, loss_epoch: 3.960
Epoch: 50, time(sec): 0.160, lr= 0.000300, loss_epoch: 3.759
Epoch: 60, time(sec): 0.183, lr= 0.000300, loss_epoch: 3.465
Epoch: 70, time(sec): 0.208, lr= 0.000300, loss_epoch: 3.106
Epoch: 80, time(sec): 0.235, lr= 0.000300, loss_epoch: 3.159
Epoch: 90, time(sec): 0.264, lr= 0.000300, loss_epoch: 2.797
Epoch: 100, time(sec): 0.286, lr= 0.000300, loss_epoch: 2.823


## LM with Multiple Attention Heads  


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# single head attention layer
class head_attention(nn.Module):
    def __init__(self, d, d_head, context_length):
        super().__init__()
        self.query = nn.Linear(d, d_head, bias=False) # query embedding layer
        self.key = nn.Linear(d, d_head, bias=False) # key embedding layer
        self.value = nn.Linear(d, d_head) # value embedding layer
        self.mask = torch.tril(torch.ones(context_length, context_length)).long() # mask to use previous tokens only : { token(<=t) }, size=[context_length, context_length]
    def forward(self, H):
        Q = self.query(H) # size=[batch_size, batch_length, d_head]
        K = self.key(H) # size=[batch_size, batch_length, d_head]
        V = self.value(H) # size=[batch_size, batch_length, d_head]
        attention_score = Q @ K.transpose(2,1) * H.size(2)**-0.5 # QK^T/sqrt(d_head), (B,L,d_head) @ (B,d_head,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        attention_score = attention_score.masked_fill(self.mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len]
        H_head = attention_score @ V # softmax( QK^T / sqrt(d_head) ) V, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d_head]
        return H_head

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads):
        super().__init__()
        d_head = d // num_heads # dim_head = d / num_heads, usually dimension per head is 64
        assert d == d_head * num_heads # check divisibility
        self.MHA = nn.ModuleList([ head_attention(d, d_head, context_length) for _ in range(num_heads) ])
        self.combined_heads = nn.Linear(d, d) # combination layer
    def forward(self, H):
        H_heads = []
        for HA_layer in self.MHA:
            H_heads.append(HA_layer(H)) # size=[batch_size, batch_length, d_head]
        H_heads = torch.cat(H_heads, dim=2) # size=[batch_size, batch_length, d]
        H_heads = self.combined_heads(H_heads) # size=[batch_size, batch_length, d]
        return H_heads

class MHA_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.MHA = multiple_head_attention(d, context_length, num_heads) # multiple self-attention layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        batch_seq_vec = self.MHA(H) # (single) attention head, size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
num_heads = 16
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d\n' % (num_tokens, d, batch_length, num_heads) )
MHA_LMnet = MHA_LM(num_tokens, d, batch_length, num_heads)
num_param = number_param(MHA_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(MHA_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = MHA_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20, num_heads: 16

num_net_parameters: 101248 / 0.10 million

Epoch: 0, time(sec): 0.027, lr= 0.000300, loss_epoch: 4.851
Epoch: 10, time(sec): 0.163, lr= 0.000300, loss_epoch: 4.662
Epoch: 20, time(sec): 0.301, lr= 0.000300, loss_epoch: 4.352
Epoch: 30, time(sec): 0.426, lr= 0.000300, loss_epoch: 4.190
Epoch: 40, time(sec): 0.564, lr= 0.000300, loss_epoch: 3.760
Epoch: 50, time(sec): 0.698, lr= 0.000300, loss_epoch: 3.664
Epoch: 60, time(sec): 0.829, lr= 0.000300, loss_epoch: 3.652
Epoch: 70, time(sec): 0.948, lr= 0.000300, loss_epoch: 3.366
Epoch: 80, time(sec): 1.072, lr= 0.000300, loss_epoch: 3.095
Epoch: 90, time(sec): 1.199, lr= 0.000300, loss_epoch: 3.087
Epoch: 100, time(sec): 1.323, lr= 0.000300, loss_epoch: 2.884


## LM with (single) Transformer Block
## Add residual connection (RC) + layer normalization (LN) + dropout + feedforward / MLP


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# single head attention layer
class head_attention(nn.Module):
    def __init__(self, d, d_head, context_length, dropout):
        super().__init__()
        self.query = nn.Linear(d, d_head, bias=False) # query embedding layer
        self.key = nn.Linear(d, d_head, bias=False) # key embedding layer
        self.value = nn.Linear(d, d_head) # value embedding layer
        self.mask = torch.tril(torch.ones(context_length, context_length)).long() # mask to use previous tokens only : { token(<=t) }, size=[context_length, context_length]
        self.dropout = nn.Dropout(dropout)
    def forward(self, H):
        Q = self.query(H) # size=[batch_size, batch_length, d_head]
        K = self.key(H) # size=[batch_size, batch_length, d_head]
        V = self.value(H) # size=[batch_size, batch_length, d_head]
        attention_score = Q @ K.transpose(2,1) * H.size(2)**-0.5 # QK^T/sqrt(d_head), (B,L,d_head) @ (B,d_head,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        attention_score = attention_score.masked_fill(self.mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len]
        attention_score = self.dropout(attention_score) # dropout attention scores
        H_head = attention_score @ V # softmax( QK^T / sqrt(d_head) ) V, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d_head]
        return H_head

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads # dim_head = d / num_heads, usually dimension per head is 64
        assert d == d_head * num_heads # check divisibility
        self.MHA = nn.ModuleList([ head_attention(d, d_head, context_length, dropout) for _ in range(num_heads) ])
        self.combined_heads = nn.Linear(d, d) # combination layer
        self.dropout = nn.Dropout(dropout)
    def forward(self, H):
        H_heads = []
        for HA_layer in self.MHA:
            H_heads.append(HA_layer(H)) # size=[batch_size, batch_length, d_head]
        H_heads = torch.cat(H_heads, dim=2) # size=[batch_size, batch_length, d]
        H_heads = self.dropout(H_heads) # dropout attention activations
        H_heads = self.combined_heads(H_heads) # size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

class TB_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.TB = TransformerBlock(d, context_length, num_heads, dropout) # transformer block layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        batch_seq_vec = self.TB(H) # (single) transformer block, size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
num_heads = 16
dropout = 0.1
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f\n' % (num_tokens, d, batch_length, num_heads, dropout) )
TB_LMnet = TB_LM(num_tokens, d, batch_length, num_heads, dropout)
num_param = number_param(TB_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(TB_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = TB_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20, num_heads: 16, dropout: 0.10

num_net_parameters: 233472 / 0.23 million

Epoch: 0, time(sec): 0.063, lr= 0.000300, loss_epoch: 5.246
Epoch: 10, time(sec): 0.285, lr= 0.000300, loss_epoch: 4.359
Epoch: 20, time(sec): 0.439, lr= 0.000300, loss_epoch: 3.695
Epoch: 30, time(sec): 0.610, lr= 0.000300, loss_epoch: 3.433
Epoch: 40, time(sec): 0.784, lr= 0.000300, loss_epoch: 2.746
Epoch: 50, time(sec): 0.951, lr= 0.000300, loss_epoch: 2.358
Epoch: 60, time(sec): 1.130, lr= 0.000300, loss_epoch: 2.213
Epoch: 70, time(sec): 1.300, lr= 0.000300, loss_epoch: 1.916
Epoch: 80, time(sec): 1.460, lr= 0.000300, loss_epoch: 1.773
Epoch: 90, time(sec): 1.624, lr= 0.000300, loss_epoch: 1.489
Epoch: 100, time(sec): 1.783, lr= 0.000300, loss_epoch: 1.338


## PyTorch implementation of MHA vs. my implementation


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# single head attention layer
class head_attention(nn.Module):
    def __init__(self, d, d_head, context_length, dropout):
        super().__init__()
        self.query = nn.Linear(d, d_head, bias=False) # query embedding layer
        self.key = nn.Linear(d, d_head, bias=False) # key embedding layer
        self.value = nn.Linear(d, d_head) # value embedding layer
        self.mask = torch.tril(torch.ones(context_length, context_length)).long() # mask to use previous tokens only : { token(<=t) }, size=[context_length, context_length]
        self.dropout = nn.Dropout(dropout)
    def forward(self, H):
        Q = self.query(H) # size=[batch_size, batch_length, d_head]
        K = self.key(H) # size=[batch_size, batch_length, d_head]
        V = self.value(H) # size=[batch_size, batch_length, d_head]
        attention_score = Q @ K.transpose(2,1) * H.size(2)**-0.5 # QK^T/sqrt(d_head), (B,L,d_head) @ (B,d_head,L) => (B,L,L), size=[batch_size, batch_length, batch_length)
        attention_score = attention_score.masked_fill(self.mask==0, value=float('-inf')) # softmax(-inf)=0 prevents using next tokens for prediction, size=(batch_size, batch_len, batch_len)
        attention_score = torch.softmax(attention_score, dim=2) # sum weights = 1, size=[batch_size, batch_length, batch_len]
        attention_score = self.dropout(attention_score) # dropout attention scores
        H_head = attention_score @ V # softmax( QK^T / sqrt(d_head) ) V, (B,L,L) @ (B,L,d) => (B,L,d), size=[batch_size, batch_length, d_head]
        return H_head

# # multiple attention heads layer -- my implementation
# class multiple_head_attention(nn.Module):
#     def __init__(self, d, context_length, num_heads, dropout):
#         super().__init__()
#         d_head = d // num_heads # dim_head = d / num_heads, usually dimension per head is 64
#         assert d == d_head * num_heads # check divisibility
#         self.MHA = nn.ModuleList([ head_attention(d, d_head, context_length, dropout) for _ in range(num_heads) ])
#         self.combined_heads = nn.Linear(d, d) # combination layer
#         self.dropout = nn.Dropout(dropout)
#     def forward(self, H):
#         H_heads = []
#         for HA_layer in self.MHA:
#             H_heads.append(HA_layer(H)) # size=[batch_size, batch_length, d_head]
#         H_heads = torch.cat(H_heads, dim=2) # size=[batch_size, batch_length, d]
#         H_heads = self.dropout(H_heads) # dropout attention activations
#         H_heads = self.combined_heads(H_heads) # size=[batch_size, batch_length, d]
#         return H_heads

# multiple attention heads layer -- PyTorch implementation
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
    def forward(self, H):
        H_heads = self.MHA(H, H, H, attn_mask=self.mask)[0] # size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

class TBpytorch_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.TB = TransformerBlock(d, context_length, num_heads, dropout) # transformer block layer
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        batch_seq_vec = self.TB(H) # (single) transformer block, size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(batch_seq_vec) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
num_heads = 16
dropout = 0.1
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f\n' % (num_tokens, d, batch_length, num_heads, dropout) )
TB_LMnet = TB_LM(num_tokens, d, batch_length, num_heads, dropout)
num_param = number_param(TB_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(TB_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = TB_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )

# my implementation      : Time(sec): 17.707 / loss_epoch: 1.387  => 3x slower
# pytorch implementation : Time(sec): 6.556 / loss_epoch: 0.940


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20, num_heads: 16, dropout: 0.10

num_net_parameters: 233728 / 0.23 million

Epoch: 0, time(sec): 0.013, lr= 0.000300, loss_epoch: 5.363
Epoch: 10, time(sec): 0.079, lr= 0.000300, loss_epoch: 4.508
Epoch: 20, time(sec): 0.128, lr= 0.000300, loss_epoch: 3.843
Epoch: 30, time(sec): 0.178, lr= 0.000300, loss_epoch: 3.320
Epoch: 40, time(sec): 0.228, lr= 0.000300, loss_epoch: 2.939
Epoch: 50, time(sec): 0.280, lr= 0.000300, loss_epoch: 2.553
Epoch: 60, time(sec): 0.323, lr= 0.000300, loss_epoch: 2.078
Epoch: 70, time(sec): 0.370, lr= 0.000300, loss_epoch: 1.795
Epoch: 80, time(sec): 0.417, lr= 0.000300, loss_epoch: 1.796
Epoch: 90, time(sec): 0.469, lr= 0.000300, loss_epoch: 1.385
Epoch: 100, time(sec): 0.519, lr= 0.000300, loss_epoch: 1.091


## LM with multiple Transformer Blocks


In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
    def forward(self, H):
        H_heads = self.MHA(H, H, H, attn_mask=self.mask)[0] # size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

class MTB_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        for transformer_block in self.transformer_blocks:
            H = transformer_block(H) # size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(H) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
num_heads = 16
dropout = 0.1
num_layers = 2
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, batch_length, num_heads, dropout, num_layers) )
MTB_LMnet = MTB_LM(num_tokens, d, batch_length, num_heads, dropout, num_layers)
num_param = number_param(MTB_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(MTB_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 101 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = MTB_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20, num_heads: 16, dropout: 0.10, num_layers: 2

num_net_parameters: 432000 / 0.43 million

Epoch: 0, time(sec): 0.023, lr= 0.000300, loss_epoch: 5.076
Epoch: 10, time(sec): 0.155, lr= 0.000300, loss_epoch: 4.135
Epoch: 20, time(sec): 0.265, lr= 0.000300, loss_epoch: 3.238
Epoch: 30, time(sec): 0.364, lr= 0.000300, loss_epoch: 2.606
Epoch: 40, time(sec): 0.455, lr= 0.000300, loss_epoch: 2.372
Epoch: 50, time(sec): 0.554, lr= 0.000300, loss_epoch: 1.615
Epoch: 60, time(sec): 0.651, lr= 0.000300, loss_epoch: 1.386
Epoch: 70, time(sec): 0.750, lr= 0.000300, loss_epoch: 1.056
Epoch: 80, time(sec): 0.851, lr= 0.000300, loss_epoch: 0.876
Epoch: 90, time(sec): 0.948, lr= 0.000300, loss_epoch: 0.875
Epoch: 100, time(sec): 1.044, lr= 0.000300, loss_epoch: 0.931


# Generate a new sequence of any length

In [None]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training <==
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation <==
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask)[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

class GEN_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        for transformer_block in self.transformer_blocks:
            H = transformer_block(H) # size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(H) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5; batch_length = 20 # bebug
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# network parameters
d = 128 # embedding dimension
num_heads = 16
dropout = 0.1
num_layers = 2
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, batch_length, num_heads, dropout, num_layers) )
GEN_LMnet = GEN_LM(num_tokens, d, batch_length, num_heads, dropout, num_layers)
num_param = number_param(GEN_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# Train network to predict next token
optimizer = torch.optim.AdamW(GEN_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
num_epochs = 11 # 101(debug), number of epochs
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_scores = GEN_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_epoch = running_loss / num_batch
    if not epoch%10:
        print('Epoch: %d, time(sec): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, time.time()-start, optimizer.param_groups[0]['lr'], loss_epoch) )

# generate a new sentence of any length
def generate(LMnet, prompt_seq, max_length_gen_seq):
    gen_seq = prompt_seq # an initial sequence (a.k.a. prompt) to generate a longer sequence with LMnet
    for k in range(max_length_gen_seq):
        context_seq = gen_seq[:,-batch_length:] # size=[1, <=batch_length]
        score = LMnet(context_seq) # size=[1, batch_length, num_tokens]
        score_last_token = score[:,-1,:].squeeze(dim=1) # size=[1, num_tokens]
        prob_last_token = torch.softmax(score_last_token, dim=1) # size=[1, num_tokens]
        #idx_next_token = torch.multinomial(prob_last_token, num_samples=1) # size=[1,1]
        idx_next_token = torch.max(prob_last_token, dim=1).indices[0].view(1,1) # size=[1,1]
        gen_seq = torch.cat((gen_seq, idx_next_token), dim=1) # append next token, (size=[1, num_tokens+1]
    return gen_seq

# generate from a prompt sequence
#  prompt_seq = [ 2, 4, 6, 8 ]
#     gen_seq =              [ 10, 12, 14, 16, 18, 20, 22 ]
prompt_seq = get_batch(seq, batch_size=1, batch_length=4, start_idx=10, list_batch_idx=torch.arange(num_subseq-1))[0] # generate a small sequence to complete, size=[1,batch_length]
prompt_tokens = func_int2token(prompt_seq[0].tolist())
print('\nsequence   :', prompt_tokens)
gen_seq = generate(GEN_LMnet, prompt_seq, max_length_gen_seq=batch_length)[0][prompt_seq[0].size(0):]
seq_tokens = func_int2token(gen_seq.tolist())
print('prediction :               ', seq_tokens,'\n')


seq_len: 100, batch_size: 5, batch_length: 20, num_subseq: 5, num_batch: 1

num_tokens: 128, d: 128, batch_length: 20, num_heads: 16, dropout: 0.10, num_layers: 2

num_net_parameters: 432000 / 0.43 million

Epoch: 0, time(sec): 0.018, lr= 0.000300, loss_epoch: 5.076
Epoch: 10, time(sec): 0.128, lr= 0.000300, loss_epoch: 4.135

sequence   : 70 76 <SEP> 95
prediction :                <SEP> <SEP> 90 <SEP> <SEP> <SEP> 99 <SEP> 99 <SEP> 90 <SEP> 90 <SEP> <SEP> 90 <SEP> 99 <SEP> 99 



## Final version of Step #1 : SSL-LLM
## Add GPU training, saving pre-trained net, warmup learning rate, stopping condition


In [9]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# GPU training
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda") # use GPU
else:
    device = torch.device("cpu")
print('device:',device,'\n')

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training <==
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation <==
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask.to(device))[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

# class of self-supervised learning LM network (step 1)
class SSL_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
    def forward(self, batch_seq):
        seq_pos_encoding = torch.arange(batch_seq.size(1)).to(device) # positional encoding = {0,1,2,...,batch_length-1}
        H = self.token2vec(batch_seq) + self.PE_embedding(seq_pos_encoding).unsqueeze(0) # size=[batch_size, batch_length, d]
        for transformer_block in self.transformer_blocks:
            H = transformer_block(H) # size=[batch_size, batch_length, d]
        batch_scores = self.token_prediction(H) # size=[batch_size, batch_length, num_tokens]
        return batch_scores # return prediction scores for next token

# generate a new sentence of any length
def generate(LMnet, prompt_seq, max_length_gen_seq):
    gen_seq = prompt_seq # an initial sequence (a.k.a. prompt) to generate a longer sequence with LMnet
    for k in range(max_length_gen_seq):
        context_seq = gen_seq[:,-batch_length:] # size=[1, <=batch_length]
        score = LMnet(context_seq) # size=[1, batch_length, num_tokens]
        score_last_token = score[:,-1,:].squeeze(dim=1) # size=[1, num_tokens]
        prob_last_token = torch.softmax(score_last_token, dim=1) # size=[1, num_tokens]
        #idx_next_token = torch.multinomial(prob_last_token, num_samples=1) # size=[1,1]
        idx_next_token = torch.max(prob_last_token, dim=1).indices[0].view(1,1) # size=[1,1]
        gen_seq = torch.cat((gen_seq, idx_next_token), dim=1) # append next token, (size=[1, num_tokens+1]
    return gen_seq

# network parameters
d = 128; num_heads = 16; dropout = 0.1; num_layers = 2; batch_length = 20 # bebug
d_head = 64; num_heads = 6; d = num_heads * d_head; dropout = 0.1; num_layers = 6; batch_length = 40 # GPU training <==
print('num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, batch_length, num_heads, dropout, num_layers) )
SSL_LMnet = SSL_LM(num_tokens, d, batch_length, num_heads, dropout, num_layers)
SSL_LMnet = SSL_LMnet.to(device)
num_param = number_param(SSL_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# optimizer
optimizer = torch.optim.AdamW(SSL_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
warmup = 500 # 50(debug), 500(GPU), number of batches used for warmup <==
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: min(t/warmup, 1.0) ) # warmup learning rate scheduler, good for LM (softmax)

# save checkpoint
net_parameters = {}
net_parameters['num_tokens'] = num_tokens
net_parameters['d'] = d
net_parameters['num_heads'] = num_heads
net_parameters['batch_length'] = batch_length
net_parameters['dropout'] = dropout
net_parameters['num_layers'] = num_layers
checkpoint_dir = os.path.join("checkpoint")
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print('checkpoint file :', checkpoint_dir + '/step1_checkpoint_SSL_LM_' + time_stamp + '.pkl', '\n')

# batching parameters
seq_len = seq.size(0) # length of the long sequence
batch_size = 5 # bebug
batch_size = 500 # GPU training <==
num_subseq = seq_len // batch_length # number of subsequences
num_batch = seq_len // (batch_size * batch_length) # number of batches
start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # new starting index at each new epoch, random integer in {0,batch_length-1}
list_batch_idx = torch.arange(num_batch) # list of batch indices, [0,1,...,num_batch-1]
print('seq_len: %d, batch_size: %d, batch_length: %d, num_subseq: %d, num_batch: %d\n' % (seq_len, batch_size, batch_length, num_subseq, num_batch) )

# Train network to predict next token
num_epochs = 1 # 1001(debug), 11(GPU), number of epochs <==
print('num_epochs :',num_epochs,'\n')
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_batch_idx = torch.arange(num_subseq-1) # list of batch indices
    start_idx = torch.randint(low=0, high=batch_length, size=(1,)) # size=[1]
    running_loss = 0.0 # tracking total loss value
    for _ in range(num_batch): # number of batches into one epoch
        batch_seq, target_seq, list_batch_idx = get_batch(seq, batch_size, batch_length, start_idx, list_batch_idx) # generate a batch of subsequences
        batch_seq, target_seq = batch_seq.to(device), target_seq.to(device) # GPU training <==
        batch_scores = SSL_LMnet(batch_seq) # size=[batch_size, batch_length, num_tokens]
        loss = nn.CrossEntropyLoss()(batch_scores.view(batch_scores.size(0)*batch_length, num_tokens), target_seq.view(batch_scores.size(0)*batch_length)) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step() # warmup scheduler
    loss_epoch = running_loss / num_batch
    if not epoch%1: # 100(debug), 1(GPU) <==
        print('Epoch: %d, time(min): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, (time.time()-start)/60, optimizer.param_groups[0]['lr'], loss_epoch) )
         # save checkpoint
        torch.save({
            'epoch': epoch,
            'tot_time': time.time()-start,
            'loss': loss_epoch,
            'net_parameters': net_parameters,
            'SSL_LMnet_dict': SSL_LMnet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            }, '{}.pkl'.format(checkpoint_dir + "/step1_checkpoint_SSL_LM_" + time_stamp ))
        # check prediction performance
        prompt_seq = get_batch(seq, batch_size=1, batch_length=4, start_idx=10, list_batch_idx=torch.arange(num_subseq-1))[0].to(device) # generate a small sequence to complete, size=[1,batch_length]
        prompt_tokens = func_tokens2str(func_indices2tokens(prompt_seq[0].tolist()))
        print('sequence   :', prompt_tokens)
        gen_seq = generate(SSL_LMnet, prompt_seq, max_length_gen_seq=batch_length)[0][prompt_seq[0].size(0):]
        seq_tokens = func_tokens2str(func_indices2tokens(gen_seq.tolist()))
        print('prediction :             ', seq_tokens,'\n')
    # Stopping condition
    if loss_epoch < 0.01:
        print("\n loss value is small -- training stopped\n")
        break

# GPU training time : Epoch: 10, time(min): 13.235, lr= 0.000300, loss_epoch: 1.031


Tesla T4
device: cuda 

num_tokens: 129, d: 384, batch_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_net_parameters: 10761345 / 10.76 million

checkpoint file : checkpoint/step1_checkpoint_SSL_LM_23-12-04--10-32-53.pkl 

seq_len: 3000000, batch_size: 500, batch_length: 40, num_subseq: 75000, num_batch: 150

num_epochs : 1 

Epoch: 0, time(min): 2.450, lr= 0.000090, loss_epoch: 3.723
sequence   : 62 66 70 74
prediction :              78 <SEP> 10 11 13 15 19 22 <SEP> 47 52 55 56 61 <SEP> 20 22 25 31 37 43 50 53 56 61 68 76 <SEP> 25 26 27 29 30 32 33 38 45 50 52 53 



## load pre-trained SSL-LM network

In [10]:
checkpoint_file = checkpoint_dir + '/step1_checkpoint_SSL_LM_' + time_stamp + '.pkl'
checkpoint = torch.load(checkpoint_file, map_location=device)
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
batch_length = net_parameters['batch_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SSL-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
print(' num_tokens: %d, d: %d, batch_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, batch_length, num_heads, dropout, num_layers) )
SSL_LMnet = SSL_LM(num_tokens, d, batch_length, num_heads, dropout, num_layers)
SSL_LMnet = SSL_LMnet.to(device)
SSL_LMnet.load_state_dict(checkpoint['SSL_LMnet_dict'])
num_param = number_param(SSL_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# check pre-trained network : generate from a prompt sequence
#  prompt_seq = 2, 4, 6, 8
#     gen_seq =             10, 12, 14, 16, 18, 20, 22
num_subseq = seq.size(0) // batch_length # number of subsequence
prompt_seq = get_batch(seq, batch_size=1, batch_length=4, start_idx=10, list_batch_idx=torch.arange(num_subseq-1))[0].to(device) # generate a small sequence to complete, size=[1,batch_length]
prompt_tokens = func_tokens2str(func_indices2tokens(prompt_seq[0].tolist()))
print('sequence   :', prompt_tokens)
gen_seq = generate(SSL_LMnet, prompt_seq, max_length_gen_seq=batch_length)[0][prompt_seq[0].size(0):]
seq_tokens = func_tokens2str(func_indices2tokens(gen_seq.tolist()))
print('prediction :             ', seq_tokens,'\n')


Load pre-trained SSL-LM: 
 checkpoint file: checkpoint/step1_checkpoint_SSL_LM_23-12-04--10-32-53.pkl
 epoch: 0, time: 146.999min, loss=3.7226
 num_tokens: 129, d: 384, batch_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_net_parameters: 10761345 / 10.76 million

sequence   : 48 51 54 57
prediction :              60 <SEP> 53 60 61 68 69 76 84 92 96 <SEP> 51 56 60 61 62 68 75 78 87 96 <SEP> 50 52 59 66 75 80 86 95 <SEP> 54 60 66 67 72 76 84 92 

