# `Conflation`

In [6]:
#| default_exp conflate

In [7]:
#| export
import numpy as np, pandas as pd, scipy.sparse as sp, torch, json, os, re
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from typing import Optional, Union, Dict, List
from IPython.display import display
from xclib.utils.sparse import retain_topk
from collections import Counter

## Helper functions

In [8]:
#| export
def display_similar_items(matrix:sp.csr_matrix, n_data:Optional[int]=10, n_view_data:Optional[int]=20, 
                          seed:Optional[int]=1000):
    np.random.seed(seed)
    rnd_idx = np.random.permutation(matrix.shape[0])[:n_data]
    
    outputs = list()
    for idx in rnd_idx:
        sort_idx = np.argsort(matrix[idx].data)[:-n_view_data:-1]
        scores = matrix[idx].data[sort_idx]
        indices = matrix[idx].indices[sort_idx]
        labels = [meta_info['text'][i] for i in indices]
        o = {
            "Substring": meta_info['text'][idx],
            "Predictions": [(x, float(y)) for x,y in zip(labels, scores)]
        }
        outputs.append(o)
    return outputs
    

In [67]:
#| export
def get_conflated_text(clusters:List, meta_info:Dict):
    texts = [[meta_info["text"][i] for i in c] for c in clusters]
    idx = np.argsort([len(c) for c in clusters])[::-1]
    return [texts[i] for i in idx]
        

In [12]:
#| export
class Operations:

    @staticmethod
    def compute_yty(mat:sp.csr_matrix, bsz:Optional[int]=1000, normalize=True):
        mat_t = mat.transpose()
        out = sp.vstack([mat[i:i+bsz]@mat_t for i in tqdm(range(0, mat.shape[0], bsz))])
        def invert(a):
            return np.divide(1.0, a, out=np.zeros_like(a), where=a!=0)
        return out.multiply(invert(out.sum(axis=1))).multiply(invert(out.sum(axis=0))).tocsr()

    @staticmethod
    def minimum(matrix:sp.csr_matrix):
        return np.array([matrix.data[i:j].min() for i,j in zip(matrix.indptr, matrix.indptr[1:])])

    @staticmethod
    def min_clamp(matrix:sp.csr_matrix, value:float, inplace:Optional[bool]=True):
        if not inplace:
            matrix = matrix.copy()
        matrix.data[matrix.data <= value] = 0
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def max_clamp(matrix:sp.csr_matrix, value:float, inplace:Optional[bool]=True):
        if not inplace:
            matrix = matrix.copy()
        matrix.data[matrix.data >= value] = 0
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def diff_threshold(matrix:sp.csr_matrix, value:float):
        for i,j in zip(matrix.indptr, matrix.indptr[1:]):
            data = matrix.data[i:j]
            if len(data):
                idx = np.where(data.max() - data > value)[0]
                data[idx] = 0.0
                matrix.data[i:j] = data
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def get_clusters(groups:np.array, lbl_idx:Optional[np.array]=None):
        lbl_idx = np.arange(groups.shape[0]) if lbl_idx is None else lbl_idx
        assert groups.shape == lbl_idx.shape, "`groups` and `lbl_idx` must have the same shape."
        clusters = dict()
        for k,v in zip(groups, lbl_idx): 
            clusters.setdefault(k, []).append(v)
        return list(clusters.values())
    

In [71]:
#| export
class Filter:

    @staticmethod
    def from_max_size(clusters:List, size:int):
        return [c for c in clusters if len(c) < size]

    @staticmethod
    def from_min_size(clusters:List, size:int):
        return [c for c in clusters if len(c) > size]

    @staticmethod
    def remove_top_clusters(clusters:List, topk:int):
        sort_idx = np.argsort([len(c) for c in clusters])[-topk-1::-1]
        return [clusters[i] for i in sort_idx]
        

## Load data

In [19]:
data_dir = "/Users/suchith720/Projects/data/beir/msmarco/XC/substring/"

trn_file = f"{data_dir}/substring_trn_X_Y.npz"
tst_file = f"{data_dir}/substring_tst_X_Y.npz"
info_file = f"{data_dir}/raw_data/substring.raw.csv"

trn_meta, tst_meta = sp.load_npz(trn_file), sp.load_npz(tst_file)
meta_info = pd.read_csv(info_file)

meta_phrases = sp.load_npz(f"{data_dir}/derived-phrases_substring_X_Y.npz")

