## Algorithm 1: Naive Paired Cost Analysis

In [1]:
import os
import re
import json
import argparse
from collections import defaultdict

def parse_files(directory):
    """
    Parses all 'ipra_analysis_*.txt' files in a directory. This is done by
    reading all files to build a complete model of the program's call graph
    and register usage.
    """
    callee_save_costs = {}
    callee_call_sites = defaultdict(list)
    function_hotness = {}
    tail_functions = {}

    func_pattern = re.compile(r"IPRA: Function: (.*?)\[")
    usage_pattern = re.compile(r"CSRegUsage: (.*?) IsFunctionEntryHot: (\d+)")
    call_pattern = re.compile(r"Calls: (.*?)\[.*\] IsTailCall: (\d+).*? LivingCSRegs: (.*)")
    mbb_pattern = re.compile(r"MBB: \d+.*?MBBCount: (\d+)")

    files_to_process = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('ipra_analysis_') and f.endswith('.txt')]
    print(f"Found {len(files_to_process)} profile files to process.")

    for filepath in files_to_process:
        with open(filepath, 'r', errors='ignore') as f:
            current_function = None
            current_mbb_count = 0
            for line in f:
                func_match = func_pattern.search(line)
                if func_match:
                    new_function = func_match.group(1).strip()
                    if new_function != current_function:
                        current_mbb_count = 0
                    current_function = new_function

                usage_match = usage_pattern.search(line)
                if usage_match and current_function:
                    regs_str = usage_match.group(1).strip()
                    num_regs = len(regs_str.split()) if regs_str else 0
                    callee_save_costs[current_function] = num_regs
                    is_hot_str = usage_match.group(2).strip()
                    function_hotness[current_function] = (int(is_hot_str) == 1)
                
                mbb_match = mbb_pattern.search(line)
                if mbb_match:
                    current_mbb_count = int(mbb_match.group(1))

                call_match = call_pattern.search(line)
                if call_match and current_function:
                    callee_name = call_match.group(1).strip()
                    is_tail_call = (int(call_match.group(2).strip()) == 1)
                    tail_functions[current_function] = is_tail_call
                    live_regs_str = call_match.group(3).strip()
                    
                    num_live_regs = len(live_regs_str.split()) if live_regs_str else 0
                    callee_call_sites[callee_name].append({
                        "caller": current_function,
                        "live_csrs": num_live_regs,
                        "count": current_mbb_count
                    })

    print(f"Found callee-save costs for {len(callee_save_costs)} unique functions.")
    print(f"Found call sites for {len(callee_call_sites)} unique callees.")
    return callee_save_costs, callee_call_sites, function_hotness, tail_functions

In [106]:
def calculate_benefits(callee_save_costs, callee_call_sites, function_hotness, tail_functions):
    """
    Calculates the total adjusted benefit score for each function.
    """
    benefit_scores = defaultdict(int)
    print(f"Calculating benefit scores...")

    # total_static_cost = 0

    for callee, sites in callee_call_sites.items():
        if not function_hotness.get(callee, False):
            continue
        callee_cost = callee_save_costs.get(callee, 0)
        
        total_dynamic_benefit = 0
        sum_of_caller_costs = 0
        
        for site in sites:
            caller_cost = site["live_csrs"]
            exec_count = site["count"]
            total_dynamic_benefit += (callee_cost - caller_cost) * exec_count
            sum_of_caller_costs += caller_cost

        # Calculate the total static cost (code size impact)
        # It's the total number of new pushes/pops minus the ones removed.
        # total_static_cost += 2 * (sum_of_caller_costs -  callee_cost)
        
        # Final adjusted score
        # adjusted_score = total_dynamic_benefit
        if total_dynamic_benefit > 0:
            benefit_scores[callee] = total_dynamic_benefit
        
    return benefit_scores

In [107]:
LIVENESS_DATA_DIR = './fdo_liveness_output'
SIZE_PENALTY = 0.1
PRESERVE_NONE_THRESHOLD = 0

costs, sites, function_hotness, tail_functions = parse_files(LIVENESS_DATA_DIR)
candidate_scores = calculate_benefits(costs, sites, function_hotness, tail_functions)

