In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../../evaluation/")

In [3]:
from prompting import user_nlls
from collections import defaultdict
from metrics import torch_compute_confidence_interval
from tqdm import tqdm

In [5]:
user_id = "1308026329"
modes = ["none", "user", "peer", "random"]

config = {
    "from_disk": True,
    "device": "cuda",
    "user_id": user_id,
    "model_id": "gpt2",
    "ctxt_len": 600,
    "window_len": None,
#     "stride": "half",
#     "mode": "none",
    "seq_sep": "\n",
    "batched": True,
    "batch_size": 8,
    "token_level_nlls": True,
}

### Load

In [None]:
import pandas as pd

results = pd.read_csv("out/stride-length.csv")

If context length is 600, then the window size will be 1024 - 600 = 424.

This means that the stride can be anything between 1-424.

The default "half" stride would be 424 / 2 = 212 

In [36]:
import numpy as np

strides = np.geomspace(1, 424, num=11, endpoint=False, dtype=int)

strides = np.unique(strides)

strides

array([  1,   3,   5,   9,  15,  27,  46,  81, 141, 244])

In [37]:
results = []

for s in tqdm(strides, position=0):
    config["stride"] = s
    for m in tqdm(modes, position=1, leave=False):
        config["mode"] = m
        nlls = user_nlls(config)
        mean, ci = torch_compute_confidence_interval(data=nlls, confidence=0.9)
        
        results.append({
            "mean": mean,
            "ci": ci,
            "context": m,
            "stride": s
        })

  0%|                                                                 | 0/10 [00:00<?, ?it/s]
  0%|                                                                  | 0/4 [00:00<?, ?it/s][A

  0%|                                                                  | 0/2 [00:00<?, ?it/s][A[A

 50%|█████████████████████████████                             | 1/2 [00:00<00:00,  4.21it/s][A[A

100%|██████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.29it/s][A[A

                                                                                             [A[A
 25%|██████████████▌                                           | 1/4 [00:03<00:10,  3.62s/it][A

  0%|                                                                  | 0/2 [00:00<?, ?it/s][A[A

 50%|█████████████████████████████                             | 1/2 [00:00<00:00,  1.45it/s][A[A

100%|██████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.46it/s][A[A

    

 75%|███████████████████████████████████████████▌              | 3/4 [00:12<00:04,  4.26s/it][A

  0%|                                                                  | 0/2 [00:00<?, ?it/s][A[A

 50%|█████████████████████████████                             | 1/2 [00:00<00:00,  1.47it/s][A[A

100%|██████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.47it/s][A[A

                                                                                             [A[A
100%|██████████████████████████████████████████████████████████| 4/4 [00:16<00:00,  4.34s/it][A
 40%|██████████████████████▊                                  | 4/10 [01:07<01:41, 16.89s/it][A
  0%|                                                                  | 0/4 [00:00<?, ?it/s][A

  0%|                                                                  | 0/2 [00:00<?, ?it/s][A[A

 50%|█████████████████████████████                             | 1/2 [00:00<00:00,  4.42it/s][A[A

100%|

100%|██████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.04it/s][A[A

                                                                                             [A[A
 50%|█████████████████████████████                             | 2/4 [00:08<00:08,  4.29s/it][A

  0%|                                                                  | 0/3 [00:00<?, ?it/s][A[A

 33%|███████████████████▎                                      | 1/3 [00:00<00:01,  1.47it/s][A[A

 67%|██████████████████████████████████████▋                   | 2/3 [00:01<00:00,  1.47it/s][A[A

100%|██████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.05it/s][A[A

                                                                                             [A[A
 75%|███████████████████████████████████████████▌              | 3/4 [00:13<00:04,  4.46s/it][A

  0%|                                                                  | 0/3 [00:00<?, ?it/s][A[A

## Plot

In [38]:
from plotting import line

In [39]:
line(data_frame=results, x="stride", y="mean", error_y="ci", error_y_mode='bands', color="context", markers=True, hover_data=["ci"], title="Avg NLL")

### Save 

In [40]:
import pandas as pd

df = pd.DataFrame(results)
df.to_csv("out/stride-length.csv")