In [20]:
model_dir = "/Users/suchith720/Downloads/00_msmarco-gpt-concept-substring-linker-with-ngame-loss-001/"

In [21]:
with pd.option_context("display.max_colwidth", None):
    display(meta_info.head())

Unnamed: 0,identifier,text
0,0,success of the Manhattan Project
1,1,hundreds of thousands of innocent lives obliterated
2,2,Restorative justice
3,3,amber
4,4,shades of yellow


## `Semantic:` Connected components

In [13]:
meta_mat = sp.load_npz(f"{model_dir}/test_predictions_labels.npz")

In [22]:
meta_mat = sp.load_npz(f"{model_dir}/test_predictions_labels_msmarco-distilbert-cos-v5.npz")

In [23]:
print(json.dumps(display_similar_items(meta_mat, n_data=10, n_view_data=20, seed=1000), indent=4))

[
    {
        "Substring": "fascial aches, pains, tension and restrictions",
        "Predictions": [
            [
                "fascial aches, pains, tension and restrictions",
                1.0
            ],
            [
                "Muscle pain and aches",
                0.6356141567230225
            ],
            [
                "muscle aches and pain",
                0.6307234168052673
            ],
            [
                "pain and muscle tension",
                0.6281501650810242
            ],
            [
                "Muscle Aches and Pains",
                0.6267786622047424
            ],
            [
                "muscle aches and pains",
                0.6267786622047424
            ],
            [
                "Muscle aches and pains",
                0.6267786622047424
            ],
            [
                "muscle aches and soreness",
                0.6230722665786743
            ],
            [
                "Aches,

In [61]:
def from_similarity(lbl_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                    sim_topk:Optional[int]=20):
    lbl_lbl = retain_topk(lbl_lbl, k=sim_topk)
    lbl_lbl = Operations.min_clamp(lbl_lbl, score_thresh)
    lbl_lbl = Operations.diff_threshold(lbl_lbl, diff_thresh)
    
    n, groups = sp.csgraph.connected_components(lbl_lbl)
    return Operations.get_clusters(groups)
    

In [63]:
clusters = from_similarity(meta_mat, score_thresh=0.5, diff_thresh=0.2, sim_topk=10)

In [68]:
texts = get_conflated_text(clusters, meta_info)

In [72]:
print(f"Number of clusters: {len(clusters)}")

Number of clusters: 987959


In [70]:
texts[10]

['typical lease',
 'lease agreement',
 'Rental Agreements',
 'parties to the agreement',
 'terms, conditions, and specifications of the job contract',
 'binding decision',
 'binding agreement',
 'existence of a contractual relationship',
 'legally binding',
 'binding contract in the foreign exchange market',
 'contracts',
 'Lease/Option',
 'Lease or Rent with the Option to Buy',
 'written contract',
 'agreement creating obligations enforceable by law',
 'term of a contract',
 'lease',
 'terms of an agreement',
 'legally binding arrangements',
 'term of the contract',
 'agreement among the States',
 'in agreement with the terms in the contract',
 'written legal contract',
 'lease agreements',
 'binding ruling',
 'deal',
 'contract',
 'in order to be legally binding',
 'as specified in the contract',
 'employment agreement',
 'contract between employer and employee',
 'three-party agreement',
 'legally binding agreement',
 'between an employer and an employee',
 'lease with an option to 

## `Predictions`

In [124]:
model_dir = "/Users/suchith720/Downloads/00_msmarco-gpt-concept-substring-linker-with-ngame-loss-001/"
data_meta = sp.load_npz(f"{model_dir}/train_predictions.npz")

In [73]:
def from_predictions(data_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                     pred_topk:Optional[int]=3):
    data_lbl = retain_topk(data_lbl, k=pred_topk)
    data_lbl = Operations.min_clamp(data_lbl, score_thresh)
    data_lbl = Operations.diff_threshold(data_lbl, diff_thresh)
    
    lbl_mat = Operations.compute_yty(data_lbl.transpose())
    n, groups = sp.csgraph.connected_components(lbl_mat)
    return Operations.get_clusters(groups)
    

## `Derived phrases`

In [74]:
def from_derived_phrases(lbl_phrases:sp.csr_matrix):
    lbl_mat = Operations.compute_yty(lbl_phrases)
    n, groups = sp.csgraph.connected_components(lbl_mat)
    return Operations.get_clusters(groups)
    

In [76]:
clusters = from_derived_phrases(meta_phrases)

  0%|          | 0/1658 [00:00<?, ?it/s]

In [78]:
texts = get_conflated_text(clusters, meta_info)

In [85]:
len(texts)

1365625

## `Conflation`

In [87]:
#| export
class Conflate:

    @staticmethod
    def from_similarity(lbl_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                        sim_topk:Optional[int]=20):
        lbl_lbl = retain_topk(lbl_lbl, k=sim_topk)
        lbl_lbl = Operations.min_clamp(lbl_lbl, score_thresh)
        lbl_lbl = Operations.diff_threshold(lbl_lbl, diff_thresh)
        
        n, groups = sp.csgraph.connected_components(lbl_lbl)
        return Operations.get_clusters(groups)

    @staticmethod
    def from_predictions(data_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                         pred_topk:Optional[int]=3):
        data_lbl = retain_topk(data_lbl, k=pred_topk)
        data_lbl = Operations.min_clamp(data_lbl, score_thresh)
        data_lbl = Operations.diff_threshold(data_lbl, diff_thresh)
        
        lbl_mat = Operations.compute_yty(data_lbl.transpose())
        n, groups = sp.csgraph.connected_components(lbl_mat)
        return Operations.get_clusters(groups)

    @staticmethod
    def from_derived_phrases(lbl_phrases:sp.csr_matrix):
        lbl_mat = Operations.compute_yty(lbl_phrases)
        n, groups = sp.csgraph.connected_components(lbl_mat)
        return Operations.get_clusters(groups)
    

## `Save data`

In [88]:
#| export
class SaveData:

    @staticmethod
    def get_groups(clusters:List, n_lbl:int):
        groups = np.full(n_lbl, -1)
        for i,idx in enumerate(clusters):
            groups[idx] = i
        idxs = np.where(groups == -1)[0]
        groups[idxs] = np.arange(len(clusters), len(clusters) + idxs.shape[0])
        return groups

    @staticmethod
    def conflate_matrix(matrix:sp.csr_matrix, groups:np.array, n_lbl:Optional[int]=None):
        data, indptr = matrix.data, matrix.indptr
        indices = [groups[i] for i in matrix.indices]
        shape = None if n_lbl is None else (matrix.shape[0], n_lbl)
        return sp.csr_matrix((data, indices, indptr), shape=shape)

    @staticmethod
    def get_conflated_info(clusters:List, info:Dict):
        conflated_info, flag = dict(), np.zeros(len(info["identifier"]))
        for idxs in clusters:
            idx = np.random.choice(idxs)
            conflated_info.setdefault("identifier", []).append(info["identifier"][idx])
            conflated_info.setdefault("text", []).append(info["text"][idx])
            flag[idxs] = 1
        idxs = np.where(flag == 0)[0]
        for idx in idxs:
            conflated_info.setdefault("identifier", []).append(info["identifier"][idx])
            conflated_info.setdefault("text", []).append(info["text"][idx])
            flag[idx] = 1
        assert flag.all(), "All items should be covered."
        return conflated_info

    @staticmethod
    def save_raw(fname:str, ids:List, txt:List):
        df = pd.DataFrame({"identifier": ids, "text": txt})
        df.to_csv(fname, index=False)

    @staticmethod
    def proc(save_dir:str, trn_lbl:sp.csr_matrix, lbl_info:Dict, clusters:List, tst_lbl:Optional[sp.csr_matrix]=None):
        groups = get_groups(clusters, n_lbl=trn_meta.shape[1])

        conflated_trn_meta = conflate_matrix(trn_meta, groups)
        conflated_tst_meta = conflate_matrix(tst_meta, groups, n_lbl=conflated_trn_meta.shape[1])
        conflated_meta_info = get_conflated_info(clusters, meta_info)

        n = len([o for o in os.listdir(save_dir) if re.match(r'conflation_[0-9]{2}', o)]) + 1
        save_dir = f"{save_dir}/conflation_{n:02d}"
        
        raw_dir = f"{save_dir}/raw_data"
        os.makedirs(raw_dir, exist_ok=True)

        sp.save_npz(f"{save_dir}/{os.path.basename(trn_file)}", conflated_trn_meta)
        sp.save_npz(f"{save_dir}/{os.path.basename(tst_file)}", conflated_tst_meta)

        raw_file = f"{raw_dir}/{os.path.basename(info_file)}"
        save_raw(raw_file, conflated_meta_info["identifier"], conflated_meta_info["text"])
        