# Binary Summarization

This notebook contains experiments for binary summarization.

## 1. Load Dataset

We first load the test datasets. Here, `test_dataset` is for decompiled-only and retrieval-augmented inputs, while `probed_test_dataset` is for ProRec inputs.

In [None]:
from importlib import reload
from datasets import load_dataset
from tqdm import tqdm
import json
import os
import pickle


# load dataset
CACHE_DIR = '../save/.cache'
OUTPUT_DIR ='../save/binsum'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

test_dataset = load_dataset(
    'PurCL/lmpa-prorec-test-1k',
    cache_dir=CACHE_DIR
)['test']
probed_test_dataset = load_dataset(
    'PurCL/lmpa-prorec-test-probed-c34b-1k',
    cache_dir=CACHE_DIR
)['test']
print(test_dataset)
print(probed_test_dataset)

## 2. Load Datastore

The retrieval results are pre-computed and stored in the `test_dataset` in the form of datastore source function ids. Therefore, we need to access the datastore to map those ids to real functions.

In [None]:
from huggingface_hub import hf_hub_download
from evaluator.binsum_evaluator import tokenizer
from tools.retriever import CrossModalRetriever

hf_hub_download(
    repo_id='PurCL/casp-moco-lmpa-c-only',
    subfolder='index',
    filename='keys.pkl',
    cache_dir=CACHE_DIR
)
hf_hub_download(
    repo_id='PurCL/casp-moco-lmpa-c-only',
    subfolder='index',
    filename='key_embeddings.npy',
    cache_dir=CACHE_DIR
)
index_path = '../save/.cache/models--PurCL--casp-moco-lmpa-c-only/snapshots/472d22fe7590403f9078799c6b83ff645eff9534/index'

retriever = CrossModalRetriever()   # without model
retriever.load_index(index_path)

## 3. Pre-Processing

We randomly sample 100 examples for reproduction. The pre-processing includes normalizing strings in the decompiled code and constructing inputs for three methods to compare.

In [None]:
import random
import re

# sampled_indexes = random.sample(range(len(test_dataset)), 10)
sampled_indexes = random.sample(range(len(test_dataset)), 100)


string_pattern = r'([\'"])(?:(?=(\\?))\2.)*?\1'
def normalize_string(raw_decompiled_function):
    normalized_decompiled_function = re.sub(string_pattern, '<some_string>', raw_decompiled_function)
    return normalized_decompiled_function


def extract_probed(probed_sources):
    strip_len = len('<asm_token>\n')
    processed_probed_sources = []
    for ps in probed_sources:
        asm_idx = ps.find('<asm_token>')
        if asm_idx == -1:
            processed_probed_sources.append(ps)
        else:
            pps = ps[asm_idx + strip_len:]
            processed_probed_sources.append(pps)
    return processed_probed_sources

# TOP_K = 3
# TOP_K = 4
TOP_K = 5

batch_lmpa = []
batch_dec = []
batch_dec_w_pro = []
batch_src = []
batch_dec_w_ret = []
for ex_q, ex_p in zip(test_dataset.select(sampled_indexes), \
                probed_test_dataset.select(sampled_indexes)):
    assert ex_q['src'] == ex_p['oracle_source']
    batch_lmpa.append(ex_q['lmpa'])
    decompiled_function = eval(ex_q['lmpa'])['query'].split('Q:[')[0].strip()
    decompiled_function = normalize_string(decompiled_function)
    batch_dec.append(decompiled_function)
    batch_dec_w_pro.append(
        (decompiled_function, extract_probed(ex_p['probed_sources'])[:TOP_K])
    )
    batch_src.append(ex_p['oracle_source'])
    retrieved_sources = []
    for score, idx in zip(ex_q['retrieved_index'][0], ex_q['retrieved_index'][1]):
        ret_src = retriever.index['keys'][int(idx)]
        retrieved_sources.append(tokenizer.decode(tokenizer.encode(ret_src)[:512]))
    batch_dec_w_ret.append(
        (decompiled_function, retrieved_sources[:TOP_K])
    )