output_data = dict(candidate_scores.items())
print(f"Found {len(candidate_scores)} candidate functions meeting the threshold.")

Found 1833 profile files to process.
Found callee-save costs for 93305 unique functions.
Found call sites for 59511 unique callees.
Calculating benefit scores...
Found 1730 candidate functions meeting the threshold.


In [108]:
import os
import re
import json
from collections import defaultdict
    # Regex to capture the main components of a line
    main_pattern = re.compile(r"^IPRA: Function: (.+?)\[(.*?)\]\s*(.*)$")
    # Regex to find all flag names within the flags part of the line
    flag_pattern = re.compile(r"(\w+): \d+")
    parsed_functions = {}
    for filepath in files_to_process:
        with open(filepath, 'r', errors='ignore') as f:
            for line in f:
                main_match = main_pattern.match(line.strip())
                
                if main_match:
                    func_name = main_match.group(1)
                    cu_name = main_match.group(2)
                    flags_string = main_match.group(3)
                    present_flags = flag_pattern.findall(flags_string)
                    
                    parsed_functions[func_name] = present_flags
    return parsed_functions


parsed_functions = filter_dangerous_functions(LIVENESS_PRERA_DATA_DIR)
discarded_functions = set()
discarded_counts = defaultdict(int)
discarded_flags = {'HasAddressTaken', 'MustTailCall', 'IsInterposable', 'UsesAreIndirectCall', 'AllUsesAreNotCall'}
for func in output_data:
    if func in discarded_functions or func not in parsed_functions:
        continue
    flags = set(parsed_functions[func])
    intersection = flags.intersection(discarded_flags)
    if not intersection:
        continue
    for flag in intersection:
        discarded_counts[flag] += 1
    discarded_functions.add(func)
print(f"found {len(discarded_functions)} to discard")
print(discarded_counts)

found 63 to discard
defaultdict(<class 'int'>, {'HasAddressTaken': 63, 'AllUsesAreNotCall': 62, 'UsesAreIndirectCall': 1})


In [109]:
before_len = len(output_data)
for func in discarded_functions:
    output_data.pop(func)
after_len = len(output_data)
print(f"Filtered out {before_len - after_len} functions. Now total {len(output_data)} functions left.")

Filtered out 63 functions. Now total 1667 functions left.


In [110]:
OUTPUT_FILE = f'{LIVENESS_DATA_DIR}/liveness_profdata.json'
filtered_output_data = {func: score for func, score in output_data.items()}
print(f"Final filtering, {len(filtered_output_data)} functions.")
output_dict = {"functions": filtered_output_data}
with open(OUTPUT_FILE, 'w') as f:
    json.dump(output_dict, f, indent=2)

print(f"\n✅ Successfully merged profile data into '{OUTPUT_FILE}'")

Final filtering, 1667 functions.

✅ Successfully merged profile data into './fdo_liveness_output/liveness_profdata.json'


## Algorithm 2: Propagating Costs via Bottom-Up Call Graph

In [2]:
import os
import re
import json
import argparse
from collections import defaultdict

