In [1]:
from transformers import LlamaForSequenceClassification
import torch
from torch.nn.utils import clip_grad_value_
import torch.optim as optim
import numpy as np
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
MAX_LEN = 1600

#### 1. Fine-Tune Value Function

In [2]:
Model_Path = '../Model/PRM_LORA_merge2'
head_path = '../Model/model_score2.pth'

In [3]:
import pickle
with open("../llmOutputs/PRM/data_V1.pickle", "rb") as f:
    data_V = pickle.load(f)
# with open("../llmOutputs/PRM/data_pi1.pickle", "rb") as f:
#     data_pi = pickle.load(f)
with open("../llmOutputs/PRM/completed_paths_y.pickle", "rb") as f:
    completed_paths_y = pickle.load(f)

data = []
for text,y in data_V:
    data.append([text.replace("<｜begin▁of▁sentence｜>User: ",""),y])
for y,_,text in completed_paths_y:
    data.append([text.replace("<｜begin▁of▁sentence｜>User: ",""),y])
import random
random.shuffle(data)
texts,ys = zip(*data)

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-math-7b-rl")
texts = tokenizer.batch_encode_plus(texts,return_attention_mask=False,add_special_tokens=True,\
                                    truncation=True, max_length=MAX_LEN)['input_ids']

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
def from_gen(texts,ys):
    data = list(zip(texts,ys))
    random.shuffle(data)
    for text,y in data:
        text = torch.tensor(text,device='cuda')[None]
        y = torch.tensor([y],device='cuda',dtype=torch.float32)
        yield text,y

In [6]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3

In [7]:
from transformers import LlamaForSequenceClassification,BitsAndBytesConfig,AutoConfig
import torch
from peft import (
    get_peft_model,
    PeftType,
    LoraConfig)

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
model = LlamaForSequenceClassification.from_pretrained(Model_Path,\
                                                       num_labels=1,\
                                                       device_map="auto",
                                                       torch_dtype="auto",
                                                       quantization_config=quantization_config,
                                                       attn_implementation="flash_attention_2"
                                                       )
model.gradient_checkpointing_enable()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at ../Model/PRM_LORA_merge2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
peft_config = LoraConfig(r=8, # low rank 
                         lora_alpha = 16, # see below 
                         lora_dropout = 0.1, 
                         bias="none",#'none', 'all' or 'lora_only' 
                         target_modules = [ "q_proj", 
                                            "k_proj", 
                                            "v_proj", 
                                            "o_proj", 
                                            "gate_proj", 
                                            "up_proj", 
                                            "down_proj" 
                                        ] 
                        )
base_model = get_peft_model(model.model, peft_config)
base_model.gradient_checkpointing_enable()
# model.config.pad_token_id = tokenizer.pad_token_id
base_model.print_trainable_parameters()
model.score = model.score.float()
model.score.load_state_dict(torch.load(head_path))
model.score.weight.requires_grad_(True);

trainable params: 18,739,200 || all params: 6,509,674,496 || trainable%: 0.287866928085493


FT head

In [10]:
base_params = [param for param in base_model.parameters() if param.requires_grad]
trainable_params = list(model.score.parameters())
                    # list(topic_model.parameters())
optimizer = torch.optim.Adam(trainable_params,lr = lr)

In [None]:
# %debug
loss_fn = torch.nn.BCEWithLogitsLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y) in enumerate(from_gen(texts,ys)):
        with torch.no_grad():
            hidden_states = base_model(text)[0][:,-1].float() # 1,d
        logits = model.score(hidden_states)[:,0] # 1,
        loss = loss_fn(logits,y)
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0
            
        torch.cuda.empty_cache()

FT head and backbone

In [None]:
base_params = [param for param in base_model.parameters() if param.requires_grad]
trainable_params =  base_params + list(model.score.parameters())
                    # list(topic_model.parameters())
optimizer = torch.optim.Adam(trainable_params,lr = lr)

