In [1]:
from transformers import LlamaForSequenceClassification
import torch
from torch.nn.utils import clip_grad_value_
import torch.optim as optim
import numpy as np
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
MAX_LEN = 1200

##### change model path and data used to train!

#### 1. Fine-Tune Value Function

In [2]:
version = "5"
Model_Path = f'../Model/PRM_LORA_merge{version}_code'
head_path = f'../Model/model_score{version}_code.pth'
next_version = str(int(version) + 1)

In [3]:
import pickle
with open(f"../llmOutputs/PRM/data_V1_code{version}.pickle", "rb") as f:
    data_V = pickle.load(f)
# with open("../llmOutputs/PRM/data_pi1.pickle", "rb") as f:
#     data_pi = pickle.load(f)
with open(f"../llmOutputs/PRM/completed_paths_y_code{version}.pickle", "rb") as f:
    completed_paths_y = pickle.load(f)

# TODO: remove in next iteration
with open(f"../llmOutputs/PRM/completed_paths_y_code{version}_rlPolicy.pickle", "rb") as f:
    completed_paths_y2 = pickle.load(f)

completed_paths_y.extend(completed_paths_y2)
data = []
for text,y in data_V:
    data.append([text.replace("<｜begin▁of▁sentence｜>User: ",""),y])
for y,score,text,code,prob_i,exit_i in completed_paths_y:
    data.append([text.replace("<｜begin▁of▁sentence｜>User: ",""),y])
import random
random.shuffle(data)
texts,ys = zip(*data)

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-math-7b-rl")
texts = tokenizer.batch_encode_plus(texts,return_attention_mask=False,add_special_tokens=True,\
                                    truncation=True, max_length=MAX_LEN)['input_ids']

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
def from_gen(texts,ys):
    data = list(zip(texts,ys))
    random.shuffle(data)
    for text,y in data:
        text = torch.tensor(text,device='cuda')[None]
        y = torch.tensor([y],device='cuda',dtype=torch.float32)
        yield text,y

In [6]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3

In [7]:
from transformers import LlamaForSequenceClassification,BitsAndBytesConfig,AutoConfig
import torch
from peft import (
    get_peft_model,
    PeftType,
    LoraConfig)

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
model = LlamaForSequenceClassification.from_pretrained(Model_Path,\
                                                       num_labels=1,\
                                                       device_map="auto",
                                                       torch_dtype="auto",
                                                       quantization_config=quantization_config,
                                                       attn_implementation="flash_attention_2"
                                                       )
model.gradient_checkpointing_enable()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at ../Model/PRM_LORA_merge5_code and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
peft_config = LoraConfig(r=8, # low rank 
                         lora_alpha = 16, # see below 
                         lora_dropout = 0.1, 
                         bias="none",#'none', 'all' or 'lora_only' 
                         target_modules = [ "q_proj", 
                                            "k_proj", 
                                            "v_proj", 
                                            "o_proj", 
                                            "gate_proj", 
                                            "up_proj", 
                                            "down_proj" 
                                        ] 
                        )
base_model = get_peft_model(model.model, peft_config)
base_model.gradient_checkpointing_enable()
# model.config.pad_token_id = tokenizer.pad_token_id
base_model.print_trainable_parameters()
model.score = model.score.float()
model.score.load_state_dict(torch.load(head_path))
model.score.weight.requires_grad_(True);

trainable params: 18,739,200 || all params: 6,509,674,496 || trainable%: 0.2879


FT head

In [10]:
base_params = [param for param in base_model.parameters() if param.requires_grad]
trainable_params = list(model.score.parameters())
                    # list(topic_model.parameters())
optimizer = torch.optim.Adam(trainable_params,lr = lr)

In [11]:
# %debug
loss_fn = torch.nn.BCEWithLogitsLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y) in enumerate(from_gen(texts,ys)):
        with torch.no_grad():
            hidden_states = base_model(text)[0][:,-1].float() # 1,d
        logits = model.score(hidden_states)[:,0] # 1,
        loss = loss_fn(logits,y)
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0
            
        torch.cuda.empty_cache()

iter: 1023, 
 train loss: 0.22503854452406813
iter: 2047, 
 train loss: 0.19690279222504614
iter: 3071, 
 train loss: 0.17981932103066356
iter: 4095, 
 train loss: 0.1670584841767777
iter: 5119, 
 train loss: 0.16929041448020143
iter: 6143, 
 train loss: 0.19290116799493262
iter: 7167, 
 train loss: 0.21981135978239763
iter: 8191, 
 train loss: 0.1823250137604191
iter: 9215, 
 train loss: 0.20385578264722426
iter: 10239, 
 train loss: 0.17016758029421908
iter: 11263, 
 train loss: 0.15799389205130865
