# Language Modeling instead of classification head

The paper [MATH-SHEPHERD](https://huggingface.co/datasets/peiyi9979/Math-Shepherd) present a more elegant solution, which is use language modeling then directly estimate from the turn tokens. Although their code is not released but it is quite simple to implement compared to my previous method for PRM.

use native unsloth

In [1]:
from transformers import LlamaForSequenceClassification, AutoTokenizer, LlamaForCausalLM
import torch
import wandb
import os
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "checkpoints/llama3-8b-gsm8k-1epoch", # "unsloth/tinyllama" for 16bit loading
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

#Use LoRA to reduce memory usage:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Currently only supports dropout = 0
    bias = "none",    # Currently only supports bias = "none"
    use_gradient_checkpointing = "unsloth", # @@@ IF YOU GET OUT OF MEMORY - set to True @@@
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

  from .autonotebook import tqdm as notebook_tqdm


==((====))==  Unsloth: Fast Llama patching release 2024.4
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.642 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2. CUDA = 8.9. CUDA Toolkit = 11.8.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.62it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Unsloth 2024.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [2]:

prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant to solve math problems step by step <|eot_id|><|start_header_id|>user<|end_header_id|>

{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{}"""

def formatting_prompts_func(examples):
    texts = []
    
    for instruction, responses, next_response, rating in zip(examples['instruction'], examples['responses'], examples['next_response'], examples['rating']):
        # Combine all responses and the next response into a single string with newline separation
        combined_responses = " + \n".join(responses) + " + \n" + next_response
        if rating == -1:
            combined_responses = combined_responses + " - \n"
        else:
            combined_responses = combined_responses + " + \n"
        
        # Format the text with the prompt template
        text = prompt.format(instruction, combined_responses) 
        texts.append(text)

    
    return {"text": texts,}

    # # Tokenize all texts at once using the tokenizer
    # model_inputs = tokenizer(texts, padding="max_length", truncation=True, max_length=512)

    # # Add labels to the model inputs
    # model_inputs['labels'] = labels
    
    # return model_inputs


If we treat it as a language model task, then the incorrect steps that ends there should also be included in the dataset. It means we still need to append all the incorrect steps.

We can use a token to represent each step's correctness, such as + and -.
For the prediction of the probability of + and - during inference. We use the special token to represent the correctness of the step. We then use softmax to get the probability of correctness.

During training, because decoder is auto-regressive, we don't need to predict the correctness of the step. We can just use the token to represent the correctness of the step. We can use the special token to represent the correctness of the step. We then use softmax to get the probability of the special token, which is the correctness of the step.

In [3]:
from datasets import load_dataset

# Load and preprocess the dataset
dataset = load_dataset("Birchlabs/openai-prm800k-stepwise-critic", split='train')
dataset = dataset.filter(lambda x: x['rating'] is not None)  # Filter entries without ratings

#filter out the examples that has 'next_response' in the responses of the solution
dataset = dataset.filter(lambda x: not(x['rating'] == 1 and x['is_solution'] == False))

#convert ratings of 0 to 1 so we have only binary labels
dataset = dataset.map(lambda x: {'rating': 1 if x['rating'] == 0 else x['rating']})

dataset = dataset.map(formatting_prompts_func, batched=True)  # Apply the preprocessing function

len(dataset)

Filter: 100%|██████████| 1015027/1015027 [00:11<00:00, 89580.28 examples/s]
Map: 100%|██████████| 369283/369283 [00:16<00:00, 21922.39 examples/s]
Map: 100%|██████████| 369283/369283 [00:01<00:00, 189351.47 examples/s]


369283

In [4]:
from trl import SFTTrainer
from transformers import TrainingArguments
from transformers.utils import logging
logging.set_verbosity_info()

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, 
    args = TrainingArguments(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 1,
        learning_rate = 2e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 100,
        save_steps= 5000,
        save_total_limit=2,
        optim = "adamw_8bit",
        weight_decay = 0.1,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "checkpoints/llama3-8b-critic-lora",
        report_to= "wandb"
    ),
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Map (num_proc=2): 100%|██████████| 369283/369283 [00:51<00:00, 7231.68 examples/s] 
Using auto half precision backend


In [5]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 4090. Max memory = 23.642 GB.
15.404 GB of memory reserved.


In [6]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 369,283 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 46,160
 "-____-"     Number of trainable parameters = 83,886,080
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjq394[0m ([33mneurorunner[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


KeyboardInterrupt: 

In [None]:
model.save_pretrained("checkpoints/llama3-8b-critic-lora") # Local saving