## 4. Prompt GPT3.5

First, construct prompter. Provide your own OpenAI token.

In [None]:
from tools.prompter import (
    OpenAIDecompSummarizer, 
    OpenAISourceSummarizer, 
)

# key
secret_key = 'YOUR_OPENAI_KEY'

MODEL_NAME = 'gpt-3.5-turbo-1106'
MODEL_NAME_SHORT = 'gpt3.5'

src_summarizer = OpenAISourceSummarizer(
    secret_key,
    model_name=MODEL_NAME
)
dec_summarizer = OpenAIDecompSummarizer(
    secret_key,
    model_name=MODEL_NAME
)

generation_config = {'temperature': 0.8}

### 4.1 Generate Summarization for Source Function

In [None]:
srcsum_results = []
for inputs in batch_src:
    srcsum_results.append(src_summarizer.generate(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'srcsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(srcsum_results, f)

In [None]:
print(srcsum_results[7].choices[0].message.content)

### 4.2 Generate Summarization with Only Decompiled Function

In [None]:
decsum_results = []
for inputs in batch_dec:
    decsum_results.append(dec_summarizer.generate(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'decsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(decsum_results, f)

### 4.3 Generate Summaration with Retrieval-based Contexts

In [None]:
ragsum_results = []
for inputs in batch_dec_w_ret:
    ragsum_results.append(dec_summarizer.generate_with_augmentation(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'ragsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(ragsum_results, f)

### 4.4 Generate Summarization with ProRec Contexts

In [None]:
parsum_results = []
for inputs in batch_dec_w_pro:
    parsum_results.append(dec_summarizer.generate_with_augmentation(inputs, **generation_config))
with open(os.path.join(OUTPUT_DIR, f'parsum_{MODEL_NAME_SHORT}_results.pkl'), 'wb') as f:
    pickle.dump(parsum_results, f)

## 5. Evaluation with CHRF and METEOR

In [None]:
import evaluator.binsum_evaluator
reload(evaluator.binsum_evaluator)
from evaluator.binsum_evaluator import (
    extract_summary,
    evaluate_automatic_metrics, 
)

with open(os.path.join(OUTPUT_DIR, f'aggregated_{MODEL_NAME_SHORT}_results.json'), 'w') as f:
    agg_results = []
    for srcsum_res, decsum_res, parsum_res, ragsum_res, dw_pro, dw_ret, src, lmpa_data in \
        zip(srcsum_results, decsum_results, parsum_results, ragsum_results, batch_dec_w_pro, 
        batch_dec_w_ret, batch_src, batch_lmpa):
        ref = extract_summary(srcsum_res.choices[0].message.content)
        pred_dec = extract_summary(decsum_res.choices[0].message.content)
        pred_pro = extract_summary(parsum_res.choices[0].message.content)
        pred_rag = extract_summary(ragsum_res.choices[0].message.content)
        agg_results.append(
            {
                'lmpa': lmpa_data,
                'src': src,
                'dec': dw_pro[0],
                'srcsum': ref,
                'decsum': pred_dec,
                'prosum': pred_pro,
                'probed_sources': dw_pro[1],
                'ragsum': pred_rag,
                'retrieved_sources': dw_ret[1],
            }
        )
    print(len(agg_results))
    json.dump(agg_results, f, indent=2)

In [None]:
evaluate_automatic_metrics(
# evaluate_detailed_automatic_metrics(
    [extract_summary(r.choices[0].message.content) for r in srcsum_results],
    decsum=[extract_summary(r.choices[0].message.content) for r in decsum_results],
    ragsum=[extract_summary(r.choices[0].message.content) for r in ragsum_results],
    parsum=[extract_summary(r.choices[0].message.content) for r in parsum_results],
)