In [1]:

import os, psutil, gc
import time 
import json
import pprint
import copy 

from collections import defaultdict
import random
import numpy as np

import multiprocessing as mp

from dataclasses import dataclass

In [2]:
import torch 
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams, PoolingParams

from sal.config import Config
from sal.search.utils import build_conv, generate_k_steps, last
from sal.utils.score import aggregate_scores

# from core import select_diverse_v31
from core.reward_models import RLHFFlow

from utils.load_data import load_data_prm800k

INFO 06-18 08:57:37 [__init__.py:244] Automatically detected platform cuda.


In [3]:
if torch.cuda.is_available():
    GPUS = os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',')
    print(GPUS)
else:
    print("CUDA is not available.")

['0', '1']


In [4]:
# base_dir
base_dir = '/groups/kjun/tnn/datasets/'

# dataset path
data_dir = base_dir + "/prm800k/math_splits"

# llm and prm path
llm_dir = base_dir + "/Llama-3.2-1B-Instruct-GGUF/Llama-3.2-1B-Instruct.Q4_K_M.gguf"
prm_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data-GGUF/Llama3.1-8B-PRM-Deepseek-Data.Q4_K_M.gguf"

llm_tokenizer_dir = base_dir + "/Llama-3.2-1B-Instruct"
prm_tokenizer_dir = base_dir + "/Llama3.1-8B-PRM-Deepseek-Data"

In [5]:
#  load data 
data_by_levels = load_data_prm800k(data_dir)

# load random_seeds     
# random_seeds = np.loadtxt("random_seeds.txt").astype("int64")
# random_seeds = [int(seed) for seed in random_seeds]

1: 43
2: 90
3: 105
4: 128
5: 134


In [6]:
# baseline: gpu_memory_utilization=0.2
# use the standard model 
llm_vllm = LLM(
        model = llm_tokenizer_dir,
        tensor_parallel_size=1,
        gpu_memory_utilization = 0.7,  # Utilize 50% of GPU memory
        # enable_prefix_caching=True,  # V100 doesn't support enable_prefix_caching 
        # enable_chunked_prefill=False, # and enable_chunked_prefill
        max_model_len = 5000,
        dtype = "float16",
        seed = 0)
    
    # # use the gguf quantized model 
    # llm_regular = LLM(
    #     model = llm_dir,
    #     tokenizer = llm_tokenizer_dir,
    #     tensor_parallel_size=1,
    #     gpu_memory_utilization = 0.2,  # Utilize 50% of GPU memory
    #     max_model_len = 5000,
    #     dtype = "float16",
    #     seed = 123)


gc.collect();torch.cuda.empty_cache();
print('#--- memory:', torch.cuda.memory_allocated(0)/(1024**3))
print('#--- memory:', torch.cuda.memory_allocated(1)/(1024**3))

INFO 06-18 08:57:54 [config.py:823] This model supports multiple tasks: {'generate', 'classify', 'reward', 'embed', 'score'}. Defaulting to 'generate'.
INFO 06-18 08:57:54 [llm_engine.py:230] Initializing a V0 LLM engine (v0.9.1) with config: model='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', speculative_config=None, tokenizer='/groups/kjun/tnn/datasets//Llama-3.2-1B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=5000, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metric

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 06-18 08:57:58 [default_loader.py:272] Loading weights took 1.32 seconds
INFO 06-18 08:57:59 [model_runner.py:1203] Model loading took 2.3185 GiB and 1.569181 seconds
INFO 06-18 08:58:00 [worker.py:294] Memory profiling takes 0.66 seconds
INFO 06-18 08:58:00 [worker.py:294] the current vLLM instance can use total_gpu_memory (31.73GiB) x gpu_memory_utilization (0.70) = 22.21GiB
INFO 06-18 08:58:00 [worker.py:294] model weights take 2.32GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.19GiB; the rest of the memory reserved for KV Cache is 18.62GiB.
INFO 06-18 08:58:00 [executor_base.py:113] # cuda blocks: 38125, # CPU blocks: 8192
INFO 06-18 08:58:00 [executor_base.py:118] Maximum concurrency for 5000 tokens per request: 122.00x
INFO 06-18 08:58:02 [model_runner.py:1513] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in

