In [1]:
import polars as pl
import numpy as np

In [2]:
# First, you need to run `model_multitoken_probing.py` (with ngram parameter ranging from 1 to 10) to generate the jsonl files in `multitoken_probes/` directory.
# Then, you can run this notebook to compute the resulting table.

In [3]:
outputs_dir = "multitoken_probes_lumi"

In [13]:
for aggregation in ["product", "avg"]:
    print("\n")
    print(f"Selecting the best hidden state by giving highest **{aggregation}** of individual piecewise accuracies")

    reports = []
    individual_accs = np.zeros((10, 10)) - np.nan

    for ngram in range(1, 11):
        filename = f"{outputs_dir}/results_ngram={ngram}_operand_idx=1.jsonl"
        df = pl.read_ndjson(filename)
        if aggregation == "avg":
            agg_col = pl.col("best_valid_acc").mean()
        else:
            agg_col = pl.col("best_valid_acc").product()
        
        best_hidden_state = df.group_by(["pos_idx", "block_idx"]).agg(agg_col).sort("best_valid_acc", descending=True).head(1)
        pos_idx = best_hidden_state.get_column("pos_idx")[0]
        block_idx = best_hidden_state.get_column("block_idx")[0]

        subset = df.filter(
            (pl.col("pos_idx") == pos_idx) & (pl.col("block_idx") == block_idx)
        ).select(
            ["pos_idx", "block_idx", "suboperand_idx", "test_acc"]
        )

        test_acc = subset.get_column("test_acc").product()
        reports.append(f"ngram={ngram:>2}, pos_idx={pos_idx}, block_idx={block_idx}, estimated acc of full number recovery = {test_acc:.0%}")

        #print(ngram)
        #display(subset)

        individual_accs[ngram - 1, -len(subset):] = subset.get_column("test_acc").to_numpy()

        if ngram == 10:
            display(subset)

    for report in reports:
        print(report)


    print("Individual accuracies:")

    individual_accs = pl.DataFrame(individual_accs, schema=[f"pos {- 10 + (i)}" for i in range(10)])
    individual_accs = individual_accs.to_pandas().map(lambda x: f"{x:.0%}" if not np.isnan(x) else "")
    individual_accs.index = [f"Length {i+1}" for i in range(10)]
    display(individual_accs)



Selecting the best hidden state by giving highest **product** of individual piecewise accuracies


pos_idx,block_idx,suboperand_idx,test_acc
i64,i64,i64,f64
-2,31,0,0.0
-2,31,1,0.002686
-2,31,2,0.0
-2,31,3,0.000244
-2,31,4,0.0
-2,31,5,0.0
-2,31,6,0.0
-2,31,7,0.002441
-2,31,8,0.004883
-2,31,9,0.772949


ngram= 1, pos_idx=-2, block_idx=15, estimated acc of full number recovery = 100%
ngram= 2, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 91%
ngram= 3, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 67%
ngram= 4, pos_idx=-2, block_idx=5, estimated acc of full number recovery = 5%
ngram= 5, pos_idx=-2, block_idx=6, estimated acc of full number recovery = 0%
ngram= 6, pos_idx=-2, block_idx=6, estimated acc of full number recovery = 0%
ngram= 7, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%
ngram= 8, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%
ngram= 9, pos_idx=-2, block_idx=8, estimated acc of full number recovery = 0%
ngram=10, pos_idx=-2, block_idx=31, estimated acc of full number recovery = 0%
Individual accuracies:


Unnamed: 0,pos -10,pos -9,pos -8,pos -7,pos -6,pos -5,pos -4,pos -3,pos -2,pos -1
Length 1,,,,,,,,,,100%
Length 2,,,,,,,,,91%,100%
Length 3,,,,,,,,75%,90%,100%
Length 4,,,,,,,22%,35%,72%,100%
Length 5,,,,,,20%,20%,13%,70%,99%
Length 6,,,,,14%,3%,15%,20%,69%,100%
Length 7,,,,9%,3%,3%,6%,24%,60%,99%
Length 8,,,8%,3%,2%,4%,7%,20%,63%,99%
Length 9,,6%,1%,1%,2%,4%,4%,17%,61%,100%
Length 10,0%,0%,0%,0%,0%,0%,0%,0%,0%,77%




Selecting the best hidden state by giving highest **avg** of individual piecewise accuracies


pos_idx,block_idx,suboperand_idx,test_acc
i64,i64,i64,f64
-2,2,0,0.005127
-2,2,1,0.0
-2,2,2,0.0
-2,2,3,0.0
-2,2,4,0.000732
-2,2,5,0.0
-2,2,6,0.004883
-2,2,7,0.557129
-2,2,8,0.881836
-2,2,9,0.99585


ngram= 1, pos_idx=-2, block_idx=17, estimated acc of full number recovery = 100%
ngram= 2, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 91%
ngram= 3, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 67%
ngram= 4, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 1%
ngram= 5, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
ngram= 6, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
ngram= 7, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
ngram= 8, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
ngram= 9, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
ngram=10, pos_idx=-2, block_idx=2, estimated acc of full number recovery = 0%
Individual accuracies:


Unnamed: 0,pos -10,pos -9,pos -8,pos -7,pos -6,pos -5,pos -4,pos -3,pos -2,pos -1
Length 1,,,,,,,,,,100%
Length 2,,,,,,,,,91%,100%
Length 3,,,,,,,,75%,90%,100%
Length 4,,,,,,,2%,51%,90%,98%
Length 5,,,,,,0%,1%,56%,89%,100%
Length 6,,,,,0%,0%,0%,58%,88%,99%
Length 7,,,,0%,0%,0%,1%,60%,88%,100%
Length 8,,,0%,0%,0%,0%,0%,58%,89%,99%
Length 9,,0%,1%,0%,0%,0%,1%,59%,89%,100%
Length 10,1%,0%,0%,0%,0%,0%,0%,56%,88%,100%