In [None]:
loss_fn = torch.nn.BCEWithLogitsLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y) in enumerate(from_gen(texts,ys)):
        hidden_states = base_model(text)[0][:,-1].float() # 1,d
        logits = model.score(hidden_states)[:,0] # 1,
        loss = loss_fn(logits,y)
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0
            
        torch.cuda.empty_cache()

iter: 1023, 
 train loss: 0.27247323917254107
iter: 2047, 
 train loss: 0.26962604945765634
iter: 3071, 
 train loss: 0.26440685853231116
iter: 4095, 
 train loss: 0.2554999446256261
iter: 5119, 
 train loss: 0.29441069425047317
iter: 6143, 
 train loss: 0.24179704060770746
iter: 7167, 
 train loss: 0.25132670278708247
iter: 8191, 
 train loss: 0.25074620979512474
iter: 9215, 
 train loss: 0.2600249861352495
iter: 10239, 
 train loss: 0.2514084854383327
iter: 11263, 
 train loss: 0.2543830612930833
iter: 12287, 
 train loss: 0.2532799553118821
iter: 13311, 
 train loss: 0.21347219805011264
iter: 14335, 
 train loss: 0.21609343921682012
iter: 15359, 
 train loss: 0.19591548759262878
iter: 16383, 
 train loss: 0.20584044610222918
iter: 17407, 
 train loss: 0.18611015798251174
iter: 18431, 
 train loss: 0.20984597941969696


In [14]:
torch.save(model.score.state_dict(), '../Model/model_score3.pth')
peft_model_id = "../Model/PRM_LORA3"
!mkdir peft_model_id
base_model.save_pretrained(peft_model_id)

mkdir: cannot create directory ‘peft_model_id’: File exists




In [1]:
from transformers import LlamaModel
model = LlamaModel.from_pretrained('deepseek-ai/deepseek-math-7b-rl',\
                                    torch_dtype="auto",\
                                    attn_implementation="flash_attention_2")
from peft import PeftModel
peft_model_id = "../Model/PRM_LORA3"
base_model = PeftModel.from_pretrained(model, peft_model_id)
base_model2 = base_model.merge_and_unload()
!mkdir '../Model/PRM_LORA_merge3'
base_model2.save_pretrained('../Model/PRM_LORA_merge3')

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### 2. Fine-Tune Policy

In [2]:
MODEL_PATH = "deepseek-ai/deepseek-math-7b-rl"

In [3]:
import pickle
with open("../llmOutputs/PRM/completed_paths_y.pickle", "rb") as f:
    completed_paths_y = pickle.load(f)

data = []
for y,_,text in completed_paths_y:
    data.append([text.replace("<｜begin▁of▁sentence｜>User: ",""),y])
import random
random.shuffle(data)
texts,ys = zip(*data)

In [4]:
# normalize rewards
import numpy as np
ys = np.array(ys)
ys = (ys-ys.mean())/ys.std()

In [5]:
import re
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
pattern = r"Please reason step by step, and put your final answer within \\boxed\{\}."
input_ids = []
lengths = []
for text in texts:
    idx = re.search(pattern,text).end()
    question = tokenizer.encode(text[:idx],add_special_tokens=True)
    answer = tokenizer.encode(text[idx:],add_special_tokens=False)
    lengths.append(len(question))
    input_ids.append(question+answer)

In [7]:
def from_gen(texts,ys,lengths):
    data = list(zip(texts,ys,lengths))
    random.shuffle(data)
    for text,y,l in data:
        text = torch.tensor(text[:MAX_LEN],device='cuda')[None]
        yield text,y,l

In [8]:
epochs = 1
accumulation_steps = 64
verbose = 1024
lr = 6e-5
clip = 6e-3

In [9]:
from transformers import AutoModelForCausalLM,BitsAndBytesConfig
import torch
from peft import (
    get_peft_model,
    PeftType,
    LoraConfig)

In [10]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,\
                                            device_map="auto",
                                            torch_dtype="auto",
                                            trust_remote_code=True,
                                            attn_implementation="flash_attention_2"
                                            )
