In [None]:
#| default_exp 19_map-amazon-meta-from-gpt-generations

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import pandas as pd, re, numpy as np, os
from tqdm.auto import tqdm

In [None]:
#| export
from sugar.core import load_raw_file, save_raw_file

In [None]:
#| export
def extract_text_between_tags(text, tag='Label'):
    pattern = fr"<{tag}>(.*?)</{tag}>"
    match = re.search(pattern, text)
    return match.group(1).strip() if match else ''
    

In [None]:
#| export
def extract_generations(df, tag='Label'):
    generations = []
    for i in range(df.shape[0]):
        text = df['raw_model_response'].iloc[i]
        text = extract_text_between_tags(text, tag=tag)
        generations.append(text)
    title = df['title'].tolist()
    return title, generations
    

In [None]:
#| export
def get_file_key(fname):
    key = re.match(r'[a-z]*([0-9]+).tsv', fname)
    return int(key.group(1))
    

In [None]:
#| export
def collate_generations(data_dir, tag='Label'):
    title, generations = [], []

    for fname in tqdm(sorted(os.listdir(data_dir), key=get_file_key)):
        df = pd.read_table(f'{data_dir}/{fname}')
        df.fillna('', inplace=True)
        t, g = extract_generations(df, tag=tag)
        title.extend(t)
        generations.extend(g)

    return title, generations
    

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/LF-AmazonTitles-1.3M_generations/test-outputs/'
tst_title, tst_text = collate_generations(data_dir)

  0%|          | 0/49 [00:00<?, ?it/s]

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/LF-AmazonTitles-1.3M_generations/test-outputs/'

fname = os.listdir(data_dir)[0]
df = pd.read_table(f'{data_dir}/{fname}')
df.fillna('', inplace=True)

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-AmazonTitles-1.3M/raw_data/'

In [None]:
fname = f'{data_dir}/test.raw.txt'
ids, text = load_raw_file(fname)

mapping = {k:v for k,v in zip(tst_title, tst_text)}

entity_text = [f'{o} :: {mapping[o]}' if o in mapping else o for o in text]
save_raw_file(f'{data_dir}/test.entity.txt', ids, entity_text)

In [None]:
#| export
def extract_and_save_generations(generation_dir, data_dir, data_type, tag, save_tag):
    gen_title, gen_text = collate_generations(generation_dir)
    fname = f'{data_dir}/test.raw.txt' if data_type == 'test' else f'{data_dir}/train.raw.txt'
    ids, text = load_raw_file(fname)
    
    mapping = {k:v for k,v in zip(gen_title, gen_text)}
    
    entity_text = [f'{o} :: {mapping[o]}' if o in mapping else o for o in text]
    fname = f'{data_dir}/test_{save_tag}.raw.txt' if data_type == 'test' else f'{data_dir}/train_{save_tag}.raw.txt'
    save_raw_file(fname, ids, entity_text)
    

In [None]:
tag = 'Label'
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-AmazonTitles-1.3M/raw_data/'

In [None]:
generation_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/LF-AmazonTitles-1.3M_generations/test-outputs/'
data_type = 'test'
extract_and_save_generations(generation_dir, data_dir, data_type, tag)

  0%|          | 0/49 [00:00<?, ?it/s]

In [None]:
generation_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/LF-AmazonTitles-1.3M_generations/train-outputs/'
data_type = 'train'
extract_and_save_generations(generation_dir, data_dir, data_type, tag)

  0%|          | 0/113 [00:00<?, ?it/s]

## `__main__`

In [None]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--generation_dir', type=str, required=True)
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--tag', type=str, default='Label')
    parser.add_argument('--data_type', type=str, default=None)
    parser.add_argument('--save_tag', type=str, default='entity')
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    start_time = timer()

    args = parse_args()
    extract_and_save_generations(args.generation_dir, args.data_dir, args.data_type, args.tag, args.save_tag)
    
    end_time = timer()
    print(f'Time elapsed: {end_time-start_time:.2f} seconds.')
    