In [22]:
import json
import numpy as np
import os
import pandas as pd
import pickle as pkl
from tqdm.notebook import tqdm

In [18]:
seed = 0
data_dir = '../chunwei-data'
output_dir = '../chunwei-data/predictions'
num_labels = 10

In [19]:
bin_edges = np.linspace(0, 512, num_labels + 1)[1:]

In [7]:
with open(os.path.join(data_dir, 'split', f'seed{seed}.json'), 'r') as fin:
    split = json.load(fin)
test_ids = set(split['test'])

with open(os.path.join(data_dir, 'all_layers_1k.json'), 'r') as fin:
    records = json.load(fin)['records'][1:]

test_records = [x for x in records if x['record_id'] in test_ids]

In [15]:
output = list()
for r in test_records:
    raw_prompt = r['output']
    pstart = raw_prompt.index('Below is an instruction')
    pend = raw_prompt.index('### Response:')
    prompt = raw_prompt[pstart:pend] + '### Response:'
    output.append({
        "id": r['record_id'],
        "prompt": prompt,
        "groundtruth": r['iteration_count']
    })

ids = [x['id'] for x in output]
sorted_idx = np.argsort(ids)
output = [output[i] for i in sorted_idx]

In [31]:
def convert_logit_to_pred(logits, bin_edges):
    exp_x = np.exp(logits - np.amax(logits))
    proba = exp_x / np.sum(exp_x)
    return np.sum([p * b for p, b in zip(proba, bin_edges)])


for layer_idx in tqdm(range(32)):

    pred_by_id = dict()

    cached_pred_path = os.path.join(output_dir, f'L{layer_idx}_class{num_labels}_seed{seed}.pkl')
    if os.path.isfile(cached_pred_path):
        with open(cached_pred_path, 'rb') as fout:
            pred_raw = pkl.load(fout)
        
        pred = list()
        for raw in pred_raw:
            pred.append({
                'id': raw['id'],
                'remaining_steps': raw['remaining_steps'],
                'pred': convert_logit_to_pred(raw['pred'], bin_edges)
            })
        
        pred = pd.DataFrame(pred)
        grb_id = pred.groupby(by='id')
        
        for id, sub_df in grb_id:
            rs, ps = sub_df.remaining_steps.values, sub_df.pred.values
            sorted_idx = np.argsort(rs)[::-1]
            rs = rs[sorted_idx]
            ps = ps[sorted_idx]

            pred_by_id[id] = ps.tolist()
        
        for rec in output:
            pred_vec = pred_by_id[rec['id']]
            rec[f'pred_refined_L{layer_idx}'] = pred_vec

  0%|          | 0/32 [00:00<?, ?it/s]

In [33]:
with open(f'eval/alpaca-by-layer-seed{seed}-class{num_labels}.json', 'w') as fout:
    json.dump(output, fout, indent=2)