In [3]:
#| default_exp 26_beir-substring-metadata

In [4]:
%load_ext autoreload
%autoreload 2

In [2]:
#| export
import os, json, pandas as pd, ast, scipy.sparse as sp, json_repair, numpy as np, re

from tqdm.auto import tqdm
from typing import Optional, Dict, List
from xcai.misc import BEIR_DATASETS

## `Helper functions`

In [112]:
#| export
def load_gpt_outputs(fname:str):
    outputs = dict()
    with open(fname) as file:
        for line in file:
            content = json.loads(line)
            outputs[content["doc_id"]] = content["output"]
    return outputs
    

In [113]:
#| export
def get_label_metadata(outputs:Dict, labels:pd.DataFrame):
    return [outputs.get(i, {}) for i in labels["identifier"]]
    

In [114]:
#| export
def get_label_substring_metadata(outputs:Dict):
    # substring metadata
    
    phrases = dict()
    vocab, data, indices, indptr = dict(), [], [], [0]
    for output in tqdm(outputs):
        for o in output.get("substring", []):
            if len(o):
                idx = vocab.setdefault(o["original_substring"], len(vocab))
                data.append(1)
                indices.append(idx)
                assert np.all([type(i) == str for i in o["derived_phrases"]]), '"derived_phrases" should be list of string'
                phrases.setdefault(idx, []).extend(o["derived_phrases"])
        indptr.append(len(data))
        
    lbl_sub = sp.csr_matrix((data, indices, indptr))
    sub_info = pd.DataFrame([(v,k) for k,v in vocab.items()], columns=["identifier", "text"])

    # derived phrases metadata
    
    vocab, data, indices, indptr = dict(), [], [], [0]
    for i in range(sub_info.shape[0]):
        for p in phrases[i]:
            idx = vocab.setdefault(p, len(vocab))
            data.append(1)
            indices.append(idx)
        indptr.append(len(data))

    sub_phs = sp.csr_matrix((data, indices, indptr))
    phs_info = pd.DataFrame([(v,k) for k,v in vocab.items()], columns=["identifier", "text"])
    
    return lbl_sub, sub_info, sub_phs, phs_info
    

In [115]:
#| export
def get_query_substring_metadata(outputs:Dict, sub_info:pd.DataFrame):
    sub_vocab = {v:i for i,(k,v) in sub_info.iterrows()}
    
    # query-substring
    
    query, derived_queries = [], []
    data, indices, indptr = [], [], [0]
    for output in tqdm(outputs):
        for i,o in enumerate(output.get("queries", [])):
            query.append(o["primary_query"])
            idxs = [sub_vocab.setdefault(a, len(sub_vocab)) for a in o["answer"]]
            data.extend([1] * len(idxs))
            indices.extend(idxs)
            indptr.append(len(data))
            derived_queries.append(o["derived_queries"])
            
    data_sub = sp.csr_matrix((data, indices, indptr))
    data_info = pd.DataFrame([(k,v) for k,v in enumerate(query)], columns=["identifier", "text"])
    sub_info = pd.DataFrame([(v,k) for k,v in sub_vocab.items()], columns=["identifier", "text"])

    # query-derived_queries
    
    vocab, data, indices, indptr = dict(), [], [], [0]
    for queries in derived_queries:
        for q in queries:
            idx = vocab.setdefault(q, len(vocab))
            data.append(1)
            indices.append(idx)
        indptr.append(len(data))

    data_der = sp.csr_matrix((data, indices, indptr))
    der_info = pd.DataFrame([(v,k) for k,v in vocab.items()], columns=["identifier", "text"])

    return data_sub, data_info, sub_info, data_der, der_info
    

In [116]:
#| export
def is_valid_output(v:Dict):        
    if type(v) == dict:
        if "substring" in v:
            for o in v["substring"]:
                if len(o):
                    if not ("original_substring" in o and isinstance(o["original_substring"], str)):
                        return False
                        
                    if "derived_phrases" in o and isinstance(o["derived_phrases"], list):
                        if not np.all([isinstance(i, str) for i in o["derived_phrases"]]):
                            return False
                    else:
                        return False
        else:
            return False
    
        if "queries" in v:
            for o in v["queries"]:
                if not ("primary_query" in o and isinstance(o["primary_query"], str)):
                    return False
    
                if "derived_queries" in o and isinstance(o["derived_queries"], list):
                    if not np.all([isinstance(i, str) for i in o["derived_queries"]]):
                        return False
                else:
                    return False
    
                if "answer" in o and isinstance(o["answer"], list):
                    if not np.all([isinstance(i, str) for i in o["answer"]]):
                        return False
                else:
                    return False
        else:
            return False
    else:
        return False
        
    return True
    

