# Step #3 : Supervised Learning (SL) of Reward Model (RM)

## Task : from prompt, rank positive and negative responses

## Use pre-trained SFT-LM model from Step #2 and froze it to use as backbone for training the reward

### Xavier Bresson, xavier.bresson@gmail.com, https://twitter.com/xbresson

### Number of data points for GPT-3, 175B parameters
+ Step #1 : 300B tokens
+ Step #2 : 10k-100k pairs (prompt, response)
+ Step #3 : 100k-1M triples (prompt, positive response, negative response)
+ Step #4 : 10k-100k prompts

### Number of data points for this tutorial
+ Step #1 : 3M tokens
+ Step #2 : 10k pairs (prompt, response)
+ Step #3 : 10k triples (prompt, positive response, negative response)
+ Step #4 : 1k prompts

### Objectives
+ Supervised learning of reward
+ Freeze a LM network to use as backbone for reward prediction
+ Train with batch of pairs (prompt, positive response, negative) for fast training with GPU


In [1]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/ACE_NLP_Dec23_codes/codes/labs_vanillaLLMs'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

In [2]:
# Libraries
import torch
print(torch.__version__)
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings
import os, datetime


2.1.0


## Time stamp for save/load data


In [3]:
# save time stamp
time_stamp = datetime.datetime.now().strftime("%y-%m-%d--%H-%M-%S")

# check dataset folder exists 
data_dir = os.path.join('dataset')
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
# select a time stamp
use_saved_time_stamp = False
use_saved_time_stamp = True
if use_saved_time_stamp:
    time_stamp = '23-11-24--13-06-33' # trained on GPU on xxx

print('time_stamp:', time_stamp, '\n')


time_stamp: 23-11-24--13-06-33 



## Load dictionary of tokens from step #1


In [6]:
load_file_dictionary = 'dataset/step1_02_SSL_dictionary_23-11-23--12-26-17.pt'
dictionary, num_tokens, token2index, index2token = torch.load(load_file_dictionary) # load dictionary of tokens
print('dictionary:',dictionary,'\n')
print('num_tokens (unique):',num_tokens,'\n')
print('token2index:', token2index,'\n')
print('index2token:', index2token,'\n')
func_tokens2indices = lambda list_tokens: [token2index[token] for token in list_tokens] # ['Let', '5', 'be', 'the'] => [113, 46, 114, 115]
func_indices2tokens = lambda list_ints: [index2token[integer] for integer in list_ints] # [113, 46, 114, 115] => ['Let', '5', 'be', 'the']
func_str2tokens = lambda input_str: [token_str for token_str in input_str.split()]      # 'Let 5 be the' => ['Let', '5', 'be', 'the']
func_tokens2str = lambda list_str: ' '.join(list_str)                                   # ['Let', '5', 'be', 'the'] => 'Let 5 be the'
func_strs2pytorchs = lambda list_strs: torch.tensor([int(token_str) for token_str in list_strs])          # ['2', '4', '6', '8'] => tensor([2, 4, 6, 8]) 
func_pytorchs2strs = lambda list_pytorchs: ' '.join([str(pytorch) for pytorch in list_pytorchs.tolist()]) # tensor([2, 4, 6, 8]) => ['2', '4', '6', '8']
func_ints2str = lambda list_ints: ' '.join([str(integer) for integer in list_ints]) # [8, 15, 22, 29] => '8 15 22 29'
func_str2ints = lambda input_str: [int(string) for string in input_str.split()]     # '8 15 22 29' => [8, 15, 22, 29]
 

dictionary: ['50', '56', '62', '68', '74', '80', '86', '<SEP>', '20', '24', '28', '32', '36', '14', '18', '22', '26', '30', '34', '38', '42', '46', '54', '58', '53', '61', '69', '77', '85', '93', '37', '45', '51', '66', '71', '76', '81', '91', '96', '0', '9', '27', '19', '23', '31', '35', '39', '43', '47', '2', '10', '33', '41', '49', '57', '65', '73', '89', '97', '88', '29', '44', '59', '64', '83', '92', '95', '98', '40', '15', '21', '25', '55', '60', '78', '87', '79', '84', '94', '48', '52', '90', '5', '8', '11', '17', '100', '99', '75', '63', '67', '70', '72', '3', '12', '82', '1', '7', '13', '16', '6', '4', 'generate', 'an', 'arithmetic', 'series', 'with', 'terms', 'starting', 'value', 'and', 'common', 'difference', 'Let', 'be', 'the', 'number', 'of', 'then', 'write', 'make', 'a', 'type', 'which', 'starts', 'at', 'elements', '<PAD>', '<EOS>'] 

num_tokens (unique): 129 