iter: 12287, 
 train loss: 0.16727510656346567
iter: 13311, 
 train loss: 0.19434696441385313
iter: 14335, 
 train loss: 0.14517014800367178
iter: 15359, 
 train loss: 0.2005108221746923
iter: 16383, 
 train loss: 0.1813583382499928
iter: 17407, 
 train loss: 0.17493958562772605
iter: 18431, 
 train loss: 0.21305208181183843
iter: 19455, 
 train loss: 0.19158705954032484
iter: 20479, 
 train loss: 0.18986034848603595
iter: 21503, 
 train loss: 0.16337557299448235
iter: 22527, 
 train loss:

FT head and backbone

In [12]:
base_params = [param for param in base_model.parameters() if param.requires_grad]
trainable_params =  base_params + list(model.score.parameters())
                    # list(topic_model.parameters())
optimizer = torch.optim.Adam(trainable_params,lr = lr)

In [13]:
loss_fn = torch.nn.BCEWithLogitsLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y) in enumerate(from_gen(texts,ys)):
        hidden_states = base_model(text)[0][:,-1].float() # 1,d
        logits = model.score(hidden_states)[:,0] # 1,
        loss = loss_fn(logits,y)
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0
            
        torch.cuda.empty_cache()

iter: 1023, 
 train loss: 0.19985639877086214
iter: 2047, 
 train loss: 0.205439392640983
iter: 3071, 
 train loss: 0.1569971104763681
iter: 4095, 
 train loss: 0.19266401687127654
iter: 5119, 
 train loss: 0.1545685434666666
iter: 6143, 
 train loss: 0.18454259471036494
iter: 7167, 
 train loss: 0.15482757396239322
iter: 8191, 
 train loss: 0.17882152074889746
iter: 9215, 
 train loss: 0.16669698896112095
iter: 10239, 
 train loss: 0.15513845881832822
iter: 11263, 
 train loss: 0.178124901346564
iter: 12287, 
 train loss: 0.17896800365087984
iter: 13311, 
 train loss: 0.16427032280080311
iter: 14335, 
 train loss: 0.14667499106144533
iter: 15359, 
 train loss: 0.1548061366620459
iter: 16383, 
 train loss: 0.16954763650483073
iter: 17407, 
 train loss: 0.17013626196603582
iter: 18431, 
 train loss: 0.15307152537025104
iter: 19455, 
 train loss: 0.1349567736742756
iter: 20479, 
 train loss: 0.15425231566041475
iter: 21503, 
 train loss: 0.13406191560534353
iter: 22527, 
 train loss: 0.1

In [14]:
torch.save(model.score.state_dict(), f'../Model/model_score{next_version}_code.pth')
peft_model_id = f"../Model/PRM_LORA{next_version}_code"
!mkdir peft_model_id
base_model.save_pretrained(peft_model_id)

mkdir: cannot create directory ‘peft_model_id’: File exists




In [15]:
del model,base_model,texts,hidden_states,logits,loss
import gc
gc.collect()
torch.cuda.empty_cache()
from transformers import LlamaModel
model = LlamaModel.from_pretrained(Model_Path,\
                                    device_map="auto",
                                    torch_dtype="auto",
                                    attn_implementation="flash_attention_2"
                                    )
from peft import PeftModel
peft_model_id = f"../Model/PRM_LORA{next_version}_code"
base_model = PeftModel.from_pretrained(model, peft_model_id)
base_model2 = base_model.merge_and_unload()
# !mkdir '../Model/PRM_LORA_merge3_code'
base_model2.save_pretrained(f'../Model/PRM_LORA_merge{next_version}_code')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
# from transformers import LlamaForSequenceClassification
# model = LlamaForSequenceClassification.from_pretrained(f'../Model/PRM_LORA_merge{next_version}_code',\
#                                                        num_labels=1,\
#                                                        ignore_mismatched_sizes=True,
#                                                        device_map="auto",
#                                                        torch_dtype="auto",
#                                                        attn_implementation="flash_attention_2"
#                                                        )
# model.score.load_state_dict(torch.load(f'../Model/model_score{next_version}_code.pth'))
# loss_fn = torch.nn.BCEWithLogitsLoss()
# train_loss = 0
# count_loss = 0

# for epoch in range(epochs):
#     for i,(text,y) in enumerate(from_gen(texts,ys)):
#         with torch.no_grad():
#             hidden_states = model.model(text)[0][:,-1] # 1,d
#             logits = model.score(hidden_states)[:,0] # 1,
#             loss = loss_fn(logits,y)
#         train_loss += loss.item()
#         count_loss += 1

#         if (i + 1) % verbose == 0:
#             print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
#             train_loss = 0
#             count_loss = 0
            
#         torch.cuda.empty_cache()