In [117]:
#| export
def convert_string_to_objects(outputs:Dict):
    
    def get_line_number(pat:str, e:str):
        m = re.match(pat, str(e))
        return int(m.group(1)) - 1
        
    out, n_invalid = dict(), 0
    for k,v in tqdm(outputs.items()):
        error_flag, n_tries = True, 0

        if len(v) == 0: continue
        
        while error_flag:
            o = json_repair.loads(v)

            if "substrings" in o:
                o["substring"] = o.pop("substrings")
            
            if not is_valid_output(o):
                try:
                    out[k] = ast.literal_eval(v)
                    error_flag = False
                except Exception as e:
                    patterns = [
                        r"closing parenthesis '\)' does not match opening parenthesis '\[.*line (\d+)",
                        r"unterminated string literal \(detected at line \d+\).*line (\d+)\)",
                        r"invalid character '’' \(U\+2019\) .*line (\d+)\)",
                        r"invalid syntax .*line (\d+)\)",
                        r"closing parenthesis '\]' does not match opening parenthesis '\{' on line \d+ .*line (\d+)\)",
                    ]
                    
                    def func_1(line:str):
                        if line[-1] == "]":
                            line = line.replace(")", "")
                        else:
                            line = line.replace(")", "]")
                        return line

                    def func_2(line:str):
                        line = line.replace('”', '"').replace('“', '"')
                        if line[-1] == "]" and line[-2].strip() != '"':
                            line = line[:-1] + '"]'
                        if line[-2:] == "*/":
                            line = line[:-2] + '"]'
                        return line
                        
                    def func_3(line:str):
                        return line.replace("’", "'")

                    def func_4(line:str):
                        if line[-1] == ";": line = line[:-1]
                        return line

                    def func_5(line:str):
                        return line.replace("]", "} ]")

                    functions = [func_1, func_2, func_3, func_4, func_5]

                    n_tries += 1
                    if n_tries > 10:
                        n_invalid += 1
                        break
                        
                    for pat, func in zip(patterns, functions):
                        if re.match(pat, str(e)):
                            idx = get_line_number(pat, e)
                            lines = v.split("\n") 
                            lines[idx] = func(lines[idx])
                            v = "\n".join(lines)
                            break
                    else:
                        raise ValueError(f"Keyword error: {k}")
                        
            else:
                out[k] = o
                error_flag = False

    if n_invalid > 0: print(f"Invalid outputs: {n_invalid/len(outputs):.4f}")
        
    return out
    

In [118]:
#| export
def save_data(save_dir:str, dtype:str, data_sub:sp.csr_matrix, data_info:pd.DataFrame, sub_info:pd.DataFrame, 
              lbl_sub:sp.csr_matrix, sub_phs:sp.csr_matrix, phs_info:pd.DataFrame, data_der:sp.csr_matrix, 
              der_info:pd.DataFrame):
    os.makedirs(f"{save_dir}/raw_data/", exist_ok=True)
    short_hand = {"simple-query": "sq", "multihop-query": "mq"}

    assert dtype in short_hand, f"Invalid data-type: {dtype}."

    data_info.to_csv(f"{save_dir}/raw_data/{dtype}.raw.csv", index=False)
    sub_info.to_csv(f"{save_dir}/raw_data/{short_hand[dtype]}-substring.raw.csv", index=False)
    
    der_info.to_csv(f"{save_dir}/raw_data/{short_hand[dtype]}-derived-queries.raw.csv", index=False)

    sp.save_npz(f"{save_dir}/{dtype}_{short_hand[dtype]}-substring.npz", data_sub)
    sp.save_npz(f"{save_dir}/{dtype}_{short_hand[dtype]}-derived-queries.npz", data_der)

    lbl_sub.resize((lbl_sub.shape[0], sub_info.shape[0]))
    sub_phs.resize((sub_info.shape[0], sub_phs.shape[1]))

    sp.save_npz(f"{save_dir}/lbl_{short_hand[dtype]}-substring.npz", lbl_sub)
    sp.save_npz(f"{save_dir}/{short_hand[dtype]}-substring_{short_hand[dtype]}-derived-phrases.npz", sub_phs)

    phs_info.to_csv(f"{save_dir}/raw_data/{short_hand[dtype]}-substring_{short_hand[dtype]}-derived-phrases.raw.csv", index=False)
    

## `Extract substring`

In [119]:
#| export
def proc_dataset(dataset:str, lbl_dir:str, data_dir:str, save_dir:Optional[str]=None):
    lbl_file = f"{lbl_dir}/{dataset}/XC/raw_data/label.raw.csv"
    labels = pd.read_csv(lbl_file)

    for dtype in ["simple", "multihop"]:
        outputs = load_gpt_outputs(f"{data_dir}/{dataset}_{dtype}_label.jsonl")
        
        outputs = convert_string_to_objects(outputs)
        outputs = get_label_metadata(outputs, labels)

        lbl_sub, sub_info, sub_phs, phs_info = get_label_substring_metadata(outputs)
        data_sub, data_info, sub_info, data_der, der_info = get_query_substring_metadata(outputs, sub_info)

        save_dir = f"{lbl_dir}/{dataset}/XC/document_substring/" if save_dir is None else save_dir
        save_data(save_dir, f"{dtype}-query", data_sub, data_info, sub_info, lbl_sub, sub_phs, phs_info, 
                  data_der, der_info)
        

In [5]:
#| export
def main(lbl_dir:str, data_dir:str, save_dir:Optional[str]=None):
    for dataset in tqdm(BEIR_DATASETS):
        proc_dataset(dataset, lbl_dir, data_dir, save_dir)
    

In [None]:
#| export
if __name__ == "__main__":
    lbl_dir = "/Users/suchith720/Projects/data/beir/"
    data_dir = "/Users/suchith720/Downloads/"
    main(lbl_dir, data_dir)
    

In [7]:
fname = "/Users/suchith720/Projects/data/beir/arguana/XC/document_substring/raw_data/mq-derived-queries.raw.csv"

In [8]:
df = pd.read_csv(fname)

In [9]:
df.head()

Unnamed: 0,identifier,text
0,0,"According to British farmer Simon Farrell, wha..."
1,1,What figure did Simon Farrell state as the UN'...
2,2,Simon Farrell challenges UN data; what percent...
3,3,"In disputing UN's livestock carbon estimate, w..."
4,4,What global carbon emissions percentage is att...