token2index: {'50': 0, '56': 1, '62': 2, '68': 3, '74': 4, '80': 5, '86': 6, '<SEP>': 7, '20': 8, '24': 9, '28':

## Load pre-trained SFT-LM network (step #2)


In [7]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# GPU training
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda") # use GPU
else:
    device = torch.device("cpu")
print('device:',device,'\n')

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training 
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation 
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask.to(device))[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H

# class of supervised fine-tuning LM network (step 2)
class SFT_LM(nn.Module):
    def __init__(self, num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int):
        super().__init__()
        self.token2vec = token2vec(num_tokens, d) # token embedding layer
        self.seq_pos_encoding = torch.arange(context_length, device=device) # positional encoding = {0,1,2,...,context_length-1}
        self.PE_embedding = nn.Embedding(context_length, d) # positional encoding embedding layer
        self.transformer_blocks = nn.ModuleList([ TransformerBlock(d, context_length, num_heads, dropout) for _ in range(num_layers) ]) # multiple transformer block layers
        self.token_prediction = nn.Linear(d, num_tokens) # next token prediction layer
        self.context_length = context_length
        self.padding = padding_int
        self.eos = eos_int 
    # Note : No forward function is needed
        
# load pre-trained SFT-LM network (step 2)
checkpoint_file = "checkpoint/step2_checkpoint_SFT_LM_23-11-24--12-53-11.pkl"
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SFT-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = SFT_LM_net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SFT_LMnet = SFT_LM(num_tokens, d, context_length, num_heads, dropout, num_layers, padding_int, eos_int)
SFT_LMnet = SFT_LMnet.to(device)
SFT_LMnet.load_state_dict(checkpoint['SFT_LMnet_dict']) # load pre-trained SFT-LM network (step 2)
num_param = number_param(SFT_LMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint

# generate new sentence of any length
def generate(LMnet, prompt, max_length_gen_seq):
    #LMnet.train()
    predicted_seq = torch.ones(1, max(prompt.size(0),LMnet.context_length)).long().to(device) * LMnet.padding # initiliaze with padding
    predicted_seq[:, -prompt.size(0):] = prompt # fill batch_predicted_seq with prompt, right-aligned
    for k in range(max_length_gen_seq):
        context = predicted_seq[:,-LMnet.context_length:] # size=[batch_size, context_length
        H = LMnet.token2vec(context) + LMnet.PE_embedding(LMnet.seq_pos_encoding[:context.size(1)]).unsqueeze(0) # size=[batch_size, context_length, d]
        for transformer_block in LMnet.transformer_blocks: H = transformer_block(H) # size=(batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token to predict the next one, size=[batch_size, d]
        token_scores = LMnet.token_prediction(token_scores) # compute scores, size=[batch_size, num_tokens]
        token_probs = torch.softmax(token_scores, dim=1) # compute probs, size=[batch_size, num_tokens]
        next_token = torch.multinomial(token_probs, num_samples=1) # sample next token, size=[batch_size, 1]
        #next_token = torch.max(token_probs, dim=1).indices[0].view(1,1) # size=(1,1)
        if next_token==LMnet.eos:
            break
        predicted_seq = torch.cat((predicted_seq, next_token), dim=1) # size=[batch_size, current_seq_len+1]
    gen_seq = predicted_seq[0][max(prompt.size(0),LMnet.context_length):]
    return gen_seq


NVIDIA RTX A5000
device: cuda 

Load pre-trained SFT-LM: 
 checkpoint file: checkpoint/step2_checkpoint_SFT_LM_23-11-24--12-53-11.pkl
 epoch: 8, time: 586.115min, loss=0.0496
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10761345 / 10.76 million



## Generate training set of (prompt, positive response, negative response) with rewards

### Use pre-trained SFT-LM network (step #2)


In [10]:
# generate arithmetic series
m = max_value = 100 # maximum value in the sequence
def arithmetic_series(m, s, d, n):
    seq = []
    for i in range(n):
        v = s + i * d
        if v <= m:
            seq.append(v)
        else:
            break
    return seq

# Generate training data:
#  for each prompt, we collect one positive response and the associated reward
#                          and one negative response and the associated reward
#
# The response is generated auto-regressively with the pre-trained LM of step 2
#
# Two responses are sampled and the positive response is the one with the largest reward value 
#
# Reward is defined as r = r_min + ( r_max - r_min ) / (  1 + beta * || exact_response - generated_response || )
#                          with r_min=1 (worst), r_max=7 (best)
#
# Training data structure:
#  list_positive_responses = [ [prompt_1 + positive_response_1], [positive_reward_1] 
#                              [prompt_2 + positive_response_2], [positive_reward_2] 
#                                ...
#                              [prompt_N + positive_response_N], [positive_reward_N] ]
#
#  list_negative_responses = [ [prompt_1 + negative_response_1], [negative_reward_1] 
#                              [prompt_2 + negative_response_2], [negative_reward_2] 
#                                ...
#                              [prompt_N + negative_response_N], [negative_reward_N] ]
#
save_training_data = False
#save_training_data = True
if save_training_data:
    
    # collect "human" training set
    list_positive_response = [] # list of prompts + positive responses
    list_negative_response = [] # list of prompts + negative responses
    num_training_data = 12 # debug
    num_training_data = 500 # number of pairs of (prompt, positive response) and (prompt, negative response), e.g. GPU 10,000 training data
    start = time.time()
    num_data = 0
    num_iterations = 0
    start = time.time()
    while num_data < num_training_data:

        # parameters for arithmetic series
        m = max_value # maximum value in the sequence
        s = torch.randint(low=0, high=m, size=(1,)).item() # starting integer of the series
        d = torch.randint(low=1, high=10, size=(1,)).item() # value of common difference
        n = torch.randint(low=5, high=15, size=(1,)).item() # number of element in the series
        #print('max_value: %d, start_value: %d, common_difference: %d, number_of_terms: %d' % (m,s,d,n))

        # generate prompt : sample a prompt between 3 candidate prompts
        prompt = {}
        prompt[1] = 'generate an arithmetic series with ' + str(n) + ' terms starting with value ' + str(s) + ' and common difference ' + str(d)
        prompt[2] = 'make a series of arithmetic type which starts at ' + str(s) + ' with ' + str(n) + ' elements and ' + str(d) + ' common difference value'
        prompt[3] = 'Let ' + str(n) + ' be the number of terms ' + str(s) + ' the starting number and ' + str(d) + ' the common difference then write the arithmetic series'
        random_int = torch.randint(low=1, high=3+1, size=(1,)).item() # random number in {1,2,3}
        prompt_str = prompt[random_int]
        #print('prompt         :',prompt_str)
        prompt_ind_pytorch = torch.tensor(func_tokens2indices(func_str2tokens(prompt_str))) # convert str to pytorch indices
        prompt_seq1_ind_pytorch = prompt_ind_pytorch; prompt_seq2_ind_pytorch = prompt_ind_pytorch # initializing prompt+response 
    
        # exact response
        exact_response_token_pytorch = torch.tensor(arithmetic_series(m,s,d,n))
        #print('exact_response :',func_ints2str(exact_response_token_pytorch.tolist()))
        
        # generate two responses
        gen_seq1_ind_pytorch = generate(SFT_LMnet, prompt_ind_pytorch, max_length_gen_seq=20) # sample one response
        gen_seq2_ind_pytorch = generate(SFT_LMnet, prompt_ind_pytorch, max_length_gen_seq=20) # sample another response
        
        # remove non-integer tokens, i.e. words
        list_int = [str(x) for x in torch.arange(max_value).tolist()] # list of all integers in string format, e.g. ['0', '1', '2', ... , '99']
        gen_seq1_token = func_indices2tokens(gen_seq1_ind_pytorch.tolist()) # convert to tokens, e.g. ['90', '46', 'make', '71']
        gen_seq1_token = [str(i) for i in gen_seq1_token if i in list_int] # remove non-integer tokens, e.g. ['90', '46', '71']
        gen_seq1_ind_pytorch = torch.tensor(func_tokens2indices(gen_seq1_token)) # back to pytorch indices, e.g. tensor([45, 68, 3])
        gen_seq2_token = func_indices2tokens(gen_seq2_ind_pytorch.tolist()) # convert to tokens
        gen_seq2_token = [str(i) for i in gen_seq2_token if i in list_int] # remove non-integer tokens
        gen_seq2_ind_pytorch = torch.tensor(func_tokens2indices(gen_seq2_token)) # back to pytorch indices
        #print('seq1           :',func_tokens2str(gen_seq1_token))
        #print('seq2           :',func_tokens2str(gen_seq2_token))
        
        if gen_seq1_ind_pytorch.size(0)>0 and gen_seq2_ind_pytorch.size(0)>0: # generated sequences have integer tokens, otherwise go to bext generation

            # concatenate prompt + generated sequences
            prompt_seq1_ind_pytorch = torch.cat( (prompt_seq1_ind_pytorch, gen_seq1_ind_pytorch) ) # concatenate prompt + seq1
            prompt_seq2_ind_pytorch = torch.cat( (prompt_seq2_ind_pytorch, gen_seq2_ind_pytorch) ) # concatenate prompt + seq2

            # compute rewards
            max_size = max(exact_response_token_pytorch.size(0), gen_seq1_ind_pytorch.size(0), gen_seq2_ind_pytorch.size(0)) # e.g. 4
            exact_response_token_pytorch = nn.functional.pad(exact_response_token_pytorch,(0, max_size-exact_response_token_pytorch.size(0)), 'constant', exact_response_token_pytorch[-1]) # padding to get same-size vectors, e.g. tensor([ 0,  8, 16, 24])
            gen_seq1_token_pytorch = func_strs2pytorchs(gen_seq1_token)        
            gen_seq1_token_pytorch = nn.functional.pad(gen_seq1_token_pytorch,(0, max_size-gen_seq1_token_pytorch.size(0)), 'constant', gen_seq1_token_pytorch[-1]) # padding to get same-size vectors, e.g. tensor([ 0,  8, 8, 8])
            gen_seq2_token_pytorch = func_strs2pytorchs(gen_seq2_token)        
            gen_seq2_token_pytorch = nn.functional.pad(gen_seq2_token_pytorch,(0, max_size-gen_seq2_token_pytorch.size(0)), 'constant', gen_seq2_token_pytorch[-1]) # padding to get same-size vectors, e.g. tensor([ 0,  7, 15, 15])
            r_min = 1; r_max = 7
            reward1 = r_min + (r_max - r_min) * ( 1 + 0.1*( (exact_response_token_pytorch - gen_seq1_token_pytorch).abs() ).float().sum().sqrt() )**(-1)
            reward2 = r_min + (r_max - r_min) * ( 1 + 0.1*( (exact_response_token_pytorch - gen_seq2_token_pytorch).abs() ).float().sum().sqrt() )**(-1)
            #print('reward #1      : %.2f' % reward1.item() ) 
            #print('reward #2      : %.2f' % reward2.item() )

            # add samples to training dataset if reward1 not= reward2
            if reward1 > reward2:
                list_positive_response.append([prompt_seq1_ind_pytorch, reward1])
                list_negative_response.append([prompt_seq2_ind_pytorch, reward2])
                num_data += 1
            elif reward1 < reward2:
                list_positive_response.append([prompt_seq2_ind_pytorch, reward2])
                list_negative_response.append([prompt_seq1_ind_pytorch, reward1])
                num_data += 1

            # print
            if not num_iterations%500: # 2 (debug), 1000 (GPU) 
                print('num_iterations: %d, num_data: %d, time(min): %.3f' % (num_iterations, num_data, (time.time()-start)/60) )
                print('prompt         :',prompt_str)
                print('exact_response :',func_ints2str(exact_response_token_pytorch.tolist()))
                print('seq1           :',func_tokens2str(gen_seq1_token))
                print('seq2           :',func_tokens2str(gen_seq2_token))
                print('reward #1      : %.2f' % reward1.item() ) 
                print('reward #2      : %.2f' % reward2.item() )
            num_iterations += 1
        
    # print
    print('\nnumber of training data (prompt, positive response, negative response) :',len(list_positive_response),'\n')
    for idx, (positive, negative) in enumerate(zip(list_positive_response[:3],list_negative_response[:3])):
        pos_response = func_tokens2str(func_indices2tokens(positive[0].tolist()))
        neg_response = func_tokens2str(func_indices2tokens(negative[0].tolist()))
        pos_reward = positive[1].item()
        neg_reward = negative[1].item()
        print('training_set[%d]: ' % idx )
        print('  pos_response : %s, pos_reward : %.2f ' % (pos_response, pos_reward) )
        print('  neg_response : %s, neg_reward : %.2f ' % (neg_response, neg_reward), '\n' )
    
    # save training data
    save_file = data_dir + '/step3_01_SLRM_training_set_' + time_stamp + '.pt'
    print('save_file:', save_file, '\n')
    torch.save([list_positive_response, list_negative_response],save_file) # save list of positive and negative responses

else:
      
    # load training data
    load_file = data_dir + '/step3_01_SLRM_training_set_' + time_stamp + '.pt'
    print('load_file:', load_file, '\n')
    list_positive_response, list_negative_response = torch.load(load_file) # load list of positive and negative responses

    # print
    print('number of training data (prompt, positive response, negative response) :',len(list_positive_response),'\n')
    for idx, (positive, negative) in enumerate(zip(list_positive_response[:3],list_negative_response[:3])):
        pos_response = func_tokens2str(func_indices2tokens(positive[0].tolist()))
        neg_response = func_tokens2str(func_indices2tokens(negative[0].tolist()))
        pos_reward = positive[1].item()
        neg_reward = negative[1].item()
        print('training_set[%d]: ' % idx )
        print('  pos_response : %s, pos_reward : %.2f ' % (pos_response, pos_reward) )
        print('  neg_response : %s, neg_reward : %.2f ' % (neg_response, neg_reward), '\n' )


load_file: dataset/step3_01_SLRM_training_set_23-11-24--13-06-33.pt 

number of training data (prompt, positive response, negative response) : 500 

training_set[0]: 
  pos_response : Let 5 be the number of terms 15 the starting number and 5 the common difference then write the arithmetic series 15 20 25 30 35, pos_reward : 7.00 
  neg_response : Let 5 be the number of terms 15 the starting number and 5 the common difference then write the arithmetic series 15 23 25 30 35, neg_reward : 6.11  

training_set[1]: 
  pos_response : Let 12 be the number of terms 12 the starting number and 2 the common difference then write the arithmetic series 12 14 16 18 20 22 24 26 28 30 32 34, pos_reward : 7.00 
  neg_response : Let 12 be the number of terms 12 the starting number and 2 the common difference then write the arithmetic series 12 13 16 18 20 22 24 26 28 30 32 34, neg_reward : 6.45  

training_set[2]: 
  pos_response : make a series of arithmetic type which starts at 79 with 12 elements and

## Get batch of sampled indices for (positive response, negative response)

In [11]:
# batching parameters
num_prompt_response = len(list_positive_response) # number of prompt sequences
batch_size = 3 # debug
#batch_size = 100 # batch size, 500 GPU 
num_batch = num_prompt_response // batch_size # number of batches
print('num_prompt_response: %d, batch_size: %d, num_batch: %d\n' % (num_prompt_response,batch_size,num_batch))

# sample batch of prompt+response
def get_batch(batch_size, list_prompt_response_idx):
    batch_idx = torch.randperm(list_prompt_response_idx.size(0))[:batch_size] # sample B number of batch indices
    batch_idx = list_prompt_response_idx[batch_idx] # and extract from remaining list of batch indices
    if list_prompt_response_idx.size(0) > batch_size:
        new_list_prompt_response_idx = torch.stack([i for i in list_prompt_response_idx if i not in batch_idx]) # remove the sampled batch indices from the list of indices
    else:
        new_list_prompt_response_idx = torch.tensor([]) # last batch of epoch, i.e. return empty list
    return batch_idx, new_list_prompt_response_idx

# # one epoch (debug)
# list_prompt_response_idx = torch.arange(num_prompt_response) # list of prompt+response indices
# for i in range(num_batch):
#     print('batch :',i)
#     print('list_prompt_response_idx (before) :',list_prompt_response_idx, list_prompt_response_idx.size())
#     batch_idx, list_prompt_response_idx = get_batch(batch_size, list_prompt_response_idx) # sample a batch of indices (prompt,response)
#     print('batch_idx :', batch_idx, batch_idx.size())
#     print('list_prompt_response_idx (after) :',list_prompt_response_idx, list_prompt_response_idx.size(),'\n')


num_prompt_response: 500, batch_size: 3, num_batch: 166



## Train reward model with supervised learning 

## Use pre-trained LM model from step #2 with frozen layers as backbone 

## Dataset is composed of (positive response, negative response)

In [14]:
torch.manual_seed(0) # use same initial seed for reproducibility

# compute number of network parameters
def number_param(net):
    nb_param = 0
    for param in net.parameters():
        nb_param += param.numel()
    return nb_param

# GPU training
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device = torch.device("cuda") # use GPU
else:
    device = torch.device("cpu")
print('device:',device,'\n')

# token embedding layer : convert seq of integers to seq of vectors
class token2vec(nn.Module):
    def __init__(self, num_tokens, d):
        super().__init__()
        self.token2vec = nn.Embedding(num_tokens, d) # map integer to one-hot vector (num_tokens dimensions), and project vector to d-dimentional space
    def forward(self, batch_int):
        batch_vec = self.token2vec(batch_int) # size=[batch_size, batch_length, d]
        return batch_vec

# multiple attention heads layer
class multiple_head_attention(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        d_head = d // num_heads
        assert d == d_head * num_heads # check divisiblity
        self.MHA = nn.MultiheadAttention(d, num_heads, batch_first=True, dropout=dropout)
        self.mask = torch.tril(torch.ones(context_length, context_length))==0 # mask to make attention to previous tokens only : { token(<=t) }, size=(context_length,context_length)
                   # torch.tril(ones) = True in the up-right part, True means *no* attention allowed in pytorch implementation
        self.context_length = context_length
    def forward(self, H):
        if H.size(1) == self.context_length: # training <==
            attn_mask = self.mask
        else: # when batch_length not= context_length, e.g. inference time / sequence generation <==
            current_batch_length = H.size(1)
            attn_mask = torch.tril(torch.ones(current_batch_length, current_batch_length))==0
        H_heads = self.MHA(H, H, H, attn_mask=attn_mask.to(device))[0] # pytorch implementation, size=[batch_size, batch_length, d]
        return H_heads

# Transformer block layer
class TransformerBlock(nn.Module):
    def __init__(self, d, context_length, num_heads, dropout):
        super().__init__()
        self.MHA = multiple_head_attention(d, context_length, num_heads, dropout)
        self.LN_MHA = nn.LayerNorm(d)
        self.MLP = nn.Sequential(nn.Linear(d,4*d), nn.ReLU(), nn.Dropout(dropout), nn.Linear(4*d,d))
        self.LN_MLP = nn.LayerNorm(d)
    def forward(self, H):
        H = H + self.MHA(self.LN_MHA(H)) # size=[batch_size, batch_length, d]
        H = H + self.MLP(self.LN_MLP(H)) # size=[batch_size, batch_length, d]
        return H
    
# Predict positive reward given context = { prompt + positive response }
#         negative reward given context = { prompt + negative response }
#
# Prediction is done into 2 stages :
#         1. compute self-attention between last token of the sequence and the context using pre-trained SFT-LM (step 2)
#         2. apply small MLP to predict scalar reward
#
# Example of reward prediction for a BATCH of sequences = (prompt + positive/negative response) => GPU
#
# Prepare a batch of sampled (prompt+response)
#  Let the token P = <Padding> 
#
#                    context_size = 7 (all prompt+response have the same context length with padding if needed) => GPU
#                 ---------------------
# batch_seq    = [ P, P, 1, 2, 3, 4, 5 ]  // prompt = [1, 2, 3] + positive response = [4, 5]
#              = [ 1, 2, 3, 4, 7, 8, 1 ]  // prompt = [1, 2, 3] + negative response = [4, 7, 8, 1]
#                         ...
#              = [ P, P, 5, 6, 7, 8, 9 ]  // prompt = [5, 6] + positive response = [7, 8, 9]
#              = [ P, P, P, 5, 6, 3, 4 ]  // prompt = [5, 6] + negative response = [3, 4]
#                        -------------
#                                    | <= compute self-attention between last token = 4 and context = [5, 6, 3]
#                                    | <= then apply MLP to predict reward 
# batch_reward_scores            = [ 5.3 ] // predicted positive reward for [1, 2, 3, 4, 5]
#                                = [ 2.9 ] // predicted negative reward for [1, 2, 3, 4, 7, 8, 1]
#                                    ...
#                                = [ 6.2 ] // predicted positive reward for [5, 6, 7, 8, 9]
#                                = [ 1.8 ] // predicted negative reward for [5, 6, 3, 4]
#
# Supervised learning network for reward (step 3)
class SL_RM(nn.Module):
    def __init__(self, SFT_LM, d, context_length, padding_int, eos_int):
        super().__init__()
        self.SFT_LM = SFT_LM # token embedding layer
        self.reward_prediction = nn.Sequential(nn.LayerNorm(d), nn.Linear(d,d), nn.ReLU(), nn.Linear(d,1)) # reward prediction layer 
        self.context_length = context_length
        self.padding = padding_int 
        self.eos = eos_int 
    def forward(self, batch_idx, list_positive_response, list_negative_response): # batch_idx.size=[batch_size], len(list_prompt,list_response) =[num_prompt_response]
        pos_responses = [list_positive_response[idx][0] for idx in batch_idx] # sample list of pos_responses, len(pos_responses)=num_prompt_response
        len_pos_response = max([len(response) for response in pos_responses]) # compute max of pos_response lengths
        neg_responses = [list_negative_response[idx][0] for idx in batch_idx] # sample list of neg_responses, len(neg_responses)=num_prompt_response
        len_neg_response = max([len(response) for response in neg_responses]) # compute max of neg_response lengths
        len_response = max(len_pos_response, len_neg_response) # compute max of pos_response and neg_response lengths
        batch_size = batch_idx.size(0)
        batch_seq = torch.ones(2* batch_size, max(len_response,self.context_length)).long().to(device) * self.padding # context, initialize with padding, size=[2* batch_size, context_length]
        for idx in range(batch_size): batch_seq[2*idx, -pos_responses[idx].size(0):] = pos_responses[idx] # fill context with pos_responses, right-aligned
        for idx in range(batch_size): batch_seq[2*idx+1, -neg_responses[idx].size(0):] = neg_responses[idx] # fill context with neg_responses, right-aligned
        H = self.SFT_LM.token2vec(batch_seq) + self.SFT_LM.PE_embedding(self.SFT_LM.seq_pos_encoding[:batch_seq.size(1)]).unsqueeze(0) # size=[2* batch_size, context_length, d]
        for transformer_block in self.SFT_LM.transformer_blocks: H = transformer_block(H) # size=[2* batch_size, context_length, d)
        token_scores = H[:,-1,:] # extract last token scores to predict rewards, size=[2* batch_size, d]
        reward_scores = self.reward_prediction(token_scores) # compute reward scores, size=[2* batch_size, 1]
        return reward_scores
        
# use parameters of pre-trained SFT-LM network (step 2) for SL_RM network
print('Parameters of pre-trained SFT-LM network (step 2)')
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print(' num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))

# SL_RM network
SL_RMnet = SL_RM(SFT_LMnet, d, context_length, padding_int, eos_int)
SL_RMnet = SL_RMnet.to(device)
num_param = number_param(SL_RMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )

# freeze SFT-LM network (step 2) during training
for name, param in SL_RMnet.named_parameters():
    if param.requires_grad and 'SFT_LM' in name:
        param.requires_grad = False
        
# optimizer
optimizer = torch.optim.AdamW(SL_RMnet.parameters(), lr=3e-4) # standard optimizer for LMs
warmup = 500 # 50(debug), 500(GPU), number of batches used for warmup 
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: min(t/warmup, 1.0) ) # warmup learning rate scheduler, good for LM (softmax)

# save checkpoint
net_parameters = SFT_LM_net_parameters  
checkpoint_dir = os.path.join("checkpoint")
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print('checkpoint file :', checkpoint_dir + '/step3_checkpoint_SL_RM_' + time_stamp + '.pkl', '\n')

# batching parameters
num_prompt_response = len(list_positive_response) # number of prompt+response sequences
batch_size = 3 # debug
batch_size = 100 # batch size, 100 GPU <==
num_batch = num_prompt_response // batch_size # number of batches
print('num_prompt_response: %d, batch_size: %d, num_batch: %d\n' % (num_prompt_response, batch_size, num_batch))

# Understanding the loss
#
# 1. loss_rewards = MSE( predicted_reward, label_reward )
#
# reward_scores                  = [ 5.3 ] // predicted positive reward for seq_1 <= index 2*i   for pos, i = 0, 1, ..., num_prompts-1
#   = predicted_reward           = [ 2.9 ] // predicted negative reward for seq_1 <= index 2*i+1 for neg, i = 0, 1, ..., num_prompts-1
#                                    ...
#                                = [ 6.2 ] // predicted Positive reward for seq_B
#                                = [ 1.8 ] // predicted Negative reward for seq_B
#
# 2. loss_rank = - log( sigmoid( positive_reward - negative_reward ) )
#
# 3. total_loss = loss_rewards + cst * loss_rank
#
# Train network to predict reward from (pos_response, neg_response)
num_epochs = 1001 # 1001(debug), 1001(GPU), number of epochs 
print('num_epochs: ',num_epochs,'\n')
start = time.time()
for epoch in range(num_epochs): # number of epochs
    list_prompt_response_idx = torch.arange(num_prompt_response).to(device) # initialize the list of (pos_response, neg_response) indices
    running_loss = 0.0 # tracking total loss value
    for k in range(num_batch): # number of batches in one epoch
        batch_idx, list_prompt_response_idx = get_batch(batch_size, list_prompt_response_idx) # sample a batch of indices (pos_response, neg_response)
        reward_scores = SL_RMnet(batch_idx.to(device), list_positive_response, list_negative_response) # predict rewards, size=[2*batch_size, 1]
        reward_labels = [ [list_positive_response[idx][1],list_negative_response[idx][1]] for idx in batch_idx ]
        reward_labels = torch.tensor(reward_labels).view(2*batch_size,1).to(device) # size=[2*batch_size, 1]
        diff_rewards = reward_scores[0:2*batch_size:2,:] - reward_scores[1:2*batch_size+1:2,:] # difference of rewards, sise=[batch_size, 1]
        loss_rank = - torch.log(torch.sigmoid(diff_rewards)).mean() # rank loss
        loss_mse = nn.MSELoss()(reward_scores, reward_labels) # regression loss for rewards
        loss = 1.0*loss_mse + 0.1* loss_rank
        running_loss += loss.detach().cpu().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
    loss_epoch = running_loss / num_batch
    if not epoch%100: # 1(debug), 100(GPU)
        print('Epoch: %d, time(min): %.3f, lr= %.6f, loss_epoch: %.3f' % (epoch, (time.time()-start)/60, optimizer.param_groups[0]['lr'], loss_epoch) )
        # print rewards
        idx_prompt = 0
        print('pos response token :',func_tokens2str(func_indices2tokens(list_positive_response[batch_idx[idx_prompt]][0].tolist())))
        print('pos reward label   :',reward_labels[2*idx_prompt,:].squeeze().item())
        print('pos reward pred    :',reward_scores[2*idx_prompt,:].squeeze().item())
        print('neg response token :',func_tokens2str(func_indices2tokens(list_negative_response[batch_idx[idx_prompt]][0].tolist())))
        print('neg reward label   :',reward_labels[2*idx_prompt+1,:].squeeze().item())
        print('neg reward pred    :',reward_scores[2*idx_prompt+1,:].squeeze().item(),'\n')
        # save checkpoint
        torch.save({
            'epoch': epoch,
            'tot_time': time.time()-start,
            'loss': loss_epoch,
            'net_parameters': net_parameters,
            'SL_RMnet_dict': SL_RMnet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            }, '{}.pkl'.format(checkpoint_dir + "/step3_checkpoint_SL_RM_" + time_stamp ))
        # Stopping condition
        if loss_epoch < 0.01: 
            print("\n loss value is small -- training stopped\n")
            break

# GPU training time : Epoch: 1000, time(min): 3.051, lr= 0.000300, loss_epoch: 0.177


NVIDIA RTX A5000
device: cuda 

Parameters of pre-trained SFT-LM network (step 2)
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6
 num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10910338 / 10.91 million

checkpoint file : checkpoint/step3_checkpoint_SL_RM_23-11-24--13-06-33.pkl 

num_prompt_response: 500, batch_size: 100, num_batch: 5

num_epochs:  1001 

Epoch: 0, time(min): 0.003, lr= 0.000003, loss_epoch: 43.714
pos response token : Let 14 be the number of terms 33 the starting number and 3 the common difference then write the arithmetic series 33 36 39 42 45 48 51 54 57 60 63 66 69 72
pos reward label   : 7.0
pos reward pred    : -0.19049903750419617
neg response token : Let 14 be the number of terms 33 the starting number and 3 the common difference then write the arithmetic series 33 36 39 42 45 48 51 54 57 60 63 66 69 66
neg reward label   : 5.819474697113037
neg reward pred    : -0.29708772897720337 

Epoch: 100, ti

## Load pre-trained SL-RM network


In [16]:
# pre-trained SL-RM network
checkpoint_file = 'checkpoint_file/step3_checkpoint_SL_RM_23-11-24--13-06-33.pkl'
checkpoint = torch.load(checkpoint_file, map_location=device)
epoch = checkpoint['epoch']
tot_time = checkpoint['tot_time']
loss = checkpoint['loss']
print('Load pre-trained SFT-LM: \n checkpoint file: {:s}\n epoch: {:d}, time: {:.3f}min, loss={:.4f}'.format(checkpoint_file,epoch,tot_time,loss))
net_parameters = checkpoint['net_parameters']
num_tokens = net_parameters['num_tokens']
d = net_parameters['d']
num_heads = net_parameters['num_heads']
context_length = net_parameters['context_length']
dropout = net_parameters['dropout']
num_layers = net_parameters['num_layers']
padding_int = net_parameters['padding_int']
eos_int = net_parameters['eos_int']
print(' num_tokens: %d, d: %d, context_length: %d, num_heads: %d, dropout: %.2f, num_layers: %d\n' % (num_tokens, d, context_length, num_heads, dropout, num_layers) )
padding_int = torch.tensor([func_tokens2indices('<PAD>'.split())[0]]).to(device) # end-of-sentence token for batch
eos_int = torch.tensor([func_tokens2indices('<EOS>'.split())[0]]).to(device) # end-of-sentence token for batch
print('num_tokens: %d, padding_int: %d, eos_int: %d\n' % (num_tokens, padding_int, eos_int))
SL_RMnet = SL_RM(SFT_LMnet, d, context_length, padding_int, eos_int)
SL_RMnet = SL_RMnet.to(device)
SL_RMnet.load_state_dict(checkpoint['SL_RMnet_dict']) # load pre-trained SL-RM network from step #3
num_param = number_param(SL_RMnet)
print('num_net_parameters: %d / %.2f million\n' % (num_param, num_param/1e6) )
del checkpoint
    
# print
batch_idx = torch.randperm(len(list_positive_response))[:3] # select 3 indices 
reward_scores = SL_RMnet(batch_idx.to(device), list_positive_response, list_negative_response) # predict rewards, size=[2*3, 1]
reward_labels = [ [list_positive_response[idx][1],list_negative_response[idx][1]] for idx in batch_idx ]
reward_labels = torch.tensor(reward_labels).view(2*batch_idx.size(0),1) # size=[2*3, 1]
for idx in range(batch_idx.size(0)): 
    print('pos response token :',func_tokens2str(func_indices2tokens(list_positive_response[batch_idx[idx]][0].tolist())))
    print('pos reward label   :',reward_labels[2*idx,:].squeeze().item())
    print('pos reward pred    :',reward_scores[2*idx,:].squeeze().item())
    print('neg response token :',func_tokens2str(func_indices2tokens(list_negative_response[batch_idx[idx]][0].tolist())))
    print('neg reward label   :',reward_labels[2*idx+1,:].squeeze().item())
    print('neg reward pred    :',reward_scores[2*idx+1,:].squeeze().item(),'\n')
        

Load pre-trained SFT-LM: 
 checkpoint file: checkpoint/step3_checkpoint_SL_RM_23-11-24--13-06-33.pkl
 epoch: 1000, time: 183.084min, loss=0.1769
 num_tokens: 129, d: 384, context_length: 40, num_heads: 6, dropout: 0.10, num_layers: 6

num_tokens: 129, padding_int: 127, eos_int: 128

num_net_parameters: 10910338 / 10.91 million

pos response token : generate an arithmetic series with 10 terms starting with value 46 and common difference 3 46 49 52 55 58 61 64 67 70 73
pos reward label   : 7.0
pos reward pred    : 7.064010143280029
neg response token : generate an arithmetic series with 10 terms starting with value 46 and common difference 3 46 49 52 55 58 61 64 67 64 73
neg reward label   : 5.819474697113037
neg reward pred    : 5.924700736999512 

pos response token : make a series of arithmetic type which starts at 15 with 12 elements and 2 common difference value 15 17 19 21 23 25 27 29 31 33 35 37
pos reward label   : 7.0
pos reward pred    : 6.833649158477783
neg response token : m