In [70]:
from pathlib import Path
from functools import partial
import yaml

FILEPATH = Path("__file__").resolve().parent
RESULTSPATH = Path("/home/pedro/mount/datafactory/iso_punc")



In [71]:
indices2path = {
    "MNIST": "mnist",
    "FashionMNIST": "fmnist",
    "E-MNIST": "emnist:mnist",
    "E-LETTERS": "emnist:letters",
    "E-BALANCED": "emnist:balanced",
    "E-BYCLASS": "emnist:byclass",
}

data = [
    ("QPC", [1.18, 3.27, 1.66, 1.70, 1.73, 1.67]),
    ("HCLT", [1.21, 3.34, 1.70, 1.75, 1.78, 1.73]),
    ("SHCLT", [1.14, 3.27, 1.52, 1.58, 1.60, 1.54]),
    # ("PNC", [0.87, 2.51, 1.36, 1.33, 1.35, 1.27]),
    # ("IDF", [1.90, 3.47, 2.07, 1.95, 2.15, 1.98]),
    # ("BitSwap", [1.27, 3.28, 1.88, 1.84, 1.96, 1.87]),
]

def _get_data(FILEPATH, indices2path, data):
    seeds = range(42, 43)

    for fl in (32,):
        for rho in ["full", "rank1"]:
            for phase in ["zero", "2pi"]:
                bpd_means = []

                for k in indices2path:

                    mean = []
                    try:
                        for s in seeds:
                            data_file = (FILEPATH / 
                                        "results" / 
                                        f"{indices2path[k]}" / 
                                        f"precision-{fl}" /
                                        "c-128" / 
                                        f"ph-{phase}" / 
                                        f"rho-{rho}" / 
                                        f"seed-{s}"/
                                        "test_results.yaml"
                            )

                            with open(data_file, "r") as file:
                                d = yaml.safe_load(file)

                            mean.append(d["test"]["bpd"])

                        mean = round(sum(mean) / len(mean), 2)
                        bpd_means.append(mean)
                    except:
                        bpd_means.append(-1)
                tex_phase = "0" if phase=="zero" else "2\pi"
                tex_rho = "FR" if rho=="full" else "R1"
                data.insert(-3, (f"$\mathrm{{PVX}}_{{{tex_phase}}}^{{{tex_rho}}}$", bpd_means ))


    data = {k:v for k,v in data}
    return list(indices2path.keys()), data

get_data = partial(_get_data, RESULTSPATH, indices2path, data)
get_data()

(['MNIST', 'FashionMNIST', 'E-MNIST', 'E-LETTERS', 'E-BALANCED', 'E-BYCLASS'],
 {'$\\mathrm{PVX}_{0}^{FR}$': [1.16, 3.37, 1.69, 1.62, 1.65, 1.47],
  '$\\mathrm{PVX}_{2\\pi}^{FR}$': [1.16, 3.55, 1.63, 1.6, 1.64, 1.47],
  '$\\mathrm{PVX}_{0}^{R1}$': [1.24, 3.44, 1.76, 1.7, 1.73, 1.56],
  '$\\mathrm{PVX}_{2\\pi}^{R1}$': [1.17, 3.55, 1.64, 1.61, 1.64, 1.49],
  'QPC': [1.18, 3.27, 1.66, 1.7, 1.73, 1.67],
  'HCLT': [1.21, 3.34, 1.7, 1.75, 1.78, 1.73],
  'SHCLT': [1.14, 3.27, 1.52, 1.58, 1.6, 1.54]})

In [72]:

# parameters = [25.8, 51.6, 38.5, 38.5, 38.5, 2.8, 24.1, 2.8]

In [73]:
def dict_to_latex_table(indices, data_dict, table_caption="Table Caption", label="table_label"):
    # Begin LaTeX table
    latex_table = "\\begin{table*}[t]\n"
    latex_table += "\\centering\n"
    
    latex_table += "\\caption{" + table_caption + "}\n \\label{" + label + "}\n"


    # Begin tabular environment with column headers
    latex_table += "\\begin{tabular}{" "l" + "".join(["c"] * len(data_dict)) + "}\n"
    
    
    # Add header for index column
    
    # Add the remaining column headers
    latex_table += " & " + " & ".join(data_dict.keys()) + " \\\\\n"
    
    # Add mid rule
    latex_table += "\n\cmidrule(lr)\n{2-5}\cmidrule(lr){6-8}\n"
    
    # Add data rows
    for i in range(len(data_dict["QPC"])):
        # Add row index
        row_data = [ indices[i] ]
        
        # Add data for each column
        for key in data_dict.keys():
            if i < len(data_dict[key]):
                row_data.append(f"${'{:.2f}'.format(data_dict[key][i])}$")
            else:
                row_data.append("")
        

        latex_table += " & ".join(row_data) + " \\\\\n"
    
    # latex_parameters = "{\#} parameters & " + " & ".join( list([ f"${p}$" for p in parameters])) 
    # latex_table +=  "\midrule \n" +latex_parameters+ " \\\\\n"

    # End tabular environment
    latex_table += "\\end{tabular}\n"
    
    
    # End LaTeX table
    latex_table += "\\end{table*}\n"
    
    return latex_table


In [74]:
# QPC256 parameters 102760448

# assert len(parameters)== len(data.keys())

caption = "Test set bpd for MNIST datasets (lower is better)."
label = "tab:mnist"

indices, data = get_data()
latex_code = dict_to_latex_table(indices, data, caption, label)

def save_to_file(filename, text):
    with open(filename, 'w') as file:
        file.write(text)


filename = FILEPATH / ".." / "tex_input" / "mnist_table.tex"
save_to_file(filename, latex_code)