def parse_files(directory):
    """
    Parses all 'ipra_analysis_*.txt' files in a directory to build a model
    of the program's call graph and register usage.
    """
    callee_save_costs = {}
    function_hotness = {} # Store hotness for each function
    callee_call_sites = defaultdict(list)
    # The call graph is represented as Caller -> set(Callees)
    successors = defaultdict(set)
    predecessors = defaultdict(set)

    func_pattern = re.compile(r"IPRA: Function: (.*?)\[")
    usage_pattern = re.compile(r"CSRegUsage: (.*?) IsFunctionEntryHot: (\d+)")
    call_pattern = re.compile(r"Calls: (.*?)\[.*\] IsTailCall: (\d+).*? LivingCSRegs: (.*)")
    mbb_pattern = re.compile(r"MBB: \d+.*?MBBCount: (\d+)")

    files_to_process = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('ipra_analysis_') and f.endswith('.txt')]
    print(f"Found {len(files_to_process)} profile files to process.")

    all_functions = set()

    for filepath in files_to_process:
        with open(filepath, 'r', errors='ignore') as f:
            current_function = None
            current_mbb_count = 0
            for line in f:
                func_match = func_pattern.search(line)
                if func_match:
                    new_function = func_match.group(1).strip()
                    if new_function != current_function:
                        current_mbb_count = 0
                    current_function = new_function
                    all_functions.add(current_function)
                    

                usage_match = usage_pattern.search(line)
                if usage_match and current_function:
                    regs_str = usage_match.group(1).strip()
                    num_regs = len(regs_str.split()) if regs_str else 0
                    callee_save_costs[current_function] = num_regs
                    is_hot_str = usage_match.group(2).strip()
                    function_hotness[current_function] = (int(is_hot_str) == 1)
                
                mbb_match = mbb_pattern.search(line)
                if mbb_match:
                    current_mbb_count = int(mbb_match.group(1))

                call_match = call_pattern.search(line)
                if call_match and current_function:
                    callee_name = call_match.group(1).strip()
                    is_tail_call_str = call_match.group(2).strip()
                    
                    all_functions.add(callee_name)
                    live_regs_str = call_match.group(3).strip()
                    num_live_regs = len(live_regs_str.split()) if live_regs_str else 0
                    
                    callee_call_sites[callee_name].append({
                        "caller": current_function,
                        "live_csrs": num_live_regs,
                        "count": current_mbb_count,
                        "is_tail_call": (int(is_tail_call_str) == 1) # Store tail call info
                    })
                    successors[current_function].add(callee_name)
                    predecessors[callee_name].add(current_function)

    print(f"Found {len(all_functions)} unique functions in the call graph.")
    return callee_save_costs, callee_call_sites, successors, predecessors, all_functions, function_hotness

In [3]:
from collections import defaultdict, deque

def tarjan_scc(nodes, successors):
    """
    Tarjan's algorithm.
    Args:
        nodes: iterable of node ids (e.g., function names)
        successors: dict[node] -> set[node] (caller -> callees)
    Returns:
        sccs: list[list[node]] where each inner list is one SCC (arbitrary order)
        comp_id: dict[node] -> int  component index for each node
    """
    index = 0
    idx = {}
    low = {}
    onstack = set()
    stack = []
    sccs = []

    def strongconnect(v):
        nonlocal index
        idx[v] = index
        low[v] = index
        index += 1
        stack.append(v)
        onstack.add(v)

        for w in successors.get(v, ()):
            if w not in idx:
                strongconnect(w)
                low[v] = min(low[v], low[w])
            elif w in onstack:
                low[v] = min(low[v], idx[w])

        # root of an SCC
        if low[v] == idx[v]:
            component = []
            while True:
                w = stack.pop()
                onstack.remove(w)
                component.append(w)
                if w == v:
                    break
            sccs.append(component)

    for v in nodes:
        if v not in idx:
            strongconnect(v)

    # component id map
    comp_id = {}
    for cid, comp in enumerate(sccs):
        for v in comp:
            comp_id[v] = cid
    return sccs, comp_id


def build_condensed_dag(nodes, successors, predecessors):
    """
    Collapse nodes into SCC super-nodes and build the condensed DAG.
    Returns:
        sccs: list[list[node]]  original nodes per component id
        comp_id: dict[node]->int
        dag_succ: dict[cid] -> set[cid]
        dag_pred: dict[cid] -> set[cid]
    """
    sccs, comp_id = tarjan_scc(nodes, successors)

    dag_succ = defaultdict(set)
    dag_pred = defaultdict(set)

    for u in nodes:
        cu = comp_id[u]
        for v in successors.get(u, ()):
            cv = comp_id[v]
            if cu != cv:
                dag_succ[cu].add(cv)
                dag_pred[cv].add(cu)

    # ensure every component appears in maps
    C = len(sccs)
    for c in range(C):
        _ = dag_succ[c]
        _ = dag_pred[c]

    return sccs, comp_id, dag_succ, dag_pred


