In [1]:
#| default_exp 11_msmarco-llama-entities

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
from tqdm.auto import tqdm
from bs4 import BeautifulSoup
import json, re, scipy.sparse as sp, pickle, numpy as np

In [4]:
#| export
def load_raw_txt(fname):
    ids, texts = list(), list()
    with open(fname, 'r') as file:
        for line in file:
            k,v = line[:-1].split("->", maxsplit=1)
            ids.append(int(k)); texts.append(v)
    return ids, texts
    

In [5]:
#| export
def select_entity(entities, entity_type='entity_canonical_category'):
    if entity_type == 'entity_canonical_category': return [o for o in entities]
    elif entity_type == 'entity': return [o.split(' | ')[0] for o in entities]
    elif entity_type == 'canonical': return [o.split(' | ')[1] for o in entities]
    elif entity_type == 'category': return [o.split(' | ')[2] for o in entities]
    elif entity_type == 'entity_canonical': return [ ' | '.join(o.split(' | ')[0:2]) for o in entities]
    else: raise ValueError(f'Invalid entity type: {entity_type}')
    

In [6]:
#| export
def load_entities(fname, entity_type='entity_canonical_category'):
    entities = dict()
    with open(fname, 'r') as file:
        for line in file:
            generations = json.loads(line)
            soup = BeautifulSoup(generations['gen'])
            entity = [entity.text for entity in soup.find_all('entities') if len(entity.text.split(' | ')) == 3]
            entities.setdefault(int(generations['qid']), []).extend(select_entity(entity, entity_type))
    return entities
    

In [7]:
#| export
def load_data(data_dir, entity_dir, entity_type='entity_canonical_category'):
    trn_ids, trn_texts = load_raw_txt(f'{data_dir}/raw_data/train.raw.txt')
    trn_entities = load_entities(f'{entity_dir}/trn_llama_generations.txt', entity_type=entity_type)
    trn_data = (trn_ids, trn_texts, trn_entities)

    tst_ids, tst_texts = load_raw_txt(f'{data_dir}/raw_data/test.raw.txt')
    tst_entities = load_entities(f'{entity_dir}/val_llama_generations.txt', entity_type=entity_type)
    tst_data = (tst_ids, tst_texts, tst_entities)

    entity_to_idx = dict()
    for v in trn_entities.values():
        for e in v: entity_to_idx.setdefault(e, len(entity_to_idx))
    for v in tst_entities.values():
        for e in v: entity_to_idx.setdefault(e, len(entity_to_idx))
        
    return trn_data, tst_data, entity_to_idx
        

In [8]:
#| export
def save_raw_txt(fname, ids, texts):
    with open(fname, 'w') as file:
        for k,v in tqdm(zip(ids, texts), total=len(ids)):
            file.write(f'{k}->{v}\n')
        

In [9]:
#| export
def construct_matrix(texts, entities, entity_to_idx):
    data, indices, indptr = [], [], [0]
    for qtxt in tqdm(texts):
        entity = entities.get(qtxt, [])
        data.extend([1] * len(entity))
        indices.extend([entity_to_idx[o] for o in entity])
        indptr.append(len(indices))
    return sp.csr_matrix((data, indices, indptr), shape=(len(texts), len(entity_to_idx)), dtype=np.int64)


In [10]:
#| export 
def construct_msmarco_entities(data_dir:str, entity_dir:str, entity_type='entity_canonical_category'):
    trn_data, tst_data, entity_to_idx = load_data(data_dir, entity_dir, entity_type=entity_type)

    entity_texts = sorted(entity_to_idx, key=lambda x: entity_to_idx[x])
    save_raw_txt(f'{data_dir}/raw_data/llama_{entity_type}.raw.txt', list(range(len(entity_texts))), entity_texts)

    trn_mat = construct_matrix(trn_data[1], trn_data[2], entity_to_idx)
    sp.save_npz(f'{data_dir}/llama_{entity_type}_trn_X_Y.npz', trn_mat)

    tst_mat = construct_matrix(tst_data[1], tst_data[2], entity_to_idx)
    sp.save_npz(f'{data_dir}/llama_{entity_type}_tst_X_Y.npz', tst_mat)

    lbl_mat = sp.csr_matrix((sp.load_npz(f'{data_dir}/trn_X_Y.npz').shape[1], len(entity_to_idx)), dtype=np.int64)
    sp.save_npz(f'{data_dir}/llama_{entity_type}_lbl_X_Y.npz', lbl_mat)
    

In [11]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--entity_dir', type=str, required=True)
    parser.add_argument('--entity_type', type=str, required=True)
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    construct_msmarco_entities(args.data_dir, args.entity_dir, args.entity_type)
    

In [None]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco-data/XC"
entity_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco_entities/gpt/"
entity_type = 'entity_canonical'

construct_msmarco_entities(data_dir, entity_dir, entity_type=entity_type)