In [2]:
#| default_exp 28_beir-intent-metadata

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import os, json, pandas as pd, ast, scipy.sparse as sp, json_repair, numpy as np, re

from tqdm.auto import tqdm
from typing import Optional, Dict, List
from xcai.misc import BEIR_DATASETS
from sugar.core import *

## `Helper functions`

In [4]:
#| export
def combine_df(dirname:str, dtype:str):
    assert dtype in ["single", "multihop"]
    return pd.concat([pd.read_table(f"{dirname}/{fname}") for fname in os.listdir(dirname) if dtype in fname])
    

In [5]:
#| export
def convert_string_to_object(output:Dict):
    try:
        if output[:9] == "```python": output = output[9:]
        if output[-3:] == "```": output = output[:-3]
        return ast.literal_eval(output)
    except:
        return json_repair.loads(output)
    

In [6]:
#| export
def convert_df_into_label_dict(df:pd.DataFrame):
    outputs = {}
    for i,row in tqdm(df.iterrows(), total=df.shape[0]):
        lbl_id, lbl_txt = str(row["id"]), row["document"]
        try:
            substring = convert_string_to_object(row["raw_model_response"])
        except:
            raise ValueError(f"Invalid model response at row #{i}")
            
        assert lbl_id not in outputs
        
        outputs[lbl_id] = {
            "label_id": lbl_id,
            "label_text": lbl_txt,
            "substring": substring,
        }
    return outputs
    

In [30]:
#| export
def get_metadata_matrix_from_label_dict(outputs:Dict, lbl_ids:List, seed:Optional[int]=100):
    
    np.random.seed(seed)

    intent_info = {
        "phrases": dict(), 
        "intents": dict(), 
        "substr": dict(),
    }
    
    lbl_info = {
        "query": dict(), 
        "queries": dict(),
    }
    
    substr2intent = dict()

    lbl_mat = {
        "vocab": dict(),
        "data": [],
        "indices": [],
        "indptr": [0],
    }

    qry_mat = {
        "text": [],
        "data": [],
        "indices": [],
        "indptr": [0],
    }

    def get_original_substring_key(gen):
        for k in gen:
            if "original" in k and "substring" in k:
                return k
        else:
            raise ValueError(f"`original_substring` absent.")
            
    for l in tqdm(lbl_ids):
        o = outputs.get(l, {})
        generations = o.get("substring", {})

        for gen in generations.get("substring", []):
            intent = str(np.random.choice(gen["intent_phrases"]))
            substr2intent[gen[get_original_substring_key(gen)]] = intent
            
            idx = lbl_mat["vocab"].setdefault(intent, len(lbl_mat["vocab"]))
            lbl_mat["indices"].append(idx)
            lbl_mat["data"].append(1.0)

            intent_info["phrases"].setdefault(idx, []).extend(gen["derived_phrases"])
            intent_info["intents"].setdefault(idx, []).extend(gen["intent_phrases"])
            intent_info["substr"].setdefault(idx, []).append(gen[get_original_substring_key(gen)])
            
        lbl_mat["indptr"].append(len(lbl_mat["indices"]))

        for gen in generations.get("queries", []):
            text = str(np.random.choice(gen["derived_queries"]))
            
            qry_mat["text"].append(text)

            if "intent_phrases" in gen:
                indices = [lbl_mat["vocab"].setdefault(substr2intent.get(o, str(np.random.choice(gen["intent_phrases"]))), len(lbl_mat["vocab"])) for o in gen["answer"]]
            else:
                try:
                    indices = [lbl_mat["vocab"][substr2intent[o]] for o in gen["answer"]]
                except:
                    indices = [lbl_mat["vocab"].setdefault(substr2intent.get(o, o), len(lbl_mat["vocab"])) for o in gen["answer"]]
            
            qry_mat["data"].extend([1.0] * len(indices))
            qry_mat["indices"].extend(indices)
            qry_mat["indptr"].append(len(qry_mat["indices"]))
                
            lbl_info["query"].setdefault(l, []).append(text)
            lbl_info["queries"].setdefault(l, []).extend(gen["derived_queries"] + [gen["primary_query"]])

    lbl_intent = sp.csr_matrix((lbl_mat["data"], lbl_mat["indices"], lbl_mat["indptr"]), dtype=np.float32)
    qry_intent = sp.csr_matrix((qry_mat["data"], qry_mat["indices"], qry_mat["indptr"]), dtype=np.float32)

    return lbl_mat["vocab"], lbl_intent, qry_mat["text"], qry_intent, intent_info, lbl_info
    

In [5]:
#| export
def get_matrix_from_dict(metadata:Dict, ids:List):
    vocab, data, indices, indptr = dict(), [], [], [0]
    for i in tqdm(ids):
        indices.extend([vocab.setdefault(o, len(vocab)) for o in metadata[i]])
        data.extend([1.0] * len(metadata[i]))
        indptr.append(len(indices))
    return sp.csr_matrix((data, indices, indptr), dtype=np.float32), vocab
    

## `Driver`

In [None]:
#| export
if __name__ == "__main__":
    datasets = ["msmarco"]
    data_dir = "/Users/suchith720/Downloads/"

    for dataset in datasets:
        break

    dirname = f"{data_dir}/{dataset}"
    df = combine_df(dirname, "single")

    outputs = convert_df_into_label_dict(df)

    data_dir = "/Users/suchith720/Projects/"
    lbl_ids, lbl_txt = load_raw_file(f"{data_dir}/data/beir/msmarco/XC/raw_data/label.raw.txt")

    o = get_metadata_matrix_from_label_dict(outputs, lbl_ids)
    intent_vocab, lbl_intent, qry_text, qry_intent, intent_info, lbl_info = o

    intent_txt = sorted(intent_vocab, key=lambda x:intent_vocab[x])
    intent_ids = list(range(len(intent_txt)))

    intent_phrases, phrase_vocab = get_matrix_from_dict(intent_info["phrases"], intent_ids)
    

In [8]:
datasets = ["msmarco"]

In [9]:
data_dir = "/Users/suchith720/Downloads/"

In [10]:
for dataset in datasets:
    break
    

In [11]:
dirname = f"{data_dir}/{dataset}"
df = combine_df(dirname, "single")

In [12]:
outputs = convert_df_into_label_dict(df)

  0%|          | 0/287762 [00:00<?, ?it/s]

In [13]:
data_dir = "/Users/suchith720/Projects/"
lbl_ids, lbl_txt = load_raw_file(f"{data_dir}/data/beir/msmarco/XC/raw_data/label.raw.txt")

In [None]:
o = get_metadata_matrix_from_label_dict(outputs, lbl_ids)

  0%|          | 0/8841823 [00:00<?, ?it/s]

In [17]:
intent_vocab, lbl_intent, qry_text, qry_intent, intent_info, lbl_info = o

In [18]:
intent_txt = sorted(intent_vocab, key=lambda x:intent_vocab[x])
intent_ids = list(range(len(intent_txt)))

In [6]:
intent_phrases, phrase_vocab = get_matrix_from_dict(intent_info["phrases"], intent_ids)