# Binary Function Name Recovery

This notebook contains experiments for binary function name recovery.

## 1. Load Dataset

We first load the test datasets. Here, `test_dataset` is for decompiled-only and retrieval-augmented inputs, while `probed_test_dataset` is for ProRec inputs.

In [None]:
from importlib import reload
from datasets import load_dataset
from tqdm import tqdm
import json
import os
import pickle


# load dataset
CACHE_DIR = '../save/.cache'
OUTPUT_DIR ='../save/binsum'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

test_dataset = load_dataset(
    'PurCL/lmpa-prorec-test-1k',
    cache_dir=CACHE_DIR
)['test']
probed_test_dataset = load_dataset(
    'PurCL/lmpa-prorec-test-probed-c34b-1k',
    cache_dir=CACHE_DIR
)['test']
print(test_dataset)
print(probed_test_dataset)

## 2. Load Datastore

The retrieval results are pre-computed and stored in the `test_dataset` in the form of datastore source function ids. Therefore, we need to access the datastore to map those ids to real functions.

In [None]:
from huggingface_hub import hf_hub_download
from evaluator.binsum_evaluator import tokenizer
from tools.retriever import CrossModalRetriever

hf_hub_download(
    repo_id='PurCL/casp-moco-lmpa-c-only',
    subfolder='index',
    filename='keys.pkl',
    cache_dir=CACHE_DIR
)
hf_hub_download(
    repo_id='PurCL/casp-moco-lmpa-c-only',
    subfolder='index',
    filename='key_embeddings.npy',
    cache_dir=CACHE_DIR
)
index_path = '../save/.cache/models--PurCL--casp-moco-lmpa-c-only/snapshots/472d22fe7590403f9078799c6b83ff645eff9534/index'

retriever = CrossModalRetriever()   # without model
retriever.load_index(index_path)

## 3. Pre-Processing

We randomly sample 100 examples for reproduction. The pre-processing includes normalizing strings in the decompiled code and constructing inputs for three methods to compare.

In [None]:
import random
import re

# sampled_indexes = random.sample(range(len(test_dataset)), 10)
sampled_indexes = random.sample(range(len(test_dataset)), 100)


string_pattern = r'([\'"])(?:(?=(\\?))\2.)*?\1'
def normalize_string(raw_decompiled_function):
    normalized_decompiled_function = re.sub(string_pattern, '<some_string>', raw_decompiled_function)
    return normalized_decompiled_function


def extract_probed(probed_sources):
    strip_len = len('<asm_token>\n')
    processed_probed_sources = []
    for ps in probed_sources:
        asm_idx = ps.find('<asm_token>')
        if asm_idx == -1:
            processed_probed_sources.append(ps)
        else:
            pps = ps[asm_idx + strip_len:]
            processed_probed_sources.append(pps)
    return processed_probed_sources

# TOP_K = 3
# TOP_K = 4
TOP_K = 5

batch_lmpa = []
batch_dec = []
batch_dec_w_pro = []
batch_src = []
batch_dec_w_ret = []
for ex_q, ex_p in zip(test_dataset.select(sampled_indexes), \
                probed_test_dataset.select(sampled_indexes)):
    assert ex_q['src'] == ex_p['oracle_source']
    batch_lmpa.append(ex_q['lmpa'])
    decompiled_function = eval(ex_q['lmpa'])['query'].split('Q:[')[0].strip()
    decompiled_function = normalize_string(decompiled_function)
    batch_dec.append(decompiled_function)
    batch_dec_w_pro.append(
        (decompiled_function, extract_probed(ex_p['probed_sources'])[:TOP_K])
    )
    batch_src.append(ex_p['oracle_source'])
    retrieved_sources = []
    for score, idx in zip(ex_q['retrieved_index'][0], ex_q['retrieved_index'][1]):
        ret_src = retriever.index['keys'][int(idx)]
        retrieved_sources.append(tokenizer.decode(tokenizer.encode(ret_src)[:512]))
    batch_dec_w_ret.append(
        (decompiled_function, retrieved_sources[:TOP_K])
    )

## 4. Prompt GPT3.5

First, construct prompter. Provide your own OpenAI token.

In [None]:
from tools.prompter import (
    OpenAIDecompFuncNamer 
)

# key
secret_key = 'YOUR_OPENAI_KEY'

MODEL_NAME = 'gpt-3.5-turbo-1106'
MODEL_NAME_SHORT = 'gpt3.5'

dec_func_namer = OpenAIDecompFuncNamer(
    secret_key,
    model_name=MODEL_NAME
)

generation_config = {'temperature': 0.2}

### 4.1 Generate Function Name with Only Decompiled Function

In [None]:
decname_results = []
for inputs in batch_dec:
    decname_results.append(dec_func_namer.generate(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'decsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(decname_results, f)

### 4.2 Generate Function Name with Retrieval-based Contexts

In [None]:
ragname_results = []
for inputs in batch_dec_w_ret:
    ragname_results.append(dec_func_namer.generate_with_augmentation(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'ragsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(ragname_results, f)

### 4.3 Generate Function Name with ProRec Contexts

In [None]:
parname_results = []
for inputs in batch_dec_w_pro:
    parname_results.append(dec_func_namer.generate_with_augmentation(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'parsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(parname_results, f)

## 5. Evaluate with SymLM Metrics

In [None]:
import evaluator.bfname_evaluator
reload(evaluator.bfname_evaluator)
from evaluator.bfname_evaluator import (
    camel_to_snake,
    extract_function_name,
    evaluate_func_names
)

with open(os.path.join(OUTPUT_DIR, f'aggregated_{MODEL_NAME_SHORT}_results.json'), 'w') as f:
    agg_results = []
    for decname_res, parname_res, ragname_res, lmpa_data, src, dw_pro, dw_ret in \
        zip(decname_results, parname_results, ragname_results, batch_lmpa, batch_src, batch_dec_w_pro, batch_dec_w_ret):
        decname = camel_to_snake(extract_function_name(decname_res.choices[0].message.content))
        if decname is None:
            print(decname_res.choices[0].message.content)
            assert 0
        parname = camel_to_snake(extract_function_name(parname_res.choices[0].message.content))
        ragname = camel_to_snake(extract_function_name(ragname_res.choices[0].message.content))

        agg_results.append(
            {
                'lmpa': lmpa_data,
                'src': src,
                'dec': dw_pro[0],
                'decname': decname,
                'parname': parname,
                'probed_sources': dw_pro[1],
                'ragname': ragname,
                'retrieved_sources': dw_ret[1],
            }
        )
    print(len(agg_results))
    json.dump(agg_results, f, indent=2)

In [None]:
from evaluator.bfname_evaluator import (
    extract_function_name,
    evaluate_func_names,
)

def parse_src_func_name(src_func):
    src_func_name = re.split(' |\n|\t', src_func.split('(')[0])
    if src_func_name[-1] == '':  # ... Member (long AllocSize)
        src_func_name = src_func_name[-2]
    else:
        src_func_name = src_func_name[-1]
    return src_func_name

evaluate_func_names(
    [parse_src_func_name(src) for src in batch_src],
    decname=[camel_to_snake(extract_function_name(r.choices[0].message.content)) for r in decname_results],
    ragname=[camel_to_snake(extract_function_name(r.choices[0].message.content)) for r in ragname_results],
    parname=[camel_to_snake(extract_function_name(r.choices[0].message.content)) for r in parname_results]
)