import json
import sys
import gzip
import os
from collections import defaultdict

def fxprof_to_callgrind(profile_data):
    """
    Converts Firefox Profiler data (in JSON format) to callgrind format,
    including file and line number information.
    """
    # --- Phase 1: Parsing and Initialization ---
    if 'threads' not in profile_data or not profile_data['threads']:
        raise ValueError("Profile data contains no threads.")

    thread = None
    for t in profile_data['threads']:
        if t.get('samples') and t['samples'].get('stack'):
            thread = t
            break
    
    if not thread:
        raise ValueError("No thread with samples found in the profile.")

    print(f"Processing thread: {thread.get('name', 'Unnamed Thread')} (PID: {thread.get('pid')}, TID: {thread.get('tid')})", file=sys.stderr)

    # --- Enhanced Data Structures for Location Info ---
    graph = defaultdict(lambda: {'self_cost': 0.0, 'inclusive_cost': 0.0, 'calls': defaultdict(float)})
    
    # Mappings for functions
    func_to_id = {}
    id_to_func_details = {}
    next_func_id = 1

    # Mappings for files
    file_to_id = {}
    id_to_file = {}
    next_file_id = 1

    def get_file_id(file_name):
        nonlocal next_file_id
        if file_name is None:
            return None
        if file_name not in file_to_id:
            file_id = next_file_id
            file_to_id[file_name] = file_id
            id_to_file[file_id] = file_name
            next_file_id += 1
        return file_to_id[file_name]

    def get_func_id(func_name, file_id, line_num):
        nonlocal next_func_id
        # A function is uniquely identified by its name and file location
        func_key = (func_name, file_id)
        if func_key not in func_to_id:
            func_id = next_func_id
            func_to_id[func_key] = func_id
            id_to_func_details[func_id] = {
                'name': func_name,
                'file_id': file_id,
                'line': line_num
            }
            next_func_id += 1
        return func_to_id[func_key]

    # Pre-fetch tables for easier access
    stack_table = thread['stackTable']
    frame_table = thread['frameTable']
    func_table = thread['funcTable']
    samples = thread['samples']

    # Use shared stringArray if exists, else per-thread
    if 'shared' in profile_data and 'stringArray' in profile_data['shared']:
        string_table = profile_data['shared']['stringArray']
    elif 'stringArray' in thread:
        string_table = thread['stringArray']
    elif 'stringTable' in thread:
        string_table = thread['stringTable']
    else:
        raise ValueError("No string table found in profile or thread")

    # --- Phase 2 & 3: Stack Reconstruction and Aggregation ---
    for sample_idx, stack_index in enumerate(samples['stack']):
        if stack_index is None:
            continue

        weight = 1
        if 'weight' in samples and samples['weight'][sample_idx] is not None:
            weight = samples['weight'][sample_idx]

        current_stack_indices = []
        curr = stack_index
        while curr is not None:
            frame_index = stack_table['frame'][curr]
            func_index = frame_table['func'][frame_index]
            
            # Extract function name
            name_index = func_table['name'][func_index]
            func_name = string_table[name_index]
            
            # Extract file name and line number
            file_name_index = func_table['fileName'][func_index]
            file_name = string_table[file_name_index] if file_name_index is not None else "Unknown File"
            line_num = func_table['lineNumber'][func_index]

            file_id = get_file_id(file_name)
            func_id = get_func_id(func_name, file_id, line_num)
            
            current_stack_indices.append(func_id)
            
            curr = stack_table['prefix'][curr]
        
        current_stack_indices.reverse()

        if not current_stack_indices:
            continue

        leaf_func_id = current_stack_indices[-1]
        graph[leaf_func_id]['self_cost'] += weight

        for func_id in current_stack_indices:
            graph[func_id]['inclusive_cost'] += weight

        for i in range(len(current_stack_indices) - 1):
            caller_id = current_stack_indices[i]
            callee_id = current_stack_indices[i+1]
            graph[caller_id]['calls'][callee_id] += weight

    # --- Phase 4: Proportional Cost Calculation ---
    total_incoming = defaultdict(float)
    for caller_id, data in graph.items():
        for callee_id, count in data['calls'].items():
            total_incoming[callee_id] += count

    # --- Phase 5: Generating the callgrind Output File ---
    output = []
    
    # Header
    output.append("# callgrind format")
    output.append("version: 1")
    cmd_name = profile_data.get('meta', {}).get('product', 'UnknownApp')
    output.append(f"cmd: {cmd_name}")
    output.append("events: Samples")
    output.append("")

    # Definitions
    for file_id, file_name in id_to_file.items():
        output.append(f"fl=({file_id}) {file_name}")
    
    for func_id, details in id_to_func_details.items():
        output.append(f"fn=({func_id}) {details['name']}")
    output.append("")

    # Body
    for func_id, data in graph.items():
        details = id_to_func_details[func_id]
        
        if details['file_id']:
            output.append(f"fl=({details['file_id']})")
        output.append(f"fn=({func_id})")
        
        line_num = details.get('line') or 1
        output.append(f"{line_num} {int(data['self_cost'])}")
        
        for callee_id, call_count in data['calls'].items():
            callee_details = id_to_func_details[callee_id]
            
            if callee_details['file_id'] and callee_details['file_id']!= details['file_id']:
                output.append(f"cfl=({callee_details['file_id']})")
            output.append(f"cfn=({callee_id})")
            
            # Using placeholder '1' for call site line number
            output.append(f"calls={int(call_count)} 1")
            
            if total_incoming[callee_id] == 0:
                edge_inclusive = graph[callee_id]['inclusive_cost']
            else:
                fraction = call_count / total_incoming[callee_id]
                edge_inclusive = fraction * graph[callee_id]['inclusive_cost']
            
            output.append(f"1 {int(edge_inclusive)}")
        output.append("")
        
    return "\n".join(output)

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python render_profile.py <path_to_profile.json|.json.gz>", file=sys.stderr)
        sys.exit(1)

    input_filepath = sys.argv[1]
    
    try:
        # Check if file is gzipped based on extension
        if input_filepath.endswith('.json.gz'):
            with gzip.open(input_filepath, 'rt', encoding='utf-8') as f:
                profile = json.load(f)
            output_filepath = input_filepath.replace('.json.gz', '.capture')
        else:
            with open(input_filepath, 'r') as f:
                profile = json.load(f)
            output_filepath = input_filepath.replace('.json', '.capture')
        
        callgrind_output = fxprof_to_callgrind(profile)
        with open(output_filepath, 'w') as f:
            f.write(callgrind_output)
        print(f"Saved callgrind file to {output_filepath}")

    except FileNotFoundError:
        print(f"Error: Input file not found at '{input_filepath}'", file=sys.stderr)
        sys.exit(1)
    except (json.JSONDecodeError, ValueError) as e:
        print(f"Error processing profile file: {e}", file=sys.stderr)
        sys.exit(1)
