# Parse results in JSON

In [1]:
import json
import glob
import os
import pandas as pd
import re

In [2]:
def list_files(path, seed=None):
    files = glob.glob(os.path.join(path, "*.json"))
    if seed:
        files = [f for f in files if f'seed{seed}_' in f]  # filter filename with the seed
    files = [f for f in files if os.path.getsize(f) > 0]  # ignore empty files
    files = sorted(files, key=lambda x: "_".join(x.split('_')[:-1]))
    return files

In [3]:
def load_suffixes(path, seed=None):
    """
    Load best suffixes
    """
    data = []
    for file in files:
        with open(file, 'r') as f:
            data += json.load(f)
    print(f'{len(data)} suffixes loaded from {len(files)} files.')
    for i,suffix in enumerate(data):
        for k,v in suffix.items():
            if type(v)==list and len(v) == 1:
                data[i][k] = v[0]
    str_length_search = re.search(r'\/str_length_(\d+)\/', path)
    if str_length_search:
        str_length = str_length_search.group(1)
    else:
        print(f'[INFO] String length not detected from suffix path (/str_length_XX/). Using 4 by default.')
        str_length = 4
    df = pd.DataFrame(data)
    df['number'] = df['targets'].str.extract(r': (\d{'+str(str_length)+'})')
    df['str_length'] = str_length
    if pd.isna(df['number']).sum() > 0:
        print(f"[ERROR] extracting targeted number: {pd.isna(df['number']).sum()} NA values!")
    return df

In [4]:
def get_args(filename):
    pattern = r"str_length_(\d+)/.*model_(\w+)/.*_offset(\d+)_"
    match = re.search(pattern, filename)
    if not match:
        raise ValueError()
    str_length = int(match.group(1))
    model = match.group(2)
    offset = int(match.group(3))
    return str_length, model, offset


In [5]:
f1 = list_files('../results/method_random/type_number/str_length_3/model_llama2')
f2 = list_files('../results/method_random/type_number/str_length_4/model_llama2')
f3 = list_files('../results/method_random/type_number/str_length_5/model_llama2')
files = f1 + f2 + f3

In [6]:
#file = files[1]
#file = 'results/method_random/type_number/model_llama2/gcg_offset0_20231107-132845.json'
stats = []

for file in files:
    with open(file, 'r') as f:
        data = json.load(f)
    
    nb_prefixes = len(data['best'])
    n_steps = data['params']['n_steps']
    n_test_steps = data['params']['test_steps']
    
    str_length, model, data_offset = get_args(file)
    
    nb_log_per_suffix = 1+n_steps//n_test_steps
    max_n_data = nb_log_per_suffix * nb_prefixes   # +1 because there is an eval at the start and the end
    
    #print(max_n_data, len(data['tests']))
    
    if len(data['tests']) > max_n_data:
        print(f"[INFO] The last {len(data['tests']) - max_n_data} will be ignore. Most likely a partial computation that failed in between.")
    
    for i, test in enumerate(data['tests']):
        # do not extract after that (ignore partial run when the node crashed)
        if i+1 > max_n_data:
            break
        idx_data = i // nb_log_per_suffix
        stats.append({
            'model': model,
            'str_length': str_length,
            'Step': (i % nb_log_per_suffix) * n_test_steps,
            'idx_data': data_offset+idx_data,
            'Loss': test['n_loss'][0],
        })

df = pd.DataFrame(stats)
    

[INFO] The last 129 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 63 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 78 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 76 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 92 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 78 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 75 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 61 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 63 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 61 will be ignore. Most likely a partial computation that failed in between.
[INFO] The last 82 will be ignore. Most likely a partial computation 

In [7]:
df

Unnamed: 0,model,str_length,Step,idx_data,Loss
0,llama2,3,0,0,1.972656
1,llama2,3,10,0,1.465820
2,llama2,3,20,0,1.317383
3,llama2,3,30,0,1.167969
4,llama2,3,40,0,1.088867
...,...,...,...,...,...
45295,llama2,5,1460,99,0.130005
45296,llama2,5,1470,99,0.103943
45297,llama2,5,1480,99,0.100342
45298,llama2,5,1490,99,0.085266


In [8]:
df.to_csv('../results/loss_steps.csv')