In [2]:
#| default_exp 28_beir-intent-metadata

In [3]:
%load_ext autoreload
%autoreload 2

In [145]:
#| export
import os, json, pandas as pd, ast, scipy.sparse as sp, json_repair, numpy as np, re

from tqdm.auto import tqdm
from typing import Optional, Dict, List, Tuple

from xcai.misc import BEIR_DATASETS
from sugar.core import *

## `Helper functions`

In [136]:
#| export
def save_raw(fname:str, ids:List, txt:List):
    df = pd.DataFrame({"identifier": ids, "text": txt})
    df.to_csv(fname, index=False)

def load_raw(fname:str):
    df = pd.read_csv(fname, keep_default_na=False, na_filter=False)
    return df["identifier"].tolist(), df["text"].tolist()
    

In [9]:
#| export
def combine_df(dirname:str, dtype:str):
    assert dtype in ["single", "multihop"]
    return pd.concat([pd.read_table(f"{dirname}/{fname}") for fname in os.listdir(dirname) if dtype in fname])
    

In [10]:
#| export
def convert_string_to_object(output:Dict):
    try:
        if output[:9] == "```python": output = output[9:]
        if output[-3:] == "```": output = output[:-3]
        return ast.literal_eval(output)
    except:
        return json_repair.loads(output)
    

In [11]:
#| export
def convert_df_into_label_dict(df:pd.DataFrame):
    outputs = {}
    for i,row in tqdm(df.iterrows(), total=df.shape[0]):
        lbl_id, lbl_txt = str(row["id"]), row["document"]
        try:
            substring = convert_string_to_object(row["raw_model_response"])
        except:
            raise ValueError(f"Invalid model response at row #{i}")
            
        assert lbl_id not in outputs
        
        outputs[lbl_id] = {
            "label_id": lbl_id,
            "label_text": lbl_txt,
            "substring": substring,
        }
    return outputs
    

In [115]:
#| export
def get_metadata_matrix_from_label_dict(outputs:Dict, lbl_ids:List, seed:Optional[int]=100):
    
    np.random.seed(seed)

    intent_info = {
        "phrases": dict(), 
        "intents": dict(), 
        "substr": dict(),
    }
    
    lbl_info = {
        "query": dict(), 
        "queries": dict(),
    }
    
    substr2intent = dict()

    lbl_mat = {
        "vocab": dict(),
        "data": [],
        "indices": [],
        "indptr": [0],
    }

    qry_mat = {
        "text": [],
        "data": [],
        "indices": [],
        "indptr": [0],
    }

    def get_original_substring_key(gen):
        for k in gen:
            if "original" in k and "substring" in k:
                return k
        else:
            raise ValueError(f"`original_substring` absent.")
            
    for l in tqdm(lbl_ids):
        o = outputs.get(l, {})
        generations = o.get("substring", {})
        
        for gen in generations.get("substring", []):
            intent = str(np.random.choice(gen["intent_phrases"]))
            substr2intent[gen[get_original_substring_key(gen)]] = intent
            
            idx = lbl_mat["vocab"].setdefault(intent, len(lbl_mat["vocab"]))
            lbl_mat["indices"].append(idx)
            lbl_mat["data"].append(1.0)

            if idx == 2955:
                print(idx)
                import pdb; pdb.set_trace()
                
            intent_info["phrases"].setdefault(idx, []).extend(gen["derived_phrases"])
            intent_info["intents"].setdefault(idx, []).extend(gen["intent_phrases"])
            intent_info["substr"].setdefault(idx, []).append(gen[get_original_substring_key(gen)])
            
        lbl_mat["indptr"].append(len(lbl_mat["indices"]))

        for gen in generations.get("queries", []):
            text = str(np.random.choice(gen["derived_queries"]))
            
            qry_mat["text"].append(text)

            if "intent_phrases" in gen:
                indices = [lbl_mat["vocab"].setdefault(substr2intent.get(o, str(np.random.choice(gen["intent_phrases"]))), len(lbl_mat["vocab"])) for o in gen["answer"]]
            else:
                try:
                    indices = [lbl_mat["vocab"][substr2intent[o]] for o in gen["answer"]]
                except:
                    indices = [lbl_mat["vocab"].setdefault(substr2intent.get(o, o), len(lbl_mat["vocab"])) for o in gen["answer"]]
            
            qry_mat["data"].extend([1.0] * len(indices))
            qry_mat["indices"].extend(indices)
            qry_mat["indptr"].append(len(qry_mat["indices"]))
                
            lbl_info["query"].setdefault(l, []).append(text)
            lbl_info["queries"].setdefault(l, []).extend(gen["derived_queries"] + [gen["primary_query"]])

    shape = (len(lbl_mat["indptr"]) - 1, len(lbl_mat["vocab"]))
    lbl_intent = sp.csr_matrix((lbl_mat["data"], lbl_mat["indices"], lbl_mat["indptr"]), shape=shape, dtype=np.float32)
    
    shape = (len(qry_mat["text"]), len(lbl_mat["vocab"]))
    qry_intent = sp.csr_matrix((qry_mat["data"], qry_mat["indices"], qry_mat["indptr"]), shape=shape, dtype=np.float32)

    return lbl_mat["vocab"], lbl_intent, qry_mat["text"], qry_intent, intent_info, lbl_info
    