Capturing CUDA graph shapes:   0%|          | 0/35 [00:00<?, ?it/s]

INFO 06-18 08:58:19 [model_runner.py:1671] Graph capturing finished in 17 secs, took 0.13 GiB
INFO 06-18 08:58:19 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 19.95 seconds
#--- memory: 20.959717750549316
#--- memory: 0.0


In [7]:
tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
llm_tf = AutoModelForCausalLM.from_pretrained(llm_tokenizer_dir).to("cuda:1")
# model_regular.generation_config.pad_token_id = tokenizer.eos_token_id
gc.collect();torch.cuda.empty_cache();
print('#--- memory:', torch.cuda.memory_allocated(0)/(1024**3))

#--- memory: 20.959717750549316


In [8]:
prm = RLHFFlow(model_path=prm_tokenizer_dir, device_map='cuda:2')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
@dataclass
class Beam:
    q_idx: int
    question: str
    templated_conv: str
    current_text: str | None
    next_texts: list[str] | None
    lookahead_texts: list[str] | None
    stop_reasons: list[str | None] | None
    best_scores: list[float]  # the PRM scores
    all_scores: list[list[float]]  # all PRM scores
    previous_text: str | None
    pruned: False
    history: list[str]
    completed: bool = False
    completion_tokens: int = 0


def _diverse_select(K, V, q_embeds, q_lprobs, q_ppl, q_scores, ds_alpha, ds_beta=0):
    num_arms = len(q_embeds)
    _V = copy.deepcopy(V)
    selected_idxes = []
    tol = 0.0001
    for it in range(K):
        _V_inv = np.linalg.inv(_V)
        q_diversity = np.einsum('ij,jk,ik->i', q_embeds, _V_inv, q_embeds)
        # print(q_scores.shape)
        # print(q_diversity.shape)
        q_vals = ds_beta*q_scores + ds_alpha*q_diversity
        # print(q_vals.shape)
        max_val = np.max([val for idx, val in enumerate(q_vals) if idx not in selected_idxes])
        # candidate_idxes = np.where(np.abs(q_vals-max_val) < tol)[0]
        candidate_idxes = [
            arm_idx for arm_idx, arm_val in enumerate(q_vals)
            if (np.abs(max_val - arm_val) <= tol) and (arm_idx not in selected_idxes)
        ]

        best_idx = max(candidate_idxes, key=lambda i: q_lprobs[i])
        # print(q_vals)
        # print(q_lprobs)
        # print(candidate_idxes)
        # print(best_idx)
        
        best_embeds = q_embeds[best_idx]
        best_embeds = best_embeds.reshape(-1, 1)
        # print(best_embeds.shape)

        # update V
        _V = _V + np.matmul(best_embeds, best_embeds.T)

        # update selected_idxes
        selected_idxes.append(best_idx)

        # print(_V.shape)
        # print(max_val)
        # print(max_idx)
        # print(selected_idxes)

    return selected_idxes
    



In [None]:
stop