def scc_order_bottom_up(dag_succ, dag_pred):
    """
    Bottom-up order on the SCC DAG: sinks first (no outgoing edges).
    Uses a Kahn-style peel from sinks upward.
    Returns:
        order: list[int] of component ids in bottom-up order.
    """
    all_cids = set(dag_succ) | set(dag_pred)
    outdeg = {c: len(dag_succ.get(c, ())) for c in all_cids}
    q = deque([c for c in all_cids if outdeg[c] == 0])
    order = []

    while q:
        u = q.popleft()
        order.append(u)
        for p in dag_pred.get(u, ()):
            outdeg[p] -= 1
            if outdeg[p] == 0:
                q.append(p)

    # If cycles existed here, something's wrong—condensation must be a DAG.
    # But to be safe, append any unprocessed nodes (should be none).
    if len(order) != len(all_cids):
        remaining = [c for c in all_cids if c not in set(order)]
        order.extend(remaining)
    return order


def function_order_bottom_up(nodes, successors, predecessors):
    """
    Convenience wrapper: returns
      - sccs (list of lists of original nodes),
      - bottom-up SCC id order,
      - a flattened bottom-up function order (each SCC kept as a group).
    """
    sccs, comp_id, dag_succ, dag_pred = build_condensed_dag(nodes, successors, predecessors)
    scc_bottom_up = scc_order_bottom_up(dag_succ, dag_pred)

    # Flatten functions in bottom-up SCC order.
    # For multi-node SCCs (recursion), return them as a list to process as a unit.
    grouped = [list(sccs[cid]) for cid in scc_bottom_up]
    flat = [f for group in grouped for f in group]  # if you really need a flat list

    return {
        "sccs": sccs,
        "scc_bottom_up_order": scc_bottom_up,
        "grouped_functions_bottom_up": grouped,
        "flat_functions_bottom_up": flat,
        "comp_id": comp_id,
        "condensed_successors": dag_succ,
        "condensed_predecessors": dag_pred,
    }


In [8]:
def calculate_benefits_bottom_up(callee_save_costs, callee_call_sites, successors, predecessors, all_functions, function_hotness, size_penalty, threshold):
    """
    Calculates benefit scores using a bottom-up traversal of the call graph
    to model the cascading effects of the preserve_none optimization.
    """
    res = function_order_bottom_up(all_functions, successors, predecessors)
    sorted_nodes = res["flat_functions_bottom_up"]
    print(f"Topologically sorted {len(sorted_nodes)} functions for bottom-up processing.")

    final_candidates = set()
    # This dictionary simulates how a caller's own save cost might increase
    # as its callees become preserve_none.
    effective_cs_usage = defaultdict(int, callee_save_costs)
    final_scores = {}

    for callee in sorted_nodes:
        # Skip any non-hot function:
        if not function_hotness.get(callee, False):
            final_scores[callee] = float('-inf')
            continue

        # 1. Calculate the benefit for the current function using the most up-to-date
        #    cost information for itself and its callees.
        callee_cost = effective_cs_usage[callee]
        total_dynamic_benefit = 0
        sum_of_caller_static_costs = 0
        
        call_sites = callee_call_sites.get(callee, [])
        for site in call_sites:
            if site.get("is_tail_call", False):
                continue
            caller_name = site["caller"]
            caller_cost = 6 - site["live_csrs"] #X86-64 has 6 callee-saved registers
            
            exec_count = site["count"]
            total_dynamic_benefit += (callee_cost - caller_cost) * exec_count
            sum_of_caller_static_costs += caller_cost

        total_static_cost = (2 * sum_of_caller_static_costs) - (2 * callee_cost)
        adjusted_score = total_dynamic_benefit - (size_penalty * total_static_cost)
        final_scores[callee] = adjusted_score

        # 2. Make a decision for the current function.
        if adjusted_score > 0:
            final_candidates.add(callee)
            
            # 3. Propagate the cost of this decision upwards to its callers.
            #    We assume the cost pushed up is the original, static number of
            #    registers the callee was responsible for.
            original_callee_cost = callee_save_costs.get(callee, 0)
            for caller in predecessors[callee]:
                # This simulates the increased register pressure on the caller.
                effective_cs_usage[caller] += original_callee_cost

    return {func: score for func, score in final_scores.items() if func in final_candidates and score > threshold}


