In [1]:
import os
os.chdir('..')

In [3]:
import json
import numpy as np
from scipy.stats import hmean
from collections import defaultdict

# Load the results from the JSON file
with open('results_nlp.json') as f:
    results_dict = json.load(f)

# Define the models of interest and their corresponding baselines with labels
models_of_interest = {
    "stage-final-llava-v15-pythia+1p4b": ("reproduction-align-pythia+1p4b", "LLaVA + Pythia (1.4B)"),
    "stage-final-llava-v15-pythia+1p4b-instruct": ("reproduction-align-pythia+1p4b-instruct", "LLaVA + Pythia Instruct (1.4B)"),
    "reproduction-llava-v15+7b+stage-finetune+x7": ("reproduction-llava-v15+7b+stage-align+x7", "LLaVA + LLaMA2 Instruct (7B)"),
    "reproduction-llama2": ("vila_base_llm", "LLaVA + LLaMA2 Base (7B)")
}

# Function to format values or return "-"
def format_value(value):
    return "{:.2f}".format(value * 100) if not np.isnan(value) else "-"

# Function to check if a model is instruction fine-tuned
def is_instruction_fine_tuned(label):
    return "Instruct" in label

# Prepare the data for the LaTeX tables
table_data = []

for model, (baseline, label) in models_of_interest.items():
    accuracies = results_dict[model]
    baseline_accuracies = results_dict[baseline]
    
    avg_acc_vl = hmean([accuracies[dataset] for dataset in ["vqa-v2", "textvqa-ocr", "textvqa-pure", "gqa"]])
    avg_acc_nl = hmean([accuracies[dataset] for dataset in ["wsc273", "winogrande", "arc_easy", "arc_challenge", "lambada_standard"]])
    avg_acc_nlu = hmean([accuracies[dataset] for dataset in ["wsc273", "winogrande", "arc_easy", "arc_challenge"]])
    avg_acc_nlg = accuracies["lambada_standard"]
    baseline_avg_nlu = hmean([baseline_accuracies[dataset] for dataset in ["wsc273", "winogrande", "arc_easy", "arc_challenge"]])
    baseline_avg_nlg = baseline_accuracies["lambada_standard"]
    baseline_avg_acc_nl = hmean([baseline_accuracies[dataset] for dataset in ["wsc273", "winogrande", "arc_easy", "arc_challenge", "lambada_standard"]])
    delta_nl = baseline_avg_acc_nl - avg_acc_nl
    delta_nlu = baseline_avg_nlu - avg_acc_nlu
    delta_nlg = baseline_avg_nlg - avg_acc_nlg
    print(f"Model {label}: Avg. VL = {avg_acc_vl*100:.2f}, Avg. NL = {avg_acc_nl*100:.2f}, Delta NL = {delta_nl*100:.2f}, Delta NLU = {delta_nlu*100:.2f}, Delta NLG = {delta_nlg*100:.2f}")
    
    table_data.append((label, is_instruction_fine_tuned(label), avg_acc_vl, avg_acc_nl, delta_nl))

# Sort the data by Avg. VL Accuracy and highest NL Delta
table_data = sorted(table_data, key=lambda x: (x[2], -x[4]), reverse=True)

# Generate LaTeX table
latex_code = """
\\begin{table*}[h]
  \\caption{\\textbf{LLaVA Model Performance}}
  \\label{tab:model_performance}
  \\centering
  \\resizebox{\\linewidth}{!}{
    \\begin{tabular}{l|c|c|cc}
     \\toprule
     \\textbf{Model} & \\textbf{Instr.} & \\multicolumn{1}{c|}{\\textbf{VL Avg.}} & \\multicolumn{2}{c}{\\textbf{NL Avg.}} \\\\
     & & \\textbf{Acc $\\uparrow$} & \\textbf{Acc $\\uparrow$} & \\textbf{Delta $\\uparrow$} \\\\
     \\midrule
"""

for label, instr, avg_acc_vl, avg_acc_nl, delta_nl in table_data:
    latex_code += "{label} & {instr} & {avg_acc_vl} & {avg_acc_nl} & {delta_nl} \\\\\n".format(
        label=label,
        instr="\\ding{51}" if instr else "\\ding{55}",
        avg_acc_vl=format_value(avg_acc_vl),
        avg_acc_nl=format_value(avg_acc_nl),
        delta_nl=format_value(delta_nl)
    )

latex_code += """
     \\bottomrule
    \\end{tabular}
  }
\\end{table*}
"""

print(latex_code)

