# Step #4 : Reinforcement Learning (RL) of Language Model (LM)

## Task : from prompt, generate response

## Use pre-trained models:
### SFT-LM model from Step #2 as reference model to fine-tune w.r.t. rank reward
### SL-RM model from Step #3 as reward model to evaluate generation

## Data structure : prompt

### Xavier Bresson, xavier.bresson@gmail.com, https://twitter.com/xbresson

### Number of data points for GPT-3, 175B parameters
+ Step #1 : 300B tokens
+ Step #2 : 10k-100k pairs (prompt, response)
+ Step #3 : 100k-1M triples (prompt, positive response, negative response)
+ Step #4 : 10k-100k prompts

### Number of data points for this tutorial
+ Step #1 : 1M tokens
+ Step #2 : 10k pairs (prompt, response)
+ Step #3 : 10k triples (prompt, positive response, negative response)
+ Step #4 : 1k prompts

### Objectives
+ Adapt PPO reinforcement learning technique to language generation
+ Train with batch of prompt training for fast training with GPU
+ Use pre-trained models from steps 2 and 3


In [1]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

Mounted at /content/gdrive
/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs
/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs


In [2]:
# Libraries
import torch
print(torch.__version__)
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings
import os, datetime


2.1.0+cu118


## Time stamp for save/load data


In [3]:
# save time stamp
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")

# check dataset folder exists
data_dir = os.path.join('dataset')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# select a time stamp
use_saved_time_stamp = False
#use_saved_time_stamp = True
if use_saved_time_stamp:
    time_stamp = '23-11-28--15-33-57' # trained on GPU on xxx

print('time_stamp:', time_stamp, '\n')


time_stamp: 23-12-04--15-21-18 



## Load dictionary of tokens (step #1)

In [4]:
load_file_dictionary = 'dataset/step1_02_SSL_dictionary_23-12-04--10-32-53.pt'
dictionary, num_tokens, token2index, index2token = torch.load(load_file_dictionary) # load dictionary of tokens
print('dictionary:',dictionary,'\n')
print('num_tokens (unique):',num_tokens,'\n')
print('token2index:', token2index,'\n')
print('index2token:', index2token,'\n')
func_tokens2indices = lambda list_tokens: [token2index[token] for token in list_tokens] # ['Let', '5', 'be', 'the'] => [113, 46, 114, 115]
func_indices2tokens = lambda list_ints: [index2token[integer] for integer in list_ints] # [113, 46, 114, 115] => ['Let', '5', 'be', 'the']
func_str2tokens = lambda input_str: [token_str for token_str in input_str.split()]      # 'Let 5 be the' => ['Let', '5', 'be', 'the']
func_tokens2str = lambda list_str: ' '.join(list_str)                                   # ['Let', '5', 'be', 'the'] => 'Let 5 be the'