In [116]:
#| export
def get_matrix_from_dict(metadata:Dict, ids:List):
    vocab, data, indices, indptr = dict(), [], [], [0]
    for i in tqdm(ids):
        indices.extend([vocab.setdefault(o, len(vocab)) for o in metadata.get(i, [])])
        data.extend([1.0] * len(metadata.get(i, [])))
        indptr.append(len(indices))
    return sp.csr_matrix((data, indices, indptr), dtype=np.float32), vocab
    

In [147]:
#| export
def extract_metadata(dirname:str, dtype:str, lbl_ids:List):
    df = combine_df(dirname, dtype)
    
    outputs = convert_df_into_label_dict(df)

    o = get_metadata_matrix_from_label_dict(outputs, lbl_ids)
    intent_vocab, lbl_intent, qry_txt, qry_intent, intent_info, lbl_info = o
    qry_ids = list(range(len(qry_txt)))

    intent_txt = sorted(intent_vocab, key=lambda x:intent_vocab[x])
    intent_ids = list(range(len(intent_txt)))
    
    assert len(intent_ids) == lbl_intent.shape[1], (
        f"Intent count mismatch: expected {lbl_intent.shape[1]}, got {len(intent_ids)}"
    )
    assert len(intent_ids) == qry_intent.shape[1], (
        f"Intent count mismatch: expected {qry_intent.shape[1]}, got {len(intent_ids)}"
    )

    intent_phrase, phrase_vocab = get_matrix_from_dict(intent_info["phrases"], intent_ids)
    intent_intents, intents_vocab = get_matrix_from_dict(intent_info["intents"], intent_ids)
    intent_substr, substr_vocab = get_matrix_from_dict(intent_info["substr"], intent_ids)

    lbl_query, query_vocab = get_matrix_from_dict(lbl_info["query"], lbl_ids)
    lbl_queries, queries_vocab = get_matrix_from_dict(lbl_info["queries"], lbl_ids)

    intent_info = (intent_phrase, phrase_vocab), (intent_intents, intents_vocab), (intent_substr, substr_vocab)
    lbl_info = (lbl_query, query_vocab), (lbl_queries, queries_vocab)

    return (intent_ids, intent_txt, lbl_intent), (qry_ids, qry_txt, qry_intent), intent_info, lbl_info
    

In [148]:
#| export
def save_metadata(save_dir:str, intent_ids:List, intent_txt:List, lbl_intent:sp.csr_matrix, 
                  qry_ids:List, qry_txt:List, qry_intent:sp.csr_matrix):
    
    n_intent = lbl_intent.shape[1]
    
    lbl_intent.sum_duplicates()
    lbl_intent.sort_indices()
    
    qry_intent.sum_duplicates()
    qry_intent.sort_indices()
    
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(f"{save_dir}/raw_data/", exist_ok=True)
    
    sp.save_npz(f"{save_dir}/intent_qry_X_Y.npz", qry_intent)
    save_raw(f"{save_dir}/raw_data/query.raw.csv", qry_ids, qry_txt)
    
    sp.save_npz(f"{save_dir}/intent_lbl_X_Y.npz", lbl_intent)
    save_raw(f"{save_dir}/raw_data/label_intent.raw.csv", intent_ids, intent_txt)
    

