# Step #2 : Supervised Fine-Tuning (SFT) of Language Model (LM)

## Task : from prompt, generate response

## Use pre-trained SSL-LM model from Step #1 and fine-tuned it with a training set of (prompt,response)

### Xavier Bresson, xavier.bresson@gmail.com, https://twitter.com/xbresson

### Number of data points for GPT-3, 175B parameters
+ Step #1 : 300B tokens
+ Step #2 : 10k-100k pairs (prompt, response)
+ Step #3 : 100k-1M triples (prompt, positive response, negative response)
+ Step #4 : 10k-100k prompts

### Number of data points for this tutorial
+ Step #1 : 3M tokens
+ Step #2 : 10k pairs (prompt, response)
+ Step #3 : 10k triples (prompt, positive response, negative response)
+ Step #4 : 1k prompts

### Objectives
+ Supervised learning with auto-regressive prediction of sequences
+ Train with batch of pairs (prompt, response) for fast training with GPU
+ Load a pre-trained LM network


In [1]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

In [1]:
# Libraries
import torch
print(torch.__version__)
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings
import os, datetime


2.1.0


## Time stamp for save/load data


In [5]:
# save time stamp
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")

# check dataset folder exists 
data_dir = os.path.join('dataset')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# select a time stamp
use_saved_time_stamp = False
use_saved_time_stamp = True
if use_saved_time_stamp:
    time_stamp = '23-11-24--12-53-11' # trained on GPU on xxx

print('time_stamp:', time_stamp, '\n')


time_stamp: 23-11-24--12-53-11 



## Load dictionary of tokens from step #1


In [6]:
load_file_dictionary = 'dataset/step1_02_SSL_dictionary_23-11-23--12-26-17.pt'
dictionary, num_tokens, token2index, index2token = torch.load(load_file_dictionary) # load dictionary of tokens
print('dictionary:',dictionary,'\n')
print('num_tokens (unique):',num_tokens,'\n')
print('token2index:', token2index,'\n')
print('index2token:', index2token,'\n')
func_tokens2indices = lambda list_tokens: [token2index[token] for token in list_tokens] # ['Let', '5', 'be', 'the'] => [113, 46, 114, 115]
func_indices2tokens = lambda list_ints: [index2token[integer] for integer in list_ints] # [113, 46, 114, 115] => ['Let', '5', 'be', 'the']
func_str2tokens = lambda input_str: [token_str for token_str in input_str.split()]      # 'Let 5 be the' => ['Let', '5', 'be', 'the']
func_tokens2str = lambda list_str: ' '.join(list_str)                                   # ['Let', '5', 'be', 'the'] => 'Let 5 be the'


dictionary: ['50', '56', '62', '68', '74', '80', '86', '<SEP>', '20', '24', '28', '32', '36', '14', '18', '22', '26', '30', '34', '38', '42', '46', '54', '58', '53', '61', '69', '77', '85', '93', '37', '45', '51', '66', '71', '76', '81', '91', '96', '0', '9', '27', '19', '23', '31', '35', '39', '43', '47', '2', '10', '33', '41', '49', '57', '65', '73', '89', '97', '88', '29', '44', '59', '64', '83', '92', '95', '98', '40', '15', '21', '25', '55', '60', '78', '87', '79', '84', '94', '48', '52', '90', '5', '8', '11', '17', '100', '99', '75', '63', '67', '70', '72', '3', '12', '82', '1', '7', '13', '16', '6', '4', 'generate', 'an', 'arithmetic', 'series', 'with', 'terms', 'starting', 'value', 'and', 'common', 'difference', 'Let', 'be', 'the', 'number', 'of', 'then', 'write', 'make', 'a', 'type', 'which', 'starts', 'at', 'elements', '<PAD>', '<EOS>'] 

num_tokens (unique): 129 

token2index: {'50': 0, '56': 1, '62': 2, '68': 3, '74': 4, '80': 5, '86': 6, '<SEP>': 7, '20': 8, '24': 9, '28':

## Generate training tuples/pairs of (prompt, response)
### For NLP, training pairs are prepared by humans 


In [9]:
# generate arithmetic series
m = max_value = 100 # maximum value in the sequence
def arithmetic_series(m, s, d, n):
    seq = []
    for i in range(n):
        v = s + i * d
        if v <= m:
            seq.append(v)
        else:
            break
    return seq

# generate training data, i.e. pairs of prompt and response
#   prompt = [ generate arithmetic series of 5 terms with difference 2 starting at 3 ]
# response = [ 3, 5, 7, 9, 11 ]
save_training_data = False
#save_training_data = True
if save_training_data:

    # "collect" high-quality "human" training set
    list_prompt = []
    list_response = []
    num_training_data = 12 # debug
    num_training_data = 10000 # number of pairs of (prompt, response), e.g. GPU 10,000 pairs (prompt, response)
    start = time.time()
    for idx in range(num_training_data):

        # parameters for arithmetic series
        m = max_value # maximum value in the sequence
        s = torch.randint(low=0, high=m, size=(1,)).item() # starting integer of the series
        d = torch.randint(low=1, high=10, size=(1,)).item() # value of common difference
        n = torch.randint(low=5, high=15, size=(1,)).item() # number of element in the series
        #print('max_value: %d, start_value: %d, common_difference: %d, number_of_terms: %d' % (m,s,d,n))

        # generate prompt : sample a prompt between 3 candidate prompts
        prompt = {}
        prompt[1] = 'generate an arithmetic series with ' + str(n) + ' terms starting with value ' + str(s) + ' and common difference ' + str(d)
        prompt[2] = 'make a series of arithmetic type which starts at ' + str(s) + ' with ' + str(n) + ' elements and ' + str(d) + ' common difference value'
        prompt[3] = 'Let ' + str(n) + ' be the number of terms ' + str(s) + ' the starting number and ' + str(d) + ' the common difference then write the arithmetic series'
        random_int = torch.randint(low=1, high=3+1, size=(1,)).item() # random number in {1,2,3}
        #random_int = 1 # debug
        prompt = prompt[random_int]
        response = arithmetic_series(m,s,d,n)
        
        # covert from token to integrer
        prompt = [str(i) for i in prompt.split()] # convert a string into seq of tokens (w/ string type)
        prompt = func_tokens2indices(prompt) # convert from token (str) to index (int)
        prompt = torch.tensor(prompt) # convert to pytorch
        response = [str(i) for i in response] # convert a string into seq of tokens (w/ string type)
        response = func_tokens2indices(response) # convert from token (str) to index (int)
        response = torch.tensor(response) # convert to pytorch
        
        # append
        list_prompt.append(prompt)
        list_response.append(response)
        
        # track 
        if not idx%1000:
            print('idx: %d, time(sec): %.3f' % (idx, time.time()-start) )
 
    # print
    print('number of training data (prompt, response) :',len(list_prompt),'\n')
    for idx, (prompt, response) in enumerate(zip(list_prompt[:3],list_response[:3])):
        print('training_set[%d] (pytorch) : %s : %s ' % (idx, prompt, response) )
        prompt = func_tokens2str(func_indices2tokens(prompt.tolist()))
        response = func_tokens2str(func_indices2tokens(response.tolist()))
        print('training_set[%d] (token) : %s : %s ' % (idx, prompt, response) , '\n' )
        
    # save training data
    save_file = data_dir + '/step2_01_SFT_training_set_' + time_stamp + '.pt'
    print('save_file:', save_file, '\n')
    torch.save([list_prompt, list_response],save_file) # save list of prompts and sequences

else:

    # load training data
    load_file = data_dir + '/step2_01_SFT_training_set_' + time_stamp + '.pt'
    print('load_file:', load_file, '\n')
    list_prompt, list_response = torch.load(load_file)
    
    # print
    print('number of training data (prompt, response) :',len(list_prompt),'\n')
    for idx, (prompt, response) in enumerate(zip(list_prompt[:3],list_response[:3])):
        print('training_set[%d] (pytorch) : %s : %s ' % (idx, prompt, response) )
        prompt = func_tokens2str(func_indices2tokens(prompt.tolist()))
        response = func_tokens2str(func_indices2tokens(response.tolist()))
        print('training_set[%d] (token) : %s : %s ' % (idx, prompt, response) , '\n' )


load_file: dataset/step2_01_SFT_training_set_23-11-24--12-53-11.pt 

number of training data (prompt, response) : 10000 

training_set[0] (pytorch) : tensor([102, 103, 104, 105, 106, 100, 107, 108, 106, 109,  66, 110, 111, 112,
         93]) : tensor([66, 67]) 
training_set[0] (token) : generate an arithmetic series with 6 terms starting with value 95 and common difference 3 : 95 98  

training_set[1] (pytorch) : tensor([113,  82, 114, 115, 116, 117, 107,  84, 115, 108, 116, 110, 100, 115,
        111, 112, 118, 119, 115, 104, 105]) : tensor([84, 85, 43, 60, 45]) 
training_set[1] (token) : Let 5 be the number of terms 11 the starting number and 6 the common difference then write the arithmetic series : 11 17 23 29 35  

training_set[2] (pytorch) : tensor([113,  83, 114, 115, 116, 117, 107,  97, 115, 108, 116, 110,  82, 115,
        111, 112, 118, 119, 115, 104, 105]) : tensor([97, 94, 85, 15, 41, 11, 30, 20]) 
training_set[2] (token) : Let 8 be the number of terms 7 the starting number

## Get batch of sampled indices of (prompt,response)


In [10]:
# hyper-parameters
num_prompt_response = len(list_prompt) # number of prompt+response sequences
batch_size = 3 # debug
batch_size = 100 # batch size, 500 GPU
num_batch = num_prompt_response // batch_size # number of batches
print('num_prompt_response: %d, batch_size: %d, num_batch: %d\n' % (num_prompt_response,batch_size,num_batch))

# sample batch of prompt+response
def get_batch(batch_size, list_prompt_response_idx):
    batch_idx = torch.randperm(list_prompt_response_idx.size(0))[:batch_size] # sample B number of batch indices
    batch_idx = list_prompt_response_idx[batch_idx] # and extract from remaining list of batch indices
    if list_prompt_response_idx.size(0) > batch_size:
        new_list_prompt_response_idx = torch.stack([i for i in list_prompt_response_idx if i not in batch_idx]) # remove the sampled batch indices from the list of indices
    else:
        new_list_prompt_response_idx = torch.tensor([]) # last batch of epoch, i.e. return empty list
    return batch_idx, new_list_prompt_response_idx

# # one epoch, debug
# list_prompt_response_idx = torch.arange(num_prompt_response) # list of prompt+response indices
# for i in range(num_batch):
#     print('batch :',i)
#     print('list_prompt_response_idx (before) :',list_prompt_response_idx, list_prompt_response_idx.size())
#     batch_idx, list_prompt_response_idx = get_batch(batch_size, list_prompt_response_idx) # sample a batch of indices (prompt,response)
#     print('batch_idx :', batch_idx, batch_idx.size())
#     print('list_prompt_response_idx (after) :',list_prompt_response_idx, list_prompt_response_idx.size(),'\n')


num_prompt_response: 10000, batch_size: 100, num_batch: 100



## Train LM model by supervised fine-tuning of LM model from step #1

## Dataset is composed of (prompt, response)


In [11]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# GPU training
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda") # use GPU
else:
    device = torch.device("cpu")
print('device:',device,'\n')

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training <==
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation <==
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask.to(device))[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

# Predict next token(t+1) given context = {token(t), token(t-1), ..., token(t-context_length)}
#
#  Prediction is auto-regressive, i.e. one token at a time (different from step #1, one-shot prediction)
#
#  Example of auto-regressive prediction for one SINGLE prompt
#
#             context
#             ------- <= context length (= 4 tokens) to predict next token
#                   |<= starts auto-regressive prediction at end of prompt = 5
#  seq    = 1 2 3 4 5 6 7 8 9 1
#           --------- ----------
#            prompt |  response
#                   |
#                   |<= predicts next_token = 6
#  target =         6 7 8 9 1 eos <= labeled tokens used for loss / training
#                   -------------
#                  len(response)+1 = num_tokens to predict
#                   |
# predicted_seq =   9 4 1 9 2 7   <= predicted tokens (one token at a time) used for loss / training
#
#
#  Example of auto-regressive prediction for a BATCH of prompts (GPU)
#
# Prepare a batch of sampled (prompt,response)
#  Let the tokens P = <Padding> and E = <EOF> (end-of-file)
#
#                      batch_size (all prompts have the same length with padding if needed) => GPU
#                 ---------------------
# batch_prompt = [ P, P, P, 1, 2, 3, 4 ]  // prompt = [1, 2, 3, 4]
#              = [ 3, 4, 5, 6, 7, 8, 9 ]  // prompt = [3, 4, 5, 6, 7, 8, 9]
#                         ...
#              = [ P, P, P, P, P, 5, 6 ]  // prompt = [5, 6]
#                        -------------
#                           context <= context length (= 5 tokens) to predict next token
#                                    | <= starts auto-regressive prediction at end of batch_seq
#                                    |
# batch_predicted_seq            = [ 9, 0, 5, 3, 5, 1 ]
#                                = [ 2, 5, 1, 7, 9, 3 ]
#                                          ...
#                                = [ 6, 2, 1, 8, 3, 9 ]
#                                   -----------------
#                                       max_len_response+1 (all generated responses have the same length ) => GPU
#                                    |
# batch_target                   = [ 5, 6, 7, 8, E, E ]
#                                = [ 0, 1, 2, E, E, E ]
#                                          ...
#                                = [ 7, 8, 9, 0, 1, E ]
#
# class of supervised fine-tuning LM network (step 2)
class SFT_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.seq_pos_encoding = torch.arange(context_length, device=device) # positional encoding = {0,1,2,...,context_length-1}
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
        self.context_length = context_length
        self.padding = padding_int #torch.tensor([eos_int]).to(device)
        self.eos = eos_int #torch.tensor([eos_int]).to(device)
    def forward(self, batch_idx, list_prompt, list_response): # batch_idx.size=[batch_size], len(list_prompt,list_response), =[num_prompt_response]
        prompts = [list_prompt[idx] for idx in batch_idx] # sample list of prompts, len(prompts)=num_prompt_response
        len_prompt = max([len(prompt) for prompt in prompts]) # compute max of prompt lengths
        responses = [list_response[idx] for idx in batch_idx] # sample list of responses, len(prompts)=num_prompt_response
        len_response = max([len(response) for response in responses]) # compute max of response lengths
        batch_size = batch_idx.size(0)
        predicted_seq = torch.ones(batch_size, max(len_prompt,self.context_length)).long().to(device) * self.padding # initiliaze with padding
        for idx in range(batch_size): predicted_seq[idx, -prompts[idx].size(0):] = prompts[idx] # fill batch_predicted_seq with prompt, right-aligned
        predicted_seq_scores = []
        for idx in range(len_response+1): # number of auto-regressive prediction
            context = predicted_seq[:,-self.context_length:] # size=[batch_size, context_length
            H = self.token2vec(context) + self.PE_embedding(self.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
            for transformer_block in self.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
            token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
            token_scores = self.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
            token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
            next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
            predicted_seq = torch.cat((predicted_seq, next_token), dim=1) # size=[batch_size, current_seq_len+1]
            predicted_seq_scores.append(token_scores.unsqueeze(1)) # append size=[batch_size, 1, num_tokens]
        predicted_seq_scores = torch.cat(predicted_seq_scores, dim=1) # size=[batch_size, len_response+1, num_tokens]
        score_seq = []
        for idx in range(batch_size): score_seq.append(predicted_seq_scores[idx,:responses[idx].size(0)+1,:]) # append size=[len_response_idx, num_tokens]
        score_seq = torch.cat(score_seq, dim=0) # [num_tokens_for_loss=sum_idx len_response_idx+1, num_tokens]
        target_seq = [] #torch.ones(batch_size, len_response+1).long().to(device) * self.eos # init with padding value
        for idx in range(batch_size):  target_seq.append(responses[idx].to(device)); target_seq.append(self.eos); # append size=[len_response_idx]
        target_seq = torch.cat(target_seq, dim=0) # [num_tokens_for_loss]
        predicted_seq = predicted_seq[:,max(len_prompt,self.context_length):] # size=[batch_size, max(len_prompt,context_length)]
        return score_seq, target_seq, predicted_seq # return prediction sequences and scores w.r.t. the dictionary of tokens

# from pre-training : load pre-trained SSL-LM network from step #1
checkpoint_file = "checkpoint/step1_checkpoint_SSL_LM_23-11-23--12-26-17.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SSL-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
batch_length = net_parameters['batch_length']; context_length = 40 # batch_length
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SFT_LMnet = SFT_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
SFT_LMnet = SFT_LMnet.to(device) # GPU training <==
SFT_LMnet.load_state_dict(checkpoint['SSL_LMnet_dict']) # # load pre-trained SSL-LM network from step #1
num_param = number_param(SFT_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# # from scratch : train SFT-LM network from random init (i.e. no pre-trained network)
# d_head = 6; num_heads = 16; d = num_heads * d_head; num_layers = 2; dropout = 0.1; context_length = 20 # debug
# #d_head = 64; num_heads = 6; d = num_heads * d_head; num_layers = 6; dropout = 0.1; context_length = 40 # GPU training <==
# print('num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
# padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
# eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
# print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
# SFT_LMnet = SFT_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
# SFT_LMnet = SFT_LMnet.to(device) # GPU training <==
# num_param = number_param(SFT_LMnet)
# print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# optimizer
optimizer = torch.optim.AdamW(SFT_LMnet.parameters(), lr=3e-4) # standard optimizer for LMs
warmup = 500 # 50(debug), 500(GPU), number of batches used for warmup <==
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: min(t/warmup, 1.0) ) # warmup learning rate scheduler, good for LM (softmax)

# save checkpoint
net_parameters = {}
net_parameters['num_tokens'] = num_tokens
net_parameters['d'] = d
net_parameters['num_heads'] = num_heads
net_parameters['context_length'] = context_length
net_parameters['dropout'] = dropout
net_parameters['num_layers'] = num_layers
net_parameters['padding_int'] = padding_int
net_parameters['eos_int'] = eos_int
checkpoint_dir = os.path.join("checkpoint")
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print('checkpoint file :', checkpoint_dir + '/step2_checkpoint_SFT_LM_' + time_stamp + '.pkl', '\n')

# batching parameters
num_prompt_response = len(list_prompt) # number of prompt+response sequences
batch_size = 3 # debug
batch_size = 100 # batch size, 500 GPU <==
num_batch = num_prompt_response // batch_size # number of batches
print('num_prompt_response: %d, batch_size: %d, num_batch: %d\n' % (num_prompt_response, batch_size, num_batch))

# Train network to predict response from prompt
num_epochs = 11 # 1001(debug), 11(GPU), number of epochs <==
print('num_epochs: ',num_epochs,'\n')
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_prompt_response_idx = torch.arange(num_prompt_response).to(device) # initialize the list of prompt+response indices
    running_loss = 0.0 # tracking total loss value
    for k in range(num_batch): # number of batches in one epoch
        batch_idx, list_prompt_response_idx = get_batch(batch_size, list_prompt_response_idx) # sample a batch of indices (prompt,response)
        score_seq, target_seq, predicted_seq = SFT_LMnet(batch_idx.to(device), list_prompt, list_response) # predict next tokens, size=[num_tokens_for_loss, num_tokens], [num_tokens_for_loss], [batch_size, len_response]
        loss = nn.CrossEntropyLoss()(score_seq, target_seq) # classification loss over dict of tokens
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    loss_epoch = running_loss / num_batch
    if not epoch%1: # 10(debug), 1(GPU) <==
        print('Epoch: %d, time(min): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, (time.time()-start)/60, optimizer.param_groups[0]['lr'], loss_epoch) )
        # print one prompt
        idx_prompt = 0
        print('prompt        :',func_tokens2str(func_indices2tokens(list_prompt[batch_idx[idx_prompt]].tolist())))
        print('response      :',func_tokens2str(func_indices2tokens(list_response[batch_idx[idx_prompt]].tolist())))
        print('predicted_seq :',func_tokens2str(func_indices2tokens(predicted_seq[idx_prompt,:][:list_response[batch_idx[idx_prompt]].size(0)].tolist())),'\n' )
        # save checkpoint
        torch.save({
            'epoch': epoch,
            'tot_time': time.time()-start,
            'loss': loss_epoch,
            'net_parameters': net_parameters,
            'SFT_LMnet_dict': SFT_LMnet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            }, '{}.pkl'.format(checkpoint_dir + "/step2_checkpoint_SFT_LM_" + time_stamp ))
        # Stopping condition
        if loss_epoch < 0.1: 
            print("\n loss value is small -- training stopped\n")
            break

# GPU training time : Epoch: 8, time(min): 9.769, lr= 0.000300, loss_epoch: 0.050


NVIDIA RTX A5000
device: cuda 

Load pre-trained SSL-LM: 
 checkpoint file: checkpoint/step1_checkpoint_SSL_LM_23-11-23--12-26-17.pkl
 epoch: 10, time: 794.094min, loss=1.0310
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million

checkpoint file : checkpoint/step2_checkpoint_SFT_LM_23-11-24--12-53-11.pkl 

num_prompt_response: 10000, batch_size: 100, num_batch: 100

num_epochs:  11 

Epoch: 0, time(min): 1.087, lr= 0.000060, loss_epoch: 6.854
prompt        : generate an arithmetic series with 9 terms starting with value 88 and common difference 1
response      : 88 89 90 91 92 93 94 95 96
predicted_seq : 8 72 78 85 0 3 <EOS> 93 <EOS> 

Epoch: 1, time(min): 2.164, lr= 0.000120, loss_epoch: 4.538
prompt        : generate an arithmetic series with 8 terms starting with value 70 and common difference 1
response      : 70 71 72 73 74 75 76 77
predicted_seq : 99 

## Load pre-trained SFT-LM network


In [13]:
# load pre-trained SFT-LM network
checkpoint_file = checkpoint_dir + '/step2_checkpoint_SFT_LM_' + time_stamp + '.pkl'
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SFT-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SFT_LMnet = SFT_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
SFT_LMnet = SFT_LMnet.to(device) 
SFT_LMnet.load_state_dict(checkpoint['SFT_LMnet_dict']) # load pre-trained SFT-LM network from step #2
num_param = number_param(SFT_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# generate new sentence of any length
def generate(LMnet, prompt, max_length_gen_seq):
    LMnet.eval()
    predicted_seq = torch.ones(1, max(prompt.size(0),LMnet.context_length)).long().to(device) * LMnet.padding # initiliaze with padding
    predicted_seq[:, -prompt.size(0):] = prompt # fill batch_predicted_seq with prompt, right-aligned
    for k in range(max_length_gen_seq):
        context = predicted_seq[:,-LMnet.context_length:] # size=[batch_size, context_length
        H = LMnet.token2vec(context) + LMnet.PE_embedding(LMnet.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in LMnet.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
        token_scores = LMnet.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
        token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
        next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
        #next_token = torch.max(token_probs, dim=1).indices[0].view(1,1) # size=(1,1)
        if next_token==LMnet.eos:
            break
        predicted_seq = torch.cat((predicted_seq, next_token), dim=1) # size=[batch_size, current_seq_len+1]
    gen_seq = predicted_seq[0][max(prompt.size(0),LMnet.context_length):]
    return gen_seq

# print one prompt
idx_prompt = torch.randint(low=0, high=num_prompt_response, size=(1,)).item() # random number in {0,...,num_prompt_response-1}
print('idx_prompt :',idx_prompt)
prompt = list_prompt[idx_prompt]
print('prompt     :',func_tokens2str(func_indices2tokens(prompt.tolist())))
response = list_response[idx_prompt]
print('response   :',func_tokens2str(func_indices2tokens(response.tolist())))
gen_seq = generate(SFT_LMnet, prompt, max_length_gen_seq=15)
print('gen_seq    :',func_tokens2str(func_indices2tokens(gen_seq.tolist())))


Load pre-trained SFT-LM: 
 checkpoint file: checkpoint/step2_checkpoint_SFT_LM_23-11-24--12-53-11.pkl
 epoch: 8, time: 586.115min, loss=0.0496
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million

idx_prompt : 2371
prompt     : Let 7 be the number of terms 55 the starting number and 6 the common difference then write the arithmetic series
response   : 55 61 67 73 79 85 91
gen_seq    : 55 61 67 73 79 85 91