dictionary: ['45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '<SEP>', '87', '94', '76', '81', '86', '91', '96', '78', '59', '63', '67', '14', '17', '20', '23', '26', '29', '32', '35', '38', '41', '44', '42', '62', '72', '77', '82', '92', '97', '65', '18', '28', '33', '43', '73', '74', '75', '79', '80', '83', '84', '0', '7', '21', '70', '30', '34', '36', '40', '31', '37', '39', '61', '64', '71', '89', '95', '93', '99', '25', '27', '66', '10', '12', '16', '22', '24', '85', '60', '68', '100', '98', '88', '69', '1', '13', '19', '90', '6', '11', '9', '15', '2', '3', '4', '5', '8', 'generate', 'an', 'arithmetic', 'series', 'with', 'terms', 'starting', 'value', 'and', 'common', 'difference', 'Let', 'be', 'the', 'number', 'of', 'then', 'write', 'make', 'a', 'type', 'which', 'starts', 'at', 'elements', '<PAD>', '<EOS>'] 

num_tokens (unique): 129 

token2index: {'45': 0, '46': 1, '47': 2, '48': 3, '49': 4, '50': 5, '51': 6, '52': 7, '53': 8, '54': 9, '55': 10

## Generate the RL training set of prompts


In [5]:
# generate arithmetic series
m = max_value = 100 # maximum value in the sequence
def arithmetic_series(m, s, d, n):
    seq = []
    for i in range(n):
        v = s + i * d
        if v <= m:
            seq.append(v)
        else:
            break
    return seq

# generate training data, i.e. list of prompts
#   prompt = [ generate arithmetic series of 5 terms with difference 2 starting at 3 ]
save_training_data = False
save_training_data = True
if save_training_data:

    # "collect" training set
    list_prompt_RL = []
    num_training_data = 12 # debug
    num_training_data = 1000 # number of pairs of (prompt, response), e.g. GPU 1,000 pairs (prompt, response)
    start = time.time()
    for idx in range(num_training_data):

        # parameters for arithmetic series
        m = max_value # maximum value in the sequence
        s = torch.randint(low=0, high=m, size=(1,)).item() # starting integer of the series
        d = torch.randint(low=1, high=10, size=(1,)).item() # value of common difference
        n = torch.randint(low=5, high=15, size=(1,)).item() # number of element in the series
        #print('max_value: %d, start_value: %d, common_difference: %d, number_of_terms: %d' % (m,s,d,n))

        # generate prompt : sample a prompt between 3 candidate prompts
        prompt = {}
        prompt[1] = 'generate an arithmetic series with ' + str(n) + ' terms starting with value ' + str(s) + ' and common difference ' + str(d)
        prompt[2] = 'make a series of arithmetic type which starts at ' + str(s) + ' with ' + str(n) + ' elements and ' + str(d) + ' common difference value'
        prompt[3] = 'Let ' + str(n) + ' be the number of terms ' + str(s) + ' the starting number and ' + str(d) + ' the common difference then write the arithmetic series'
        random_int = torch.randint(low=1, high=3+1, size=(1,)).item() # random number in {1,2,3}
        #random_int = 1 # debug
        prompt = prompt[random_int]
        #response = arithmetic_series(m,s,d,n)

        # covert from token to pytorch
        prompt = [str(i) for i in prompt.split()] # convert a string into seq of tokens (w/ string type)
        prompt = func_tokens2indices(prompt) # convert from token (str) to index (int)
        prompt = torch.tensor(prompt) # convert to pytorch

        # append
        list_prompt_RL.append(prompt)

        # track
        if not idx%1000:
            print('idx: %d, time(sec): %.3f' % (idx, time.time()-start) )

    # print
    print('number of training data / prompt :',len(list_prompt_RL),'\n')
    for idx, prompt in enumerate(list_prompt_RL[:3]):
        prompt = func_tokens2str(func_indices2tokens(prompt.tolist()))
        print('training_set[%d] : %s ' % (idx, prompt) , '\n' )

    # save training data
    save_file = data_dir + '/step4_01_RL_training_set_' + time_stamp + '.pt'
    print('save_file:', save_file, '\n')
    torch.save([list_prompt_RL],save_file) # save list of prompts

else:

    # load training data
    load_file = data_dir + '/step4_01_RL_training_set_' + time_stamp + '.pt'
    print('load_file:', load_file, '\n')
    list_prompt_RL = torch.load(load_file)[0]

    # print
    print('number of training data / prompt :',len(list_prompt_RL),'\n')
    for idx, prompt in enumerate(list_prompt_RL[:3]):
        prompt = func_tokens2str(func_indices2tokens(prompt.tolist()))
        print('training_set[%d] : %s ' % (idx, prompt) , '\n' )


idx: 0, time(sec): 0.046
number of training data / prompt : 1000 

training_set[0] : Let 6 be the number of terms 2 the starting number and 3 the common difference then write the arithmetic series  

training_set[1] : Let 12 be the number of terms 24 the starting number and 8 the common difference then write the arithmetic series  

training_set[2] : make a series of arithmetic type which starts at 68 with 12 elements and 8 common difference value  

save_file: dataset/step4_01_RL_training_set_23-12-04--15-21-18.pt 



## Get batch of sampled indices of RL prompts

In [6]:
# batching parameters
num_prompt_RL = len(list_prompt_RL) # number of prompt sequences
batch_size = 3 # debug
batch_size = 100 # batch size, 500 GPU
num_batch_RL = num_prompt_RL // batch_size # number of batches
print('num_prompt_RL: %d, batch_size: %d, num_batch_RL: %d\n' % (num_prompt_RL,batch_size,num_batch_RL))

# sample batch of RL prompt
def get_batch_RL(batch_size, list_prompts_idx):
    batch_idx = torch.randperm(list_prompts_idx.size(0))[:batch_size] # sample B number of batch indices
    batch_idx = list_prompts_idx[batch_idx] # and extract from remaining list of batch indices
    if list_prompts_idx.size(0) > batch_size:
        new_list_prompts_idx = torch.stack([i for i in list_prompts_idx if i not in batch_idx]) # remove the sampled batch indices from the list of indices
    else:
        new_list_prompts_idx = torch.tensor([]) # last batch of epoch, i.e. return empty list
    return batch_idx, new_list_prompts_idx

# # one epoch, debug
# list_prompts_idx = torch.arange(num_prompt_RL) # list of RL prompt indices
# for i in range(3): # num_batch_RL
#     print('batch :',i)
#     print('list_prompts_idx (before) :',list_prompts_idx, list_prompts_idx.size())
#     batch_idx, list_prompts_idx = get_batch_RL(batch_size, list_prompts_idx) # sample a batch of indices (prompt,response)
#     print('batch_idx :', batch_idx, batch_idx.size())
#     print('list_prompts_idx (after) :',list_prompts_idx, list_prompts_idx.size(),'\n')


num_prompt_RL: 1000, batch_size: 100, num_batch_RL: 10



## Transformers backbone for all models, i.e. step #2, step #3 and step #4


In [7]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# GPU training
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda") # use GPU
else:
    device = torch.device("cpu")
print('device:',device,'\n')

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask.to(device))[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H


Tesla T4
device: cuda 



## Load pre-trained SFT-LM network (step #2)


In [8]:
# class of supervised fine-tuning LM network (step 2)
class SFT_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.seq_pos_encoding = torch.arange(context_length, device=device) # positional encoding = {0,1,2,...,context_length-1}
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
        self.context_length = context_length
        self.padding = padding_int
        self.eos = eos_int
    # Note : No forward function is needed

# load pre-trained SFT-LM network (step 2)
checkpoint_file = "checkpoint/step2_checkpoint_SFT_LM_23-12-04--13-08-04.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SFT-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = SFT_LM_net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SFT_LMnet = SFT_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
SFT_LMnet = SFT_LMnet.to(device)
SFT_LMnet.load_state_dict(checkpoint['SFT_LMnet_dict']) # load pre-trained SFT-LM network (step 2)
num_param = number_param(SFT_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# check model prediction
# generate new sentence of any length
def generate(LMnet, prompt, max_length_gen_seq):
    LMnet.eval()
    predicted_seq = torch.ones(1, max(prompt.size(0),LMnet.context_length)).long().to(device) * LMnet.padding # initiliaze with padding
    predicted_seq[:, -prompt.size(0):] = prompt # fill batch_predicted_seq with prompt, right-aligned
    for k in range(max_length_gen_seq):
        context = predicted_seq[:,-LMnet.context_length:] # size=[batch_size, context_length
        H = LMnet.token2vec(context) + LMnet.PE_embedding(LMnet.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in LMnet.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
        token_scores = LMnet.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
        token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
        next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
        #next_token = torch.max(token_probs, dim=1).indices[0].view(1,1) # size=(1,1)
        if next_token==LMnet.eos:
            break
        predicted_seq = torch.cat((predicted_seq, next_token), dim=1) # size=[batch_size, current_seq_len+1]
    gen_seq = predicted_seq[0][max(prompt.size(0),LMnet.context_length):]
    return gen_seq
prompt_idx = torch.randint(0, num_prompt_RL, (1,))
prompt_RL = list_prompt_RL[prompt_idx]
print('prompt_RL :',func_tokens2str(func_indices2tokens(prompt_RL.tolist())))
gen_seq = generate(SFT_LMnet, prompt_RL, 15)
print('gen_seq   :',func_tokens2str(func_indices2tokens(gen_seq.tolist())))


Load pre-trained SFT-LM: 
 checkpoint file: checkpoint/step2_checkpoint_SFT_LM_23-12-04--13-08-04.pkl
 epoch: 1, time: 457.656min, loss=1.6250
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million

prompt_RL : make a series of arithmetic type which starts at 82 with 10 elements and 9 common difference value
gen_seq   : 82 91 100


## Load pre-trained RM network (step #3)


In [9]:
# Supervised learning network for reward (step 3)
class SL_RM(nn.Module):
    def __init__(self, SFT_LM, d, context_length, padding_int, eos_int):
        super().__init__()
        self.SFT_LM = SFT_LM # token embedding layer
        self.reward_prediction = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,d), nn.ReLU(), nn.Linear(d,1)) # reward prediction layer
        self.context_length = context_length
        self.padding = padding_int
        self.eos = eos_int
    # Note : No forward function is needed

# pre-trained SL-RM network
checkpoint_file = 'checkpoint/step3_checkpoint_SL_RM_23-12-04--15-15-36.pkl'
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SL-RM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SL_RMnet = SL_RM(SFT_LMnet, d, context_length, padding_int, eos_int)
SL_RMnet = SL_RMnet.to(device)
SL_RMnet.load_state_dict(checkpoint['SL_RMnet_dict']) # load pre-trained SL-RM network from step #3
num_param = number_param(SL_RMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# check model prediction
# compute rank reward
def rank(SL_RMnet, prompt, response):
    len_prompt_response = prompt.size(0) + response.size(0)
    batch_seq = torch.ones(1, max(len_prompt_response,SL_RMnet.context_length)).long().to(device) * SL_RMnet.padding # initiliaze with padding
    batch_seq[:, -len_prompt_response:] = torch.cat((prompt,response),dim=0) # fill batch_predicted_seq with prompt, right-aligned
    H = SL_RMnet.SFT_LM.token2vec(batch_seq) + SL_RMnet.SFT_LM.PE_embedding(SL_RMnet.SFT_LM.seq_pos_encoding[:batch_seq.size(1)]).unsqueeze(0) # size=[2* batch_size, context_length, d]
    for transformer_block in SL_RMnet.SFT_LM.transformer_blocks: H = transformer_block(H) # size=[1, context_length, d)
    token_score = H[:,-1,:] # extract last token scores to predict rewards, size=[1, d]
    reward_score = SL_RMnet.reward_prediction(token_score) # compute reward scores, size=[1, 1]
    return reward_score
prompt_idx = torch.randint(0, num_prompt_RL, (1,))
prompt_RL = list_prompt_RL[prompt_idx].to(device)
print('prompt_RL :',func_tokens2str(func_indices2tokens(prompt_RL.tolist())))
gen_seq = generate(SFT_LMnet, prompt_RL, 15)
print('gen_seq   :',func_tokens2str(func_indices2tokens(gen_seq.tolist())))
reward_score = rank(SL_RMnet, prompt_RL, gen_seq)
print('rank      :',reward_score.item())


Load pre-trained SL-RM: 
 checkpoint file: checkpoint/step3_checkpoint_SL_RM_23-12-04--15-15-36.pkl
 epoch: 100, time: 57.364min, loss=0.0573
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10910338 / 10.91 million

prompt_RL : generate an arithmetic series with 10 terms starting with value 99 and common difference 3
gen_seq   : 99
rank      : 4.972538471221924


## Reinforcement learning LM network (step #4)


In [17]:
# Understanding how to cast the task "prompt => text generation" as a reinforcement learning PPO technique
#
# defined state/s, action/a, reward/r in the LM context :
#  state : prompt
#  action : generated response with RL-LM network (step 4)
#  reward : trained rank score of prompt+response (step 3)
#
# an episode in the standard RL context :
#  s_0 => ... => s_t => a_t ~ policy_net(s_t) => r_t, s_t+1 => ... => end of episode
#
# an episode in the LM context :
#  s = prompt => a = response ~ policy_net(s), r = rank(prompt+response)
#
# important note : There is NO time t in the RL-LM setting !
#                   Mostly because there is no trained reward r_t for partial response
#
# NO value function V(s) is required to be learned !
#  in the standard setting, the value function provides the predicted total discounted reward
#                           to reach the end of the episode
#  in the LM setting, the value function is given by the learned rank function in step 3
#                     it predicts the rank of the prompt+response
#  reminder : min_V || V_t - dr_t ||^2, dr_t = sum_{l=0} gamma^l r_t+l
#              no t => V = dr = r
#
# advantage function is simply A = rank in this setting !
#  reminder of advantage equation : A_t = sum_{l=0} (gamma * beta)^l delta_t+l, delta_t = r_t + gamma * V_t+1 - V_t
#                                   No t => A_t = A = delta = r + (gamma-1) V = r + (gamma-1) r = gamma.r
#             min_Policy - min( ratio_t * A_t , clip(ratio_t) * A_t )
#         <=> min_Policy - min( ratio * A , clip(ratio) * A )
#         <=> min_Policy - min( ratio * gamma.r , clip(ratio) * gamma.r )
#         <=> min_Policy - min( ratio * r , clip(ratio) * r ) as gamma>0 does not change the solution
#
# define the ratio = Policy_Net(a|s) / Policy_Net_previous(a|s), Policy_Net = Probability_RL_LM
#                    Policy_Net(response|prompt) / Policy_Net_previous(response|prompt)
#
# final advantage function is composed of two terms (for maximizing human alignment) :
#  advantage = rank(prompt+response) + beta * mean_{token in response} log( Probability_RL_LM(response|prompt) /
#                                                                           Probability_SFT_LM(response|prompt) )
#              <------------------>           <----------------------------------------------------------------->
#             human ranking (step 3)                           human response to prompt (step 2)
#
# goal : train the RL-LM network to maximize the rank(response)
#         but let the network explore diverse responses learned in step 2
#         (otherwise, RL-LM will only learn one response)
#
# class of reinforcement learning LM network (step 4)
class RL_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.seq_pos_encoding = torch.arange(context_length, device=device) # positional encoding = {0,1,2,...,context_length-1}
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
        self.context_length = context_length
        self.padding = padding_int
        self.eos = eos_int
    # y_RL, prob_PolicyNet_y_RL ~ LM_RL(x_RL)
    def forward(self, batch_idx, list_prompt_RL, len_response): # batch_idx.size=[batch_size], len(list_prompt) =[list_prompt_RL]
        prompts = [list_prompt_RL[idx] for idx in batch_idx] # sample list of prompts, len(prompts)=num_prompt_response
        len_prompt = max([len(prompt) for prompt in prompts]) # compute max of prompt lengths
        batch_size = batch_idx.size(0)
        y_RL = torch.ones(batch_size, max(len_prompt,self.context_length)).long().to(device) * self.padding # initialize with padding
        for idx in range(batch_size): y_RL[idx, -prompts[idx].size(0):] = prompts[idx] # fill batch_predicted_seq with prompt, right-aligned
        x_RL = y_RL # batch of RL prompts
        prob_PolicyNet_y_RL = torch.tensor([]).to(device)
        for idx in range(len_response): # number of auto-regressive prediction
            context = y_RL[:,-self.context_length:] # size=[batch_size, context_length
            H = self.token2vec(context) + self.PE_embedding(self.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
            for transformer_block in self.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
            token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
            token_scores = self.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
            token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
            next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
            next_token_probs = token_probs[torch.arange(batch_size), next_token.squeeze()].unsqueeze(1) # probability of next token, size=[batch_size]
            y_RL = torch.cat((y_RL, next_token), dim=1) # size=[batch_size, current_seq_len+1]
            prob_PolicyNet_y_RL = torch.cat((prob_PolicyNet_y_RL, next_token_probs), dim=1) # size=[batch_size, idx+1]
        y_RL = y_RL[:,-len_response:] # size=[batch_size, len_response]
        mask_eos = torch.zeros(batch_size, len_response).to(device) # size=[batch_size, len_response]
        indices_batch, indices_token = torch.where(y_RL == self.eos) # size=[batch_size, len_(y_RL==self.eos)]
        for b in range(batch_size):
            indices_all_eos = torch.where(indices_batch == b)[0]
            if indices_all_eos.numel()>0:
                indices_first_eos = indices_token[indices_all_eos[0]] + 1 # first index s.t. y_RL==self.eos
                mask_eos[b,:indices_first_eos] = 1.0 # fill out mask with 1 to identify tokens selected for PPO loss
        return x_RL, y_RL, prob_PolicyNet_y_RL, mask_eos # x_RL=[batch_size, context_length], y_RL=[batch_size, len_response], prob_PolicyNet_y_RL=[batch_size, len_response], mask_eos=[batch_size, len_response]
    # prob_y_LM_SFT = LM_RL(x_RL+y_RL)
    def forward_SFT(self, SFT_LMnet, x_RL, y_RL): # x_RL=[batch_size, context_length], y_RL=[batch_size, len_response]
        batch_size = x_RL.size(0); len_response = y_RL.size(1)
        xy_RL = torch.cat( (x_RL,y_RL), dim=1) # size=(batch_size, context_length+len_response)
        # no need auto-regressive, i.e. one-shot prediction of prob of y_SFT
        context = xy_RL[:,-SFT_LMnet.context_length:] # size=[batch_size, context_length]
        H = SFT_LMnet.token2vec(context) + SFT_LMnet.PE_embedding(self.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in SFT_LMnet.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        response_token_scores = H[:,-len_response:,:] # extract last token to predict the next one, size=[batch_size, len_response, d]
        response_token_scores = SFT_LMnet.token_prediction(response_token_scores) # compute scores, size=[batch_size, len_response, num_tokens]
        response_token_probs = torch.softmax(response_token_scores, dim=2) # compute probs, size=[batch_size, len_response, num_tokens]
        prob_y_LM_SFT = torch.tensor([]).to(device)
        for idx in range(len_response):
            prob_y_t = response_token_probs[torch.arange(batch_size),idx,y_RL[torch.arange(batch_size),idx]].unsqueeze(1) # size=[batch_size, 1]
            prob_y_LM_SFT = torch.cat((prob_y_LM_SFT,prob_y_t),dim=1) # size=[batch_size, current_seq_len+1]
        return prob_y_LM_SFT # size=[batch_size, len_response]
    # RM_xy_RL = RM(x_RL+y_RL)
    def forward_RM(self, SL_RMnet, x_RL, y_RL):  # x_RL=[batch_size, context_length], y_RL=[batch_size, len_response]
        batch_size = x_RL.size(0); len_response = y_RL.size(1)
        xy_RL = torch.cat( (x_RL,y_RL), dim=1) # size=(batch_size, context_length+len_response)
        # no need auto-regressive, i.e. one-shot prediction of prob of RM(x_RL+y_RL)
        context = xy_RL[:,-SL_RMnet.context_length:] # size=[batch_size, context_length]
        H = SL_RMnet.SFT_LM.token2vec(context) + SL_RMnet.SFT_LM.PE_embedding(self.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in SL_RMnet.SFT_LM.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token to predict rewards, size=[batch_size, len_response, d]
        RM_xy_RL = SL_RMnet.reward_prediction(token_scores).squeeze() # compute reward scores, size=[batch_size]
        return RM_xy_RL # size=[batch_size]


# use parameters of pre-trained SFT-LM network (step 2) for RL_LM network
print('Parameters of pre-trained SFT-LM network (step 2)')
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print(' num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))

# RL_LM network
RL_LMnet = RL_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
RL_LMnet = RL_LMnet.to(device)
num_param = number_param(RL_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# initialize RL-LM with pre-trained SFT-LM network (step 2)
checkpoint_file = "checkpoint/step2_checkpoint_SFT_LM_23-12-04--13-08-04.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
RL_LMnet.load_state_dict(checkpoint['SFT_LMnet_dict'])

# optimizer
optimizer = torch.optim.AdamW(RL_LMnet.parameters(), lr=1e-5) # lr must be smaller because RM can have high value
warmup = 1 # 50(debug), 50(GPU), number of batches used for warmup <==
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: min(t/warmup, 1.0) ) # warmup learning rate scheduler, good for LM (softmax)
clip_value = 0.2 # clipping value for PPO
num_iter_policy_loss = 4 # 10 number of iteration for policy loss
beta = 0.01 # weight for similarity between RL policy (i.e. human preferences) and SFT-LM (i.e. human prompt-response)

# save checkpoint
net_parameters = {}
net_parameters['num_tokens'] = num_tokens
net_parameters['d'] = d
net_parameters['num_heads'] = num_heads
net_parameters['context_length'] = context_length
net_parameters['dropout'] = dropout
net_parameters['num_layers'] = num_layers
net_parameters['padding_int'] = padding_int
net_parameters['eos_int'] = eos_int
print('checkpoint :',"step4_checkpoint_RL_LM_" + time_stamp + '.pkl', '\n')
checkpoint_dir = os.path.join("checkpoint")
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# batching parameters
num_prompt_RL = len(list_prompt_RL) # number of prompt sequences
batch_size = 3 # debug
batch_size = 50 # batch size, 500 GPU <==
batch_size = 25
num_batch_RL = num_prompt_RL // batch_size # number of batches
print('num_prompt_RL: %d, batch_size: %d, num_batch_RL: %d\n' % (num_prompt_RL, batch_size, num_batch_RL))

# Train network to predict response from prompt
len_response = 15
num_epochs = 1 # 1001(debug), 11(GPU), number of epochs <==
print('num_epochs: ',num_epochs,'\n')
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_prompts_RL_idx = torch.arange(num_prompt_RL).to(device) # initialize the list of prompt
    running_loss = 0.0 # tracking total loss value
    for k in range(num_batch_RL): # number of batches in one epoch
        # y_RL, prob_PolicyNet_y_RL ~ LM_RL(x_RL) : get a batch x_RL of RL prompts and generate responses y_RL and their probabilities prob_PolicyNet_y_RL
        batch_idx, list_prompts_RL_idx = get_batch_RL(batch_size, list_prompts_RL_idx) # sample a batch of indices (prompt,response)
        x_RL, y_RL, prob_PolicyNet_y_RL, mask_eos = RL_LMnet(batch_idx.to(device), list_prompt_RL, len_response) # x_RL=[batch_size, context_length], y_RL=[batch_size, len_response], prob_PolicyNet_y_RL=[batch_size, len_response]
        # RM(x_RL+y_RL) : compute rank score of RL responses with SL-RM network
        RM_xy_RL = RL_LMnet.forward_RM(SL_RMnet, x_RL, y_RL) # size=[batch_size]
        # prob_y_LM_SFT = LM_RL(x_RL+y_RL) : compute probabilities prob_y_SFT of RL responses with reference LM-SFT network
        prob_y_LM_SFT = RL_LMnet.forward_SFT(SFT_LMnet, x_RL, y_RL) # size=[batch_size, len_response]
        # compute advantage function for PPO loss
        advantage = RM_xy_RL - beta * ( torch.log(prob_PolicyNet_y_RL) - torch.log(prob_y_LM_SFT) ).mean(dim=1) # 0.1 # size=[batch_size]
        advantage = advantage.unsqueeze(1).detach() # size=[batch_size,1]
        # Run PPO a few iterations
        log_probs_previous = torch.log(prob_PolicyNet_y_RL) # use log_probs from generation step as reference (fixed during optimization), size=[batch_size, len_response]
        for k in range(num_iter_policy_loss):
            _, _, prob_PolicyNet_y_RL, _ = RL_LMnet(batch_idx.to(device), list_prompt_RL, len_response) # from same prompts, generate new probabilities, size=[batch_size, len_response]
            log_probs = torch.log(prob_PolicyNet_y_RL) # size=[batch_size, len_response]
            policy_ratio = torch.exp( log_probs - log_probs_previous.detach()) # ratio between new optimized policy and previous one, size=[batch_size, len_response]
            clipped_ratio = policy_ratio.clamp(1.0 - clip_value, 1.0 + clip_value) # clipped ratio to allow small changes only, size=[batch_size, len_response]
            policy_ratio = mask_eos * policy_ratio   # tokens after eos do not contribute to the loss, size=[batch_size, len_response]
            clipped_ratio = mask_eos * clipped_ratio # tokens after eos do not contribute to the loss, size=[batch_size, len_response]
            loss = - torch.min( policy_ratio * advantage , clipped_ratio * advantage ).mean() # select the loss with smallest change, scalar
            running_loss += loss.detach().cpu().item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
    loss_epoch = running_loss / num_batch_RL
    if not epoch%1: # 10(debug), 1(GPU) <==
        print('Epoch: %d, time(min): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, (time.time()-start)/60, optimizer.param_groups[0]['lr'], loss_epoch) )
        # save checkpoint
        torch.save({
            'epoch': epoch,
            'tot_time': time.time()-start,
            'loss': loss_epoch,
            'net_parameters': net_parameters,
            'RL_LMnet_dict': RL_LMnet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            }, '{}.pkl'.format(checkpoint_dir + "/step4_checkpoint_RL_LM_" + time_stamp ))
        # print one prompt
        idx_prompt = 0
        print('prompt        :',func_tokens2str(func_indices2tokens(x_RL[idx_prompt][torch.where(x_RL[idx_prompt]==padding_int)[0][-1]+1:].tolist())))   # remove all padding tokens
        print('predicted_seq :',func_tokens2str(func_indices2tokens(y_RL[idx_prompt][:torch.where(y_RL[idx_prompt]==eos_int)[0][0]+1].tolist())),'\n' ) # remove all tokens from first eos token
#         # Stopping condition
#         if loss_epoch < 0.1:
#             print("\n loss value is small -- training stopped\n")
#             break

# GPU training time : Epoch: 4, time(min): 2.185, lr= 0.000010, loss_epoch: -24.283


Parameters of pre-trained SFT-LM network (step 2)
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6
 num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million

checkpoint : step4_checkpoint_RL_LM_23-12-04--15-21-18.pkl 

num_prompt_RL: 1000, batch_size: 25, num_batch_RL: 40

num_epochs:  1 

Epoch: 0, time(min): 1.104, lr= 0.000010, loss_epoch: 34.566
prompt        : Let 9 be the number of terms 70 the starting number and 7 the common difference then write the arithmetic series
predicted_seq : 70 79 82 91 98 97 100 <EOS> 



## Load pre-trained RL-LM network


In [24]:
# pre-trained SL-RM network
checkpoint_file = checkpoint_dir + '/step4_checkpoint_RL_LM_' + time_stamp + '.pkl'
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained RL-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
RL_LMnet = RL_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
RL_LMnet = RL_LMnet.to(device)
RL_LMnet.load_state_dict(checkpoint['RL_LMnet_dict']) # load pre-trained RL-LM network from step #4
num_param = number_param(RL_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# generate new sentence of any length
def generate(LMnet, prompt, max_length_gen_seq):
    LMnet.eval()
    predicted_seq = torch.ones(1, max(prompt.size(0),LMnet.context_length)).long().to(device) * LMnet.padding # initiliaze with padding
    predicted_seq[:, -prompt.size(0):] = prompt # fill batch_predicted_seq with prompt, right-aligned
    for k in range(max_length_gen_seq):
        context = predicted_seq[:,-LMnet.context_length:] # size=[batch_size, context_length
        H = LMnet.token2vec(context) + LMnet.PE_embedding(LMnet.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in LMnet.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
        token_scores = LMnet.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
        token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
        next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
        #next_token = torch.max(token_probs, dim=1).indices[0].view(1,1) # size=(1,1)
        if next_token==LMnet.eos:
            break
        predicted_seq = torch.cat((predicted_seq, next_token), dim=1) # size=[batch_size, current_seq_len+1]
    gen_seq = predicted_seq[0][max(prompt.size(0),LMnet.context_length):]
    return gen_seq

# print one prompt
idx_prompt = torch.randint(low=0, high=num_prompt_RL, size=(1,)).item() # random number in {0,...,num_prompt_RL-1}
print('idx_prompt :',idx_prompt)
prompt = list_prompt_RL[idx_prompt]
print('prompt     :',func_tokens2str(func_indices2tokens(prompt.tolist())))
gen_seq = generate(RL_LMnet, prompt, max_length_gen_seq=15)
print('gen_seq    :',func_tokens2str(func_indices2tokens(gen_seq.tolist())))


Load pre-trained RL-LM: 
 checkpoint file: checkpoint/step4_checkpoint_RL_LM_23-12-04--15-21-18.pkl
 epoch: 0, time: 66.213min, loss=34.5662
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million

idx_prompt : 683
prompt     : make a series of arithmetic type which starts at 61 with 11 elements and 8 common difference value
gen_seq    : 61 67 75 85 89 93 91