In [None]:

    
def select_diverse_search(batch_of_questions, config, llm_vllm, llm_tf, llm_tokenizer, prm):
    sampling_params = SamplingParams(
        temperature=config.temperature,
        max_tokens=config.max_tokens,
        top_p=config.top_p,
        stop=["\n\n"],
        include_stop_str_in_output=True,
        n=1,
    )

    ndim = 2048
    V = config.lam*np.eye(ndim)
    K = int(config.n / config.beam_width)
    
    completed_beams: list[Beam] = []
    beams: list[Beam] = []
    for q_idx, question in enumerate(batch_of_questions):
        for _ in range(config.n):
            beams.append(
                Beam(
                    q_idx=q_idx,
                    question=question,
                    templated_conv="",
                    current_text="",
                    next_texts=None,
                    lookahead_texts=None,
                    pruned=False,
                    completed=False,  # New flag to track completion
                    stop_reasons=None,
                    history=[],
                    best_scores=[],
                    all_scores=[],
                    previous_text=None,
                    completion_tokens=0,
                )
            ) 

    active_beams = beams
    # for i in tqdm(range(config.num_iterations), desc="Beam search iterations"):
    start_time = time.time()
    for it in range(config.num_iterations):
        # print(f"\n-> {it}")
        # if it != 0:
        #     # active_beams = [b for b in active_beams if not b.pruned]
        #     extended_beams = []
        #     for beam in active_beams:
        #         if beam.pruned:
        #             continue 
                    
        #         for j in range(config.beam_width):
        #             extended_beams.append(copy.deepcopy(beam))

        #     active_beams = extended_beams
        
        # print(len(active_beams))
        convs = [
            build_conv(b.question, b.current_text, config.system_prompt)
            for b in active_beams
        ]

        add_generation_prompt = it == 0
        continue_final_message = it > 0
    
        tokenizer = llm_vllm.get_tokenizer()
    
        if config.custom_chat_template is not None:
            tokenizer.chat_template = config.custom_chat_template
            
        templated_convs = tokenizer.apply_chat_template(
            convs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tokenize=False,
        )

        # Last iteration, generate to EOS
        if it == config.num_iterations - 1:
            sampling_params = SamplingParams(
                temperature=config.temperature,
                max_tokens=config.max_tokens,
                top_p=config.top_p,
                n=1,
            )

        lookahead = 0 if it == config.num_iterations - 1 else config.lookahead
        gen_results = generate_k_steps(
            templated_convs, lookahead, llm_vllm, sampling_params, 1
        )
        # total_time = time.time() - start_time
        # print(f"it takes {total_time:0.4f}s")

        # Collecct gen_results into beams
        for beam, gen_result in zip(active_beams, gen_results, strict=True):
            beam.next_texts = gen_result.next_texts
            beam.stop_reasons = gen_result.stop_reasons
            beam.lookahead_texts = gen_result.lookahead_texts
            beam.completion_tokens += gen_result.completion_tokens
            beam.current_text += gen_result.next_texts[0]
            # beam.history.append(beam.next_texts[0])
            beam.templated_prompt = gen_result.prompt
            # pprint.pprint(gen_result)
            # print(f"beam.next_texts = {beam.next_texts}")
            # print(f"beam.stop_reasons = {beam.stop_reasons}")
            # print(f"beam.lookahead_texts = {beam.lookahead_texts}")
            # print(f"beam.lookahead_texts = {beam.lookahead_texts}")
            # stop
            
            if (
                beam.stop_reasons[0] == "EOS"
                or beam.stop_reasons[0] == "length"
                or beam.next_texts[0] == ""
            ):
                beam.completed = True
                completed_beams.append(beam)
                continue
            
        # Filter out comleted beams 
        active_beams = [b for b in active_beams if not b.completed]
        # print(len(active_beams))

        # Early stopping if all beams are completed
        if len(active_beams) == 0:
            print("break")
            break

        # Compute prm scores
        all_prompts = []
        all_completions = []
        for b_idx, beam in enumerate(active_beams):
            all_prompts.append(beam.question)
            all_completions.append([beam.current_text])

        # print(all_completions)
        all_scores = prm.score(all_prompts, all_completions)
        # print(all_scores)
        # for score in all_scores:
        #     print(score)
        #     for s in score:
        #         print(s)
        all_agg_scores = [
            [aggregate_scores(s, config.agg_strategy) for s in score]
            for score in all_scores
        ]
        # print(all_agg_scores)
        
        # Extract completion's embeddings and other info
        batch_prms = [[] for _ in range(len(batch_of_questions))]
        batch_embeds = [[] for _ in range(len(batch_of_questions))]
        batch_log_probs = [[] for _ in range(len(batch_of_questions))]
        batch_ppl = [[] for _ in range(len(batch_of_questions))]
        batch_beams = [[] for _ in range(len(batch_of_questions))]
        
        batch_scores = [[] for _ in range(len(batch_of_questions))]
        for b_idx, beam in enumerate(active_beams):
            batch_scores[beam.q_idx].append(all_agg_scores[b_idx][0])
            
            with torch.no_grad():
                # get beam.current_text which include previous all steps upto now
                gen_prompt = beam.templated_prompt + beam.next_texts[0]
                # print(gen_prompt)
                # stop
                inputs = llm_tokenizer(gen_prompt, return_tensors="pt").to(llm_tf.device)
                outputs = llm_tf(**inputs, output_hidden_states=True)
    
                # Get last_token_embeds
                last_hidden_state = outputs.hidden_states[-1]
                last_token_embeds = last_hidden_state[:, -1, :].squeeze(0).detach().cpu().numpy()
                # print(last_token_embeds.shape)
    
                # Compute otuput_log_prob
                # Prepare labels: shift input_ids to the right by one
                labels = inputs['input_ids'][:, 1:]   
                shifted_logits = outputs.logits[:, :-1, :]
                loss_fct = CrossEntropyLoss(reduction='sum')
                completion_log_prob = -loss_fct(shifted_logits.view(-1, shifted_logits.size(-1)), labels.view(-1)).detach().cpu().numpy()
                completion_ppl = np.exp(completion_log_prob/len(labels))
                # print(sent_ppl)
                # print(loss)
    
                # normalize the embeds
                if config.normalize_embeds:
                    norm = np.linalg.norm(last_token_embeds)
                    last_token_embeds /= norm
                    # print(np.linalg.norm(last_token_embeds))
    
                batch_embeds[beam.q_idx].append(last_token_embeds)
                batch_log_probs[beam.q_idx].append(completion_log_prob)
                batch_ppl[beam.q_idx].append(completion_ppl)
                batch_beams[beam.q_idx].append(beam)
    
        # pprint.pprint(len(batch_completions_embeds))
        # pprint.pprint(len(batch_completions_log_probs))
        # pprint.pprint(len(batch_completions_ppl))
        # print(len(batch_beams))
        # print(len(batch_beams[0]))
        # total_time = time.time() - start_time
        # print(f"it takes {total_time:0.4f}s")

        
        # Use _select_diverse to diversify embeddings 
        for q_idx in range(len(batch_of_questions)):
            
            if len(batch_beams[q_idx]) <= K:
                continue 
            
            selected_idxes = _select_diverse_prm(
                K, V, batch_embeds[q_idx], batch_log_probs[q_idx], batch_ppl[q_idx], 
                batch_scores[q_idx], config.ds_alpha)
                
            for idx, beam in enumerate(batch_beams[q_idx]):
                if idx not in selected_idxes:
                    beam.pruned = True 

        # create next active beams
        next_active_beams = []
        for beam in active_beams:
            if beam.pruned:
                continue 
                
            for j in range(config.beam_width):
                next_active_beams.append(copy.deepcopy(beam))

        active_beams = next_active_beams

                        
        total_time = time.time() - start_time
        # print(f"it takes {total_time:0.4f}s")

    stop
    # Filter duplicate active beams
    if config.filter_duplicates:
        # Create a dictionary to filter duplicates and retain order
        unique_beam_dict = {}
        for i, b in enumerate(completed_beams):
            if b.current_text not in unique_beam_dict:
                unique_beam_dict[b.current_text] = (
                    i  # Map the unique text to its index
                )
        completed_beams = [completed_beams[i] for i in unique_beam_dict.values()]
            
    # Collect the completions from beams
    completions = [[] for _ in range(len(batch_of_questions))]
    # completion_ntokens = [[] for _ in range(len(batch_of_questions))]
    
    for beam in completed_beams:
        completions[beam.q_idx].append(beam.current_text)
        # completion_ntokens[beam.q_idx].append(beam.current_text)

    results = defaultdict(list)
    results["completions"] = completions
    # results["completion_ntokens"] = completion_ntokens
    
    return results

        

# general params
config = Config()
config.n = 8
config.beam_width = 2
config.lookahead = 0
config.num_iterations = 2

# diverse_select params
config.lam = 10
config.normalize_embeds = True

config.ds_alpha = 1.0

level = 4
num_questions = len(data_by_levels[level])
num_questions = 2
num_trials = 1
print(f"num_questions = {num_questions}")
print(f"num_trials = {num_trials}")

# get batch of questions
batch_of_questions = [data_by_levels[level][q_idx]['problem'] for q_idx in range(num_questions)]

results = select_diverse_search(batch_of_questions, config, llm_vllm, llm_tf, tokenizer, prm)

In [None]:
# print(len(beam_results))
# pprint.pprint(results)
print(len(results["completions"][2]))