In [9]:
LIVENESS_DATA_DIR = './thinly_linked_fdo_liveness_output'
SIZE_PENALTY = 0.1
PRESERVE_NONE_THRESHOLD = 0

costs, sites, successors, predecessors, all_nodes, function_hotness = parse_files(LIVENESS_DATA_DIR)
candidate_scores = calculate_benefits_bottom_up(costs, sites, successors, predecessors, all_nodes, function_hotness, SIZE_PENALTY, PRESERVE_NONE_THRESHOLD)

output_data = dict(candidate_scores.items())

print(f"Found {len(candidate_scores)} candidate functions meeting the threshold.")

Found 1686 profile files to process.
Found 82923 unique functions in the call graph.
Topologically sorted 82923 functions for bottom-up processing.
Found 2280 candidate functions meeting the threshold.


In [11]:
import os
import re
import json
from collections import defaultdict

LIVENESS_PRERA_DATA_DIR = './thinly_linked_fdo_liveness_output'
def filter_dangerous_functions(directory):
    files_to_process = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('ipra_prera_analysis_') and f.endswith('.txt')]
    # Regex to capture the main components of a line
    main_pattern = re.compile(r"^IPRA: Function: (.+?)\[(.*?)\]\s*(.*)$")
    # Regex to find all flag names within the flags part of the line
    flag_pattern = re.compile(r"(\w+): \d+")
    parsed_functions = {}
    for filepath in files_to_process:
        with open(filepath, 'r', errors='ignore') as f:
            for line in f:
                main_match = main_pattern.match(line.strip())
                
                if main_match:
                    func_name = main_match.group(1)
                    cu_name = main_match.group(2)
                    flags_string = main_match.group(3)
                    present_flags = flag_pattern.findall(flags_string)
                    
                    parsed_functions[func_name] = present_flags
    return parsed_functions


parsed_functions = filter_dangerous_functions(LIVENESS_PRERA_DATA_DIR)
discarded_functions = set()
discarded_counts = defaultdict(int)
discarded_flags = {'HasAddressTaken', 'MustTailCall', 'IsInterposable', 'UsesAreIndirectCall', 'AllUsesAreNotCall'}
for func in output_data:
    if func in discarded_functions or func not in parsed_functions:
        continue
    flags = set(parsed_functions[func])
    intersection = flags.intersection(discarded_flags)
    if not intersection:
        continue
    for flag in intersection:
        discarded_counts[flag] += 1
    discarded_functions.add(func)
print(f"found {len(discarded_functions)} to discard")
print(discarded_counts)

found 305 to discard
defaultdict(<class 'int'>, {'AllUsesAreNotCall': 300, 'HasAddressTaken': 305, 'UsesAreIndirectCall': 1})


In [12]:
before_len = len(output_data)
for func in discarded_functions:
    output_data.pop(func)
after_len = len(output_data)
print(f"Filtered out {before_len - after_len} functions. Now total {len(output_data)} functions left.")

Filtered out 305 functions. Now total 1975 functions left.


## Save output

In [14]:
OUTPUT_FILE = f'{LIVENESS_DATA_DIR}/liveness_profdata.json'
output_dict = {"functions": output_data}
with open(OUTPUT_FILE, 'w') as f:
    json.dump(output_dict, f, indent=2)

print(f"\n✅ Successfully merged profile data into '{OUTPUT_FILE}'")


✅ Successfully merged profile data into './thinly_linked_fdo_liveness_output/liveness_profdata.json'


In [15]:
# For pn.syms
function_names = list(output_data.keys())
PN_SYMS_OUTPUT_PATH = f'{LIVENESS_DATA_DIR}/pn.syms'
with open(PN_SYMS_OUTPUT_PATH, 'w') as f:
    for name in function_names:
        f.write(name + '\n')
print(f"Successfully created '{PN_SYMS_OUTPUT_PATH}' with {len(function_names)} function symbols.")

Successfully created './thinly_linked_fdo_liveness_output/pn.syms' with 1975 function symbols.
