In [1]:
import os
import json

# --- Modify these variables directly to change inputs ---
evaluation_dir = "evaluations_llm"
exp_name = "none-1B"
communication_rounds = "20"


In [2]:
def generate_latex_table(evaluation_dir, exp_name, communication_rounds):
    # Map domain abbreviations to JSON category keys
    domains = {
        'BS': 'brainstorming',
        'CL': 'classification',
        'CQ': 'closed_qa',
        'CW': 'creative_writing',
        'GQ': 'general_qa',
        'IE': 'information_extraction',
        'OQ': 'open_qa',
        'SM': 'summarization'
    }
    n = len(domains)
    
    # Load results into an n×n matrix
    results = [[None]*n for _ in range(n)]
    base_path = os.path.join(evaluation_dir, exp_name, communication_rounds)
    for i in range(n):
        filepath = os.path.join(base_path, f"client_{i}_results.json")
        try:
            data = json.load(open(filepath))
            for j, abbr in enumerate(domains):
                key = domains[abbr]
                if key in data['categories']:
                    results[i][j] = data['categories'][key]['ratio']
        except Exception as e:
            print(f"Warning for client {i}: {e}")
    
    # Compute row averages and column averages
    row_avgs = []
    for row in results:
        vals = [v for v in row if v is not None]
        row_avgs.append(sum(vals)/len(vals) if vals else None)
    col_avgs = []
    for j in range(n):
        vals = [results[i][j] for i in range(n) if results[i][j] is not None]
        col_avgs.append(sum(vals)/len(vals) if vals else None)
    overall_avg = sum(a for a in row_avgs if a is not None) / len([a for a in row_avgs if a is not None])
    
    # Identify minimum in each column and min row average for bolding
    col_min = [min(results[i][j] or 0 for i in range(n)) for j in range(n)]
    min_row_avg = min(a or 0 for a in row_avgs)
    
    # Build LaTeX table lines with the exact header format you requested
    lines = []
    lines.append(r"\begin{tabular}{l l|cccccccc|c}")
    lines.append(r"\hline")
    lines.append(r"\textbf{Method} & \textbf{Training} & \multicolumn{8}{c|}{\textbf{Test Domain}} & \textbf{Avg.} \\")
    lines.append(r" & \textbf{Domain} & \textbf{BS} & \textbf{CL} & \textbf{CQ} & \textbf{CW} & \textbf{GQ} & \textbf{IE} & \textbf{OQ} & \textbf{SM} & \\")
    lines.append(r"\hline")
    lines.append(r"\multirow{%d}{*}{Individual}" % n)
    
    # Data rows
    for i in range(n):
        row = ["", list(domains.keys())[i]]
        for j in range(n):
            v = results[i][j]
            if v is None:
                cell = "--"
            else:
                text = f"{v:.3f}"
                if v == col_min[j]:
                    text = r"\textbf{" + text + "}"
                cell = text
            row.append(cell)
        avg = row_avgs[i]
        avg_text = f"{avg:.3f}" if avg is not None else "--"
        if avg == min_row_avg:
            avg_text = r"\textbf{" + avg_text + "}"
        row.append(avg_text)
        lines.append(" & ".join(row) + r" \\")
    
    # Column-average row with multicolumn formatting
    lines.append(r"\hline")
    avg_cells = [f"{v:.3f}" for v in (col_avgs + [overall_avg])]
    lines.append(
        r"\multicolumn{2}{c|}{\textbf{Average}} & "
        + " & ".join(avg_cells)
        + r" \\"
    )
    
    lines.append(r"\hline")
    lines.append(r"\end{tabular}")
    return "\n".join(lines)


In [3]:
latex_code = generate_latex_table(evaluation_dir, exp_name, communication_rounds)
print(latex_code)


\begin{tabular}{l l|cccccccc|c}
\hline
\textbf{Method} & \textbf{Training} & \multicolumn{8}{c|}{\textbf{Test Domain}} & \textbf{Avg.} \\
 & \textbf{Domain} & \textbf{BS} & \textbf{CL} & \textbf{CQ} & \textbf{CW} & \textbf{GQ} & \textbf{IE} & \textbf{OQ} & \textbf{SM} & \\
\hline
\multirow{8}{*}{Individual}
 & BS & 1.072 & 1.265 & 1.248 & 1.241 & 1.166 & 1.491 & 1.156 & 1.314 & 1.244 \\
 & CL & 1.253 & 1.311 & 1.192 & 1.338 & 1.142 & 1.425 & 1.195 & 1.240 & 1.262 \\
 & CQ & 1.316 & 1.460 & 1.136 & 1.337 & 1.241 & 1.465 & 1.315 & 1.196 & 1.308 \\
 & CW & \textbf{0.957} & \textbf{1.227} & 1.305 & \textbf{1.065} & \textbf{0.991} & 1.495 & 1.019 & 1.360 & \textbf{1.177} \\
 & GQ & 1.063 & 1.456 & 1.269 & 1.183 & 1.059 & 1.444 & \textbf{1.011} & 1.320 & 1.226 \\
 & IE & 1.206 & 1.533 & \textbf{1.088} & 1.243 & 1.166 & 1.384 & 1.303 & 1.182 & 1.263 \\
 & OQ & 1.137 & 1.563 & 1.223 & 1.265 & 1.203 & 1.420 & 1.277 & 1.299 & 1.298 \\
 & SM & 1.137 & 1.436 & 1.143 & 1.283 & 1.172 & \textbf{1.252