In [149]:
#| export
def save_metadata_info(save_dir:str, intent_info:Tuple, lbl_info:Tuple):    
    (intent_phrase, phrase_vocab), (intent_intents, intents_vocab), (intent_substr, substr_vocab) = intent_info
    (lbl_query, query_vocab), (lbl_queries, queries_vocab) = lbl_info
    
    intent_phrase.sum_duplicates()
    intent_phrase.sort_indices()
    
    intent_intents.sum_duplicates()
    intent_intents.sort_indices()

    intent_substr.sum_duplicates()
    intent_substr.sort_indices()

    lbl_query.sum_duplicates()
    lbl_query.sort_indices()

    lbl_queries.sum_duplicates()
    lbl_queries.sort_indices()
    
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(f"{save_dir}/raw_data/", exist_ok=True)

    def get_raw_info(vocab):
        txt = sorted(phrase_vocab, key=lambda x: phrase_vocab[x])
        ids = list(range(len(txt)))
        return ids, txt
    
    sp.save_npz(f"{save_dir}/derived-phrases_intent_X_Y.npz", intent_phrase)
    ids, txt = get_raw_info(phrase_vocab)
    save_raw(f"{save_dir}/raw_data/intent_derived-phrases.raw.csv", ids, txt)
    
    sp.save_npz(f"{save_dir}/derived-intents_intent_X_Y.npz", intent_intents)
    ids, txt = get_raw_info(intents_vocab)
    save_raw(f"{save_dir}/raw_data/intent_derived-intents.raw.csv", ids, txt)

    sp.save_npz(f"{save_dir}/substring_intent_X_Y.npz", intent_substr)
    ids, txt = get_raw_info(substr_vocab)
    save_raw(f"{save_dir}/raw_data/intent_substring.raw.csv", ids, txt)
    
    sp.save_npz(f"{save_dir}/qry_lbl_X_Y.npz", lbl_query)
    ids, txt = get_raw_info(query_vocab)
    save_raw(f"{save_dir}/raw_data/label_query.raw.csv", ids, txt)

    sp.save_npz(f"{save_dir}/derived-queries_lbl_X_Y.npz", lbl_queries)
    ids, txt = get_raw_info(queries_vocab)
    save_raw(f"{save_dir}/raw_data/label_derived-queries.raw.csv", ids, txt)
    

## `Driver`

In [None]:
#| export
if __name__ == "__main__":
    datasets = ["msmarco"]
    output_dir = "/Users/suchith720/Downloads/"

    for dataset in tqdm(BEIR_DATASETS):
        print(dataset)
        
        data_dir = f"/Users/suchith720/Projects/data/beir/{dataset}/XC/"
        lbl_ids, lbl_txt = load_raw_file(f"{data_dir}/raw_data/label.raw.txt")

        dirname = f"{output_dir}/{dataset}"
        for dtype in ["single", "multihop"]:
            o = extract_metadata(dirname, dtype, lbl_ids)
            (intent_ids, intent_txt, lbl_intent), (qry_ids, qry_txt, qry_intent), intent_info, lbl_info = o
    
            save_dir = f"{data_dir}/document_intent_substring/{dtype}"
            save_metadata(save_dir, intent_ids, intent_txt, lbl_intent, qry_ids, qry_txt, qry_intent)
            save_metadata_info(save_dir, intent_info, lbl_info)
            

In [14]:
datasets = ["msmarco"]

In [101]:
output_dir = "/Users/suchith720/Downloads/"

In [140]:
for dataset in datasets:
    break
    

In [142]:
data_dir = f"/Users/suchith720/Projects/data/beir/{dataset}/XC/"

In [143]:
lbl_ids, lbl_txt = load_raw_file(f"{data_dir}/raw_data/label.raw.txt")

In [144]:
dirname = f"{output_dir}/{dataset}"

In [119]:
o = extract_metadata(dirname, "single", lbl_ids)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/8841823 [00:00<?, ?it/s]

  0%|          | 0/15764 [00:00<?, ?it/s]

  0%|          | 0/15764 [00:00<?, ?it/s]

  0%|          | 0/15764 [00:00<?, ?it/s]

  0%|          | 0/8841823 [00:00<?, ?it/s]

  0%|          | 0/8841823 [00:00<?, ?it/s]

In [120]:
(intent_ids, intent_txt, lbl_intent), (qry_ids, qry_txt, qry_intent), intent_info, lbl_info = o

In [137]:
save_dir = f"/Users/suchith720/Projects/data/beir/{dataset}/XC/document_intent_substring"

In [138]:
save_metadata(save_dir, intent_ids, intent_txt, lbl_intent, qry_ids, qry_txt, qry_intent)

In [139]:
save_metadata_info(save_dir, intent_info, lbl_info)