# model.gradient_checkpointing_enable()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Fine-tune head

In [11]:
for param in model.model.parameters():
    param.requires_grad = False
for param in model.lm_head.parameters():
    param.requires_grad = True
model.lm_head = model.lm_head.float()

In [12]:
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(trainable_params,lr = lr)

In [13]:
# %debug
import math
import gc
loss_fn = torch.nn.CrossEntropyLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y,l) in enumerate(from_gen(input_ids,ys,lengths)):
        if i > 0:
            del outs
            gc.collect()
            torch.cuda.empty_cache()
        
        with torch.no_grad():
            outs = model.model(text)[0].float() # 1,l,C
        outs = model.lm_head(outs)
        
        if torch.any(torch.isnan(outs)): continue
        loss = loss_fn(outs[0,l:-1],text[0,l+1:]) * y # (l,C), (l,)
        if math.isinf(loss.item()) or math.isnan(loss.item()): continue
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0

iter: 1023, 
 train loss: -0.013546783548008534
iter: 2047, 
 train loss: -0.02897482163643872
iter: 3071, 
 train loss: -0.01757911852109828
iter: 4095, 
 train loss: -0.034147584029597056
iter: 5119, 
 train loss: -0.037454732966580195
iter: 6143, 
 train loss: -0.034789292309142184
iter: 7167, 
 train loss: -0.030503184903864167
iter: 8191, 
 train loss: -0.03708435299995472
iter: 9215, 
 train loss: -0.02951386814311263
iter: 10239, 
 train loss: -0.06372847682087013
iter: 11263, 
 train loss: -0.07185396562272217
iter: 12287, 
 train loss: -0.03204263923362305


In [24]:
# torch.save(model.lm_head.state_dict(), '../Model/lm_head.pth')
model.lm_head.to(torch.bfloat16) 
model.save_pretrained("../Model/Policy1")

LORA -> OOM

In [11]:
peft_config = LoraConfig(r=8, # low rank 
                         lora_alpha = 16, # see below 
                         lora_dropout = 0.1, 
                         bias="none",#'none', 'all' or 'lora_only' 
                         target_modules = [ "q_proj", 
                                            "k_proj", 
                                            "v_proj", 
                                            "o_proj", 
                                            "gate_proj", 
                                            "up_proj", 
                                            "down_proj" 
                                        ] 
                        )
model = get_peft_model(model, peft_config)
# model.gradient_checkpointing_enable()
# model.config.pad_token_id = tokenizer.pad_token_id
model.print_trainable_parameters()

trainable params: 18,739,200 || all params: 6,929,104,896 || trainable%: 0.2704


In [12]:
trainable_params = [param for param in model.parameters() if param.requires_grad]
optimizer = torch.optim.AdamW(trainable_params,lr = lr)

In [None]:
# %debug
import math
import gc
loss_fn = torch.nn.CrossEntropyLoss()
train_loss = 0
count_loss = 0

for epoch in range(epochs):
    for i,(text,y,l) in enumerate(from_gen(input_ids,ys,lengths)):
        if i > 0:
            del outs
            gc.collect()
            torch.cuda.empty_cache()
        outs = model(text)[0] # 1,l,C
        if torch.any(torch.isnan(outs)): continue
        loss = loss_fn(outs[0,l:-1],text[0,l+1:]) * y # (l,C), (l,)
        if math.isinf(loss.item()) or math.isnan(loss.item()): continue
        loss.backward()
        train_loss += loss.item()
        count_loss += 1
            
        if (i + 1) % accumulation_steps == 0:
            # clip_grad_value_(trainable_params,clip)
            clip_grad_value_(trainable_params,clip)
            optimizer.step()
            optimizer.zero_grad()

        if (i + 1) % verbose == 0:
            print(f"iter: {i}, \n train loss: {train_loss/count_loss}")
            train_loss = 0
            count_loss = 0