Model LLaVA + Pythia (1.4B): Avg. VL = 43.97, Avg. NL = 45.51, Delta NL = 2.18, Delta NLU = 0.55, Delta NLG = 8.07
Model LLaVA + Pythia Instruct (1.4B): Avg. VL = 43.93, Avg. NL = 41.37, Delta NL = -1.16, Delta NLU = -1.20, Delta NLG = -1.01
Model LLaVA + LLaMA2 Instruct (7B): Avg. VL = 56.55, Avg. NL = 64.44, Delta NL = -0.36, Delta NLU = -0.98, Delta NLG = 2.04
Model LLaVA + LLaMA2 Base (7B): Avg. VL = 57.22, Avg. NL = 66.23, Delta NL = -1.84, Delta NLU = -2.15, Delta NLG = -0.43

\begin{table*}[h]
  \caption{\textbf{LLaVA Model Performance}}
  \label{tab:model_performance}
  \centering
  \resizebox{\linewidth}{!}{
    \begin{tabular}{l|c|c|cc}
     \toprule
     \textbf{Model} & \textbf{Instr.} & \multicolumn{1}{c|}{\textbf{VL Avg.}} & \multicolumn{2}{c}{\textbf{NL Avg.}} \\
     & & \textbf{Acc $\uparrow$} & \textbf{Acc $\uparrow$} & \textbf{Delta $\uparrow$} \\
     \midrule
LLaVA + LLaMA2 Base (7B) & \ding{55} & 57.22 & 66.23 & -1.84 \\
LLaVA + LLaMA2 Instruct (7B) & \ding{51} & 

In [4]:
# --- INSERT *AFTER* loading results_dict and defining models_of_interest, format_value, is_instruction_fine_tuned ---

# Define VL and NL splits and human‑readable labels
vl_cols = ["vqa-v2", "textvqa-ocr", "textvqa-pure", "gqa"]
vl_labels = {"vqa-v2": "VQA", "textvqa-ocr": "TextVQA‑OCR", "textvqa-pure": "TextVQA‑Pure", "gqa": "GQA"}

nl_cols = ["wsc273", "winogrande", "arc_easy", "arc_challenge", "lambada_standard"]
nl_labels = {"wsc273": "WSC273", "winogrande": "Winogrande", "arc_easy": "ARC‑E", "arc_challenge": "ARC‑C", "lambada_standard": "Lambada"}

# Function to emit a LaTeX table given a column split
def emit_table(split_cols, split_labels, caption, label):
    # begin table
    print(r"\begin{table*}[h]")
    print(r"  \caption{" + caption + r"}")
    print(r"  \label{" + label + r"}")
    print(r"  \centering")
    print(r"  \resizebox{\linewidth}{!}{%")
    # tabular header: one 'l' for model name + one 'c' per split column
    cols_fmt = "l|" + "c" * len(split_cols)
    print(r"    \begin{tabular}{" + cols_fmt + r"}")
    print(r"      \toprule")
    # header row
    headers = " & ".join([r"\textbf{" + split_labels[c] + r"}" for c in split_cols])
    print(r"      \textbf{Model} & " + headers + r" \\")
    print(r"      \midrule")
    # data rows
    for model_key, (baseline, label) in models_of_interest.items():
        row = [label]
        accs = results_dict[model_key]
        for c in split_cols:
            val = accs.get(c, np.nan)
            row.append(format_value(val))
        print("      " + " & ".join(row) + r" \\")
    # end tabular
    print(r"      \bottomrule")
    print(r"    \end{tabular}%")
    print(r"}")
    print(r"\end{table*}")
    print()

# Emit VL accuracy table
emit_table(
    vl_cols,
    vl_labels,
    caption=r"\textbf{Vision‑Language Task Accuracies} -- post‑training on each model.",
    label="tab:vl_accuracies"
)

# Emit NL accuracy table
emit_table(
    nl_cols,
    nl_labels,
    caption=r"\textbf{Natural Language Task Accuracies} -- post‑training on each model.",
    label="tab:nl_accuracies"
)

\begin{table*}[h]
  \caption{\textbf{Vision‑Language Task Accuracies} -- post‑training on each model.}
  \label{tab:vl_accuracies}
  \centering
  \resizebox{\linewidth}{!}{%
    \begin{tabular}{l|cccc}
      \toprule
      \textbf{Model} & \textbf{VQA} & \textbf{TextVQA‑OCR} & \textbf{TextVQA‑Pure} & \textbf{GQA} \\
      \midrule
      LLaVA + Pythia (1.4B) & 66.17 & 38.49 & 35.49 & 46.09 \\
      LLaVA + Pythia Instruct (1.4B) & 66.46 & 39.12 & 34.35 & 46.88 \\
      LLaVA + LLaMA2 Instruct (7B) & 74.50 & 56.28 & 45.95 & 56.25 \\
      LLaVA + LLaMA2 Base (7B) & 75.88 & 55.21 & 45.43 & 60.25 \\
      \bottomrule
    \end{tabular}%
}
\end{table*}

\begin{table*}[h]
  \caption{\textbf{Natural Language Task Accuracies} -- post‑training on each model.}
  \label{tab:nl_accuracies}
  \centering
  \resizebox{\linewidth}{!}{%
    \begin{tabular}{l|ccccc}
      \toprule
      \textbf{Model} & \textbf{WSC273} & \textbf{Winogrande} & \textbf{ARC‑E} & \textbf{ARC‑C} & \textbf{Lambada} \\
      \

In [6]:
# --- INSERT *AFTER* defining results_dict, models_of_interest, format_value, is_instruction_fine_tuned ---

import numpy as np

# NL tasks & labels
nl_cols   = ["wsc273", "winogrande", "arc_easy", "arc_challenge", "lambada_standard"]
nl_labels = {
    "wsc273":           "WSC273",
    "winogrande":       "Winogrande",
    "arc_easy":         "ARC‑E",
    "arc_challenge":    "ARC‑C",
    "lambada_standard": "Lambada"
}

# Build baseline LLM entries and LLaVA MLLM entries
baseline_entries = []
ml_entries       = []
for model_key, (baseline_id, label) in models_of_interest.items():
    # unify “LLaMA2 Base (7B)” and “LLaMA2 Instruct (7B)” under “LLaMA2 (7B)”
    base_label = label.replace("LLaVA + ", "").replace(" Base", "")
    base_flag  = is_instruction_fine_tuned(label)
    baseline_entries.append((base_label, base_flag, baseline_id))

    ml_label = label.replace("LLaVA + ", "").replace(" Base", "")
    ml_flag  = is_instruction_fine_tuned(label)
    ml_entries.append((ml_label, ml_flag, model_key))

# Deduplicate while preserving order
def uniq(seq):
    seen = set(); out = []
    for x in seq:
        if x not in seen:
            seen.add(x); out.append(x)
    return out

baseline_entries = uniq(baseline_entries)
ml_entries       = uniq(ml_entries)

# Group keys (e.g. "Pythia (1.4B)", "LLaMA2 (7B)")
group_keys = []
for lbl, _, _ in baseline_entries:
    key = lbl.replace(" Instruct", "")
    if key not in group_keys:
        group_keys.append(key)

# Begin LaTeX table
print(r"\begin{table*}[h]")
print(r"  \centering")
print(r"  \caption{\textbf{Natural Language Task Accuracies: Base LLMs vs.\ LLaVA}}")
print(r"  \label{tab:nl_base_vs_mllm}")
print(r"  \resizebox{\linewidth}{!}{%")
print(r"    \begin{tabular}{l|c|" + "c"*len(nl_cols) + r"}")
print(r"      \toprule")
print(r"      \textbf{Model} & \textbf{Instr.} & " 
      + " & ".join([r"\textbf{" + nl_labels[c] + r"}" for c in nl_cols]) 
      + r" \\")
print(r"      \midrule")

# Base LLMs block
print(r"      \multicolumn{" + str(len(nl_cols)+2) 
      + r"}{l}{\textbf{Base LLMs}} \\")
print(r"      \midrule")
for key in group_keys:
    rows = [(lbl, flag, rid) 
            for lbl, flag, rid in baseline_entries 
            if lbl.replace(" Instruct","") == key]
    for i, (lbl, flag, rid) in enumerate(rows):
        tick = r"\ding{51}" if flag else r"\ding{55}"
        vals = [ format_value(results_dict[rid].get(c, np.nan)) for c in nl_cols ]
        if i == 0:
            print(f"      \\multirow{{{len(rows)}}}{{*}}{{{key}}} & {tick} & " 
                  + " & ".join(vals) + r" \\")
        else:
            print(f"      & {tick} & " + " & ".join(vals) + r" \\")
    print(r"      \midrule")

# LLaVA MLLMs block
print(r"      \multicolumn{" + str(len(nl_cols)+2) 
      + r"}{l}{\textbf{LLaVA MLLMs}} \\")
print(r"      \midrule")
for key in group_keys:
    rows = [(lbl, flag, mid) 
            for lbl, flag, mid in ml_entries 
            if lbl.replace("LLaVA + ","").replace(" Instruct","") == key]
    for i, (lbl, flag, mid) in enumerate(rows):
        tick = r"\ding{51}" if flag else r"\ding{55}"
        vals = [ format_value(results_dict[mid].get(c, np.nan)) for c in nl_cols ]
        if i == 0:
            print(f"      \\multirow{{{len(rows)}}}{{*}}{{{key}}} & {tick} & " 
                  + " & ".join(vals) + r" \\")
        else:
            print(f"      & {tick} & " + " & ".join(vals) + r" \\")
    print(r"      \midrule")

print(r"      \bottomrule")
print(r"    \end{tabular}%")
print(r"  }")
print(r"\end{table*}")

\begin{table*}[h]
  \centering
  \caption{\textbf{Natural Language Task Accuracies: Base LLMs vs.\ LLaVA}}
  \label{tab:nl_base_vs_mllm}
  \resizebox{\linewidth}{!}{%
    \begin{tabular}{l|c|ccccc}
      \toprule
      \textbf{Model} & \textbf{Instr.} & \textbf{WSC273} & \textbf{Winogrande} & \textbf{ARC‑E} & \textbf{ARC‑C} & \textbf{Lambada} \\
      \midrule
      \multicolumn{7}{l}{\textbf{Base LLMs}} \\
      \midrule
      \multirow{2}{*}{Pythia (1.4B)} & \ding{55} & 70.70 & 56.51 & 61.74 & 27.47 & 48.98 \\
      & \ding{51} & 59.34 & 50.43 & 48.23 & 26.79 & 33.79 \\
      \midrule
      \multirow{2}{*}{LLaMA2 (7B)} & \ding{51} & 85.35 & 69.53 & 75.63 & 43.17 & 64.35 \\
      & \ding{55} & 80.59 & 69.22 & 76.26 & 43.43 & 68.27 \\
      \midrule
      \multicolumn{7}{l}{\textbf{LLaVA MLLMs}} \\
      \midrule
      \multirow{2}{*}{Pythia (1.4B)} & \ding{55} & 67.40 & 56.20 & 60.98 & 27.47 & 40.91 \\
      & \ding{51} & 64.84 & 54.06 & 51.94 & 25.68 & 34.80 \\
      \midrule
      \

In [8]:
# Build MLLM entries (family label, instruct‐flag, lookup key), 
# but skip the “Base” LLaMA2 entry so only one LLaMA2 (7B) appears
ml_entries = []
for model_key, (_, label) in models_of_interest.items():
    # drop the “Base” version of LLaMA2 to keep only one LLaMA2 (7B)
    if "Base (7B)" in label:
        continue
    family    = label.replace("LLaVA + ", "").replace(" Instruct", "")
    instr_flag = is_instruction_fine_tuned(label)
    ml_entries.append((family, instr_flag, model_key))

# Extract unique families in order
group_keys = []
for fam, _, _ in ml_entries:
    if fam not in group_keys:
        group_keys.append(fam)

# Print grouped VL table
print(r"\begin{table*}[h]")
print(r"  \centering")
print(r"  \caption{\textbf{Vision‑Language Task Accuracies: Grouped by Model Family}}")
print(r"  \label{tab:vl_grouped_accuracies}")
print(r"  \resizebox{\linewidth}{!}{%")
print(r"    \begin{tabular}{l|c|" + "c"*len(vl_cols) + r"}")
print(r"      \toprule")
print(r"      \textbf{Model} & \textbf{Instr.} & "
      + " & ".join([r"\textbf{" + vl_labels[c] + r"}" for c in vl_cols])
      + r" \\")
print(r"      \midrule")

for fam in group_keys:
    rows = [(lbl, flag, key) for lbl, flag, key in ml_entries if lbl == fam]
    for i, (lbl, flag, key) in enumerate(rows):
        tick = r"\ding{51}" if flag else r"\ding{55}"
        vals = [format_value(results_dict[key].get(c, np.nan)) for c in vl_cols]
        if i == 0:
            print(f"      \\multirow{{{len(rows)}}}{{*}}{{{fam}}} & {tick} & "
                  + " & ".join(vals) + r" \\")
        else:
            print(f"      & {tick} & " + " & ".join(vals) + r" \\")

    print(r"      \midrule")

print(r"      \bottomrule")
print(r"    \end{tabular}%")
print(r"  }")
print(r"\end{table*}")

\begin{table*}[h]
  \centering
  \caption{\textbf{Vision‑Language Task Accuracies: Grouped by Model Family}}
  \label{tab:vl_grouped_accuracies}
  \resizebox{\linewidth}{!}{%
    \begin{tabular}{l|c|cccc}
      \toprule
      \textbf{Model} & \textbf{Instr.} & \textbf{VQA} & \textbf{TextVQA‑OCR} & \textbf{TextVQA‑Pure} & \textbf{GQA} \\
      \midrule
      \multirow{2}{*}{Pythia (1.4B)} & \ding{55} & 66.17 & 38.49 & 35.49 & 46.09 \\
      & \ding{51} & 66.46 & 39.12 & 34.35 & 46.88 \\
      \midrule
      \multirow{1}{*}{LLaMA2 (7B)} & \ding{51} & 74.50 & 56.28 & 45.95 & 56.25 \\
      \midrule
      \bottomrule
    \end{tabular}%
  }
\end{table*}
