# `Conflation`

In [89]:
#| default_exp conflate

In [175]:
#| export
import numpy as np, pandas as pd, scipy.sparse as sp, torch, json, os, re
import matplotlib.pyplot as plt
from IPython.display import display

from tqdm.auto import tqdm
from typing import Optional, Union, Dict, List
from IPython.display import display
from xclib.utils.sparse import retain_topk
from collections import Counter

## Helper functions

In [166]:
#| export
def display_items(matrix:sp.csr_matrix, info:Dict, n_data:Optional[int]=10, n_view_data:Optional[int]=20, 
                          seed:Optional[int]=1000):
    np.random.seed(seed)
    rnd_idx = np.random.permutation(matrix.shape[0])[:n_data]
    
    outputs = list()
    for idx in rnd_idx:
        sort_idx = np.argsort(matrix[idx].data)[:-n_view_data:-1]
        scores = matrix[idx].data[sort_idx]
        indices = matrix[idx].indices[sort_idx]
        labels = [info['text'][i] for i in indices]
        o = {
            "Substring": info['text'][idx],
            "Predictions": [(x, float(y)) for x,y in zip(labels, scores)]
        }
        outputs.append(o)
    return outputs
    

In [169]:
#| export
def compare_items(output_1, output_2):
    rows, index = [], []
    for x,y in zip(linker_output, pretrn_output):
        assert x["Substring"] == y["Substring"]
    
        index.extend([(x["Substring"], i+1) for i in range(len(x["Predictions"]))])
        rows.extend([(a[0], b[0]) for a,b in zip(x["Predictions"], y["Predictions"])])
        
    df = pd.DataFrame(
        rows,
        index=pd.MultiIndex.from_tuples(
            index,
            names=["Query", "Index"],
        ),
        columns=["Output 1", "Output 2"],
    )
    return df
    

In [218]:
#| export
def get_conflated_text(clusters:List, meta_info:Dict):
    texts = [[meta_info["text"][i] for i in c] for c in clusters]
    idx = np.argsort([len(c) for c in clusters])[::-1]
    return [texts[i] for i in idx]
        

In [219]:
#| export
class Operations:

    @staticmethod
    def compute_yty(mat:sp.csr_matrix, bsz:Optional[int]=1000, normalize=True):
        mat_t = mat.transpose()
        out = sp.vstack([mat[i:i+bsz]@mat_t for i in tqdm(range(0, mat.shape[0], bsz))])
        def invert(a):
            return np.divide(1.0, a, out=np.zeros_like(a), where=a!=0)
        return out.multiply(invert(out.sum(axis=1))).multiply(invert(out.sum(axis=0))).tocsr()

    @staticmethod
    def minimum(matrix:sp.csr_matrix):
        return np.array([matrix.data[i:j].min() for i,j in zip(matrix.indptr, matrix.indptr[1:])])

    @staticmethod
    def min_clamp(matrix:sp.csr_matrix, value:float, inplace:Optional[bool]=True):
        if not inplace:
            matrix = matrix.copy()
        matrix.data[matrix.data <= value] = 0
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def max_clamp(matrix:sp.csr_matrix, value:float, inplace:Optional[bool]=True):
        if not inplace:
            matrix = matrix.copy()
        matrix.data[matrix.data >= value] = 0
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def diff_threshold(matrix:sp.csr_matrix, value:float):
        for i,j in zip(matrix.indptr, matrix.indptr[1:]):
            data = matrix.data[i:j]
            if len(data):
                idx = np.where(data.max() - data > value)[0]
                data[idx] = 0.0
                matrix.data[i:j] = data
        matrix.eliminate_zeros()
        return matrix

    @staticmethod
    def get_clusters(groups:np.array, lbl_idx:Optional[np.array]=None):
        lbl_idx = np.arange(groups.shape[0]) if lbl_idx is None else lbl_idx
        assert groups.shape == lbl_idx.shape, "`groups` and `lbl_idx` must have the same shape."
        clusters = dict()
        for k,v in zip(groups, lbl_idx): 
            clusters.setdefault(k, []).append(v)
        clusters = list(clusters.values())
        
        cluster_sz = [len(c) for c in clusters]
        sort_idx = np.argsort(cluster_sz)[::-1]
        return [clusters[i] for i in sort_idx]
    

In [220]:
#| export
class Filter:

    @staticmethod
    def from_max_size(clusters:List, size:int):
        return [c for c in clusters if len(c) < size]

    @staticmethod
    def from_min_size(clusters:List, size:int):
        return [c for c in clusters if len(c) > size]

    @staticmethod
    def remove_top_clusters(clusters:List, topk:int):
        sort_idx = np.argsort([len(c) for c in clusters])[-topk-1::-1]
        return [clusters[i] for i in sort_idx]
        

## Load data

In [95]:
data_dir = "/Users/suchith720/Projects/data/beir/msmarco/XC/substring/"

trn_file = f"{data_dir}/substring_trn_X_Y.npz"
tst_file = f"{data_dir}/substring_tst_X_Y.npz"
info_file = f"{data_dir}/raw_data/substring.raw.csv"

trn_meta, tst_meta = sp.load_npz(trn_file), sp.load_npz(tst_file)
meta_info = pd.read_csv(info_file)

meta_phrases = sp.load_npz(f"{data_dir}/derived-phrases_substring_X_Y.npz")

In [109]:
linker_dir = "/Users/suchith720/Downloads/00_msmarco-gpt-concept-substring-linker-with-ngame-loss-001/"

In [100]:
with pd.option_context("display.max_colwidth", None):
    display(meta_info.head())

Unnamed: 0,identifier,text
0,0,success of the Manhattan Project
1,1,hundreds of thousands of innocent lives obliterated
2,2,Restorative justice
3,3,amber
4,4,shades of yellow


## Clustering

### `Semantic:` Connected components

In [120]:
lnk_meta_mat = retain_topk(sp.load_npz(f"{linker_dir}/test_predictions_labels.npz"), k=10)
ptr_meta_mat = retain_topk(sp.load_npz(f"{linker_dir}/test_predictions_labels_msmarco-distilbert-cos-v5.npz"), k=10)

In [168]:
lnk_meta_mat.shape, ptr_meta_mat.shape

((1657190, 1657190), (1657190, 1657190))

In [122]:
linker_output = display_similar_items(lnk_meta_mat, meta_info, n_data=10, n_view_data=20, seed=1000)
pretrn_output = display_similar_items(ptr_meta_mat, meta_info, n_data=10, n_view_data=20, seed=1000)

In [170]:
df = compare_items(linker_output, pretrn_output)

In [176]:
with pd.option_context('display.max_rows', None):
    display(df)

Unnamed: 0_level_0,Unnamed: 1_level_0,Output 1,Output 2
Query,Index,Unnamed: 2_level_1,Unnamed: 3_level_1
"fascial aches, pains, tension and restrictions",1,investing fascia,"fascial aches, pains, tension and restrictions"
"fascial aches, pains, tension and restrictions",2,spasm-type pain,Muscle pain and aches
"fascial aches, pains, tension and restrictions",3,"fascial aches, pains, tension and restrictions",muscle aches and pain
"fascial aches, pains, tension and restrictions",4,restrictions in fascia,pain and muscle tension
"fascial aches, pains, tension and restrictions",5,osteoarthritis-related joint pain,Muscle Aches and Pains
"fascial aches, pains, tension and restrictions",6,Relieving pain of normal intact skin,muscle aches and pains
"fascial aches, pains, tension and restrictions",7,"muscles, ligaments, tendons",Muscle aches and pains
"fascial aches, pains, tension and restrictions",8,contours the body,muscle aches and soreness
"fascial aches, pains, tension and restrictions",9,deeper and thicker fascia,"Aches, pains, and tense muscles"
"fascial aches, pains, tension and restrictions",10,penetrates deep into aching muscles and joints,Muscle aches or stiffness


In [182]:
def from_similarity(lbl_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                    sim_topk:Optional[int]=20):
    lbl_lbl = retain_topk(lbl_lbl, k=sim_topk)
    lbl_lbl = Operations.min_clamp(lbl_lbl, score_thresh)
    lbl_lbl = Operations.diff_threshold(lbl_lbl, diff_thresh)
    
    n, groups = sp.csgraph.connected_components(lbl_lbl)
    return Operations.get_clusters(groups)
    

In [221]:
lnk_clusters = from_similarity(lnk_meta_mat, score_thresh=0.5, diff_thresh=0.2, sim_topk=3)

In [228]:
ptr_clusters = from_similarity(ptr_meta_mat, score_thresh=0.5, diff_thresh=0.2, sim_topk=10)

In [192]:
lnk_cluster_text = get_conflated_text(lnk_clusters, meta_info)

In [193]:
ptr_cluster_text = get_conflated_text(ptr_clusters, meta_info)

In [229]:
lnk_cluster_text[100]

['FB may be a Hallmark',
 'FB stands for Franz Bibus',
 'BEB',
 'people can comment just like Facebook',
 '5$ a day in advertising on Facebook']

In [230]:
ptr_cluster_text[10000]

['causing the scabbing',
 'scab',
 'scabbing',
 'Also called scab',
 'cause scabbing',
 'a scab']

### `Predictions`

In [124]:
model_dir = "/Users/suchith720/Downloads/00_msmarco-gpt-concept-substring-linker-with-ngame-loss-001/"
data_meta = sp.load_npz(f"{model_dir}/train_predictions.npz")

In [73]:
def from_predictions(data_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                     pred_topk:Optional[int]=3):
    data_lbl = retain_topk(data_lbl, k=pred_topk)
    data_lbl = Operations.min_clamp(data_lbl, score_thresh)
    data_lbl = Operations.diff_threshold(data_lbl, diff_thresh)
    
    lbl_mat = Operations.compute_yty(data_lbl.transpose())
    n, groups = sp.csgraph.connected_components(lbl_mat)
    return Operations.get_clusters(groups)
    

### `Derived phrases`

In [207]:
def from_derived_phrases(lbl_phrases:sp.csr_matrix):
    lbl_mat = Operations.compute_yty(lbl_phrases)
    n, groups = sp.csgraph.connected_components(lbl_mat)
    return Operations.get_clusters(groups)
    

In [231]:
clusters = from_derived_phrases(meta_phrases)

  0%|          | 0/1658 [00:00<?, ?it/s]

In [209]:
texts = get_conflated_text(clusters, meta_info)

In [214]:
texts[2]

['high cholesterol',
 'cholesterol',
 'High Levels Of Fat In The Blood',
 'reduce LDL cholesterol in blood',
 'High blood cholesterol levels',
 'High cholesterol',
 'help you elevate your HDL',
 'raises HDL',
 'lower your cholesterol',
 'your LDL, or bad cholesterol',
 'improve your cholesterol numbers',
 'help lower cholesterol',
 'lower high cholesterol',
 'dyslipidemia',
 'lowering cholesterol levels',
 'increases the good HDL cholesterol',
 'to lower cholesterol',
 'LDL (bad) cholesterol',
 'low-density lipoprotein cholesterol (LDLc)',
 'lower blood cholesterol',
 'treat high cholesterol levels',
 'raise the level of cholesterol in the blood',
 'lower cholesterol',
 'reduce your cholesterol',
 'increased cholesterol levels',
 'treat high cholesterol',
 'increase bad cholesterol',
 'LDL cholesterol',
 'higher levels of bad cholesterol',
 'help maintain healthy cholesterol levels',
 'lower the bad cholesterol levels',
 'total cholesterol levels',
 'increased serum cholesterol',
 'kee

### Class cluster

In [243]:
#| export
class Cluster:

    @staticmethod
    def from_similarity(lbl_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                        sim_topk:Optional[int]=20):
        lbl_lbl = retain_topk(lbl_lbl, k=sim_topk)
        lbl_lbl = Operations.min_clamp(lbl_lbl, score_thresh)
        lbl_lbl = Operations.diff_threshold(lbl_lbl, diff_thresh)
        
        n, groups = sp.csgraph.connected_components(lbl_lbl)
        return Operations.get_clusters(groups)

    @staticmethod
    def from_predictions(data_lbl:sp.csr_matrix, score_thresh:Optional[float]=0.3, diff_thresh:Optional[float]=0.2, 
                         pred_topk:Optional[int]=3):
        data_lbl = retain_topk(data_lbl, k=pred_topk)
        data_lbl = Operations.min_clamp(data_lbl, score_thresh)
        data_lbl = Operations.diff_threshold(data_lbl, diff_thresh)
        
        lbl_mat = Operations.compute_yty(data_lbl.transpose())
        n, groups = sp.csgraph.connected_components(lbl_mat)
        return Operations.get_clusters(groups)

    @staticmethod
    def from_derived_phrases(lbl_phrases:sp.csr_matrix):
        lbl_mat = Operations.compute_yty(lbl_phrases)
        n, groups = sp.csgraph.connected_components(lbl_mat)
        return Operations.get_clusters(groups)
        

## `Conflation`

In [312]:
#| export
class Conflation:

    @staticmethod
    def get_groups(clusters:List, n_lbl:int):
        groups = np.full(n_lbl, -1)
        for i,idx in enumerate(clusters):
            groups[idx] = i
        idxs = np.where(groups == -1)[0]
        groups[idxs] = np.arange(len(clusters), len(clusters) + idxs.shape[0])
        return groups

    @staticmethod
    def get_conflated_matrix(matrix:sp.csr_matrix, groups:np.array, n_lbl:Optional[int]=None):
        data, indptr = matrix.data, matrix.indptr
        indices = [groups[i] for i in matrix.indices]
        shape = None if n_lbl is None else (matrix.shape[0], n_lbl)
        return sp.csr_matrix((data, indices, indptr), shape=shape)

    @staticmethod
    def get_conflated_info(clusters:List, info:Dict):
        conflated_info, flag, lbl_idx = dict(), np.zeros(len(info["identifier"])), []
        
        for idxs in clusters:
            idx = np.random.choice(idxs)
            conflated_info.setdefault("identifier", []).append(info["identifier"][idx])
            conflated_info.setdefault("text", []).append(info["text"][idx])
            lbl_idx.append(idx)
            flag[idxs] = 1
            
        idxs = np.where(flag == 0)[0]
        for idx in idxs:
            conflated_info.setdefault("identifier", []).append(info["identifier"][idx])
            conflated_info.setdefault("text", []).append(info["text"][idx])
            lbl_idx.append(idx)
            flag[idx] = 1
            
        assert flag.all(), "All items should be covered."
        return conflated_info, lbl_idx

    @staticmethod
    def perform_conflation(clusters:List, trn_meta:sp.csr_matrix, meta_info:Dict, tst_meta:Optional[sp.csr_matrix]=None, 
                           lbl_meta:Optional[sp.csr_matrix]=None):
        
        groups = Conflation.get_groups(clusters, n_lbl=trn_meta.shape[1])

        conflated_trn_meta = Conflation.get_conflated_matrix(trn_meta, groups)
        conflated_meta_info, meta_idx = Conflation.get_conflated_info(clusters, meta_info)
        
        conflated_tst_meta = None if tst_meta is None else Conflation.get_conflated_matrix(tst_meta, groups, n_lbl=conflated_trn_meta.shape[1])
        conflated_lbl_meta = None if lbl_meta is None else Conflation.get_conflated_matrix(lbl_meta, groups, n_lbl=conflated_trn_meta.shape[1])

        return conflated_trn_meta, conflated_meta_info, meta_idx, conflated_tst_meta, conflated_lbl_meta
    

### Derived phrase `+` semantic conflation

In [314]:
clusters = Cluster.from_derived_phrases(meta_phrases)

  0%|          | 0/1658 [00:00<?, ?it/s]

In [315]:
clusters = Filter.remove_top_clusters(clusters, topk=1)

In [319]:
ctrn_meta, cmeta_info, meta_idx, ctst_meta, _ = Conflation.perform_conflation(clusters, trn_meta, meta_info, tst_meta=tst_meta)

In [321]:
meta_file = f"{linker_dir}/test_predictions_labels_msmarco-distilbert-cos-v5.npz"
meta_mat = retain_topk(sp.load_npz(meta_file), k=10)

In [322]:
meta_mat = meta_mat[meta_idx][:, meta_idx]

In [323]:
clusters = Cluster.from_similarity(meta_mat, score_thresh=0.5, diff_thresh=0.2, sim_topk=10)

In [324]:
clusters = Filter.remove_top_clusters(clusters, topk=1)

In [325]:
ctrn_meta, cmeta_info, meta_idx, ctst_meta, _ = Conflation.perform_conflation(clusters, ctrn_meta, cmeta_info, tst_meta=ctst_meta)

In [329]:
ctrn_meta.getnnz(axis=0).mean()

np.float64(1.8354457163056375)

### Semantic conflation

In [330]:
meta_file = f"{linker_dir}/test_predictions_labels_msmarco-distilbert-cos-v5.npz"
meta_mat = retain_topk(sp.load_npz(meta_file), k=10)

In [337]:
clusters = Cluster.from_similarity(meta_mat, score_thresh=0.5, diff_thresh=0.2, sim_topk=10)

In [338]:
clusters = Filter.remove_top_clusters(clusters, topk=1)

In [339]:
ctrn_meta, cmeta_info, meta_idx, ctst_meta, _ = Conflation.perform_conflation(clusters, trn_meta, meta_info, tst_meta=tst_meta)

In [342]:
ctrn_meta.getnnz(axis=0).mean()

np.float64(1.6394842156736276)

In [343]:
ctrn_meta

<Compressed Sparse Row sparse matrix of dtype 'float32'
	with 2186003 stored elements and shape (502939, 1333348)>

## `Save data`

In [313]:
#| export
class SaveData:

    @staticmethod
    def save_raw(fname:str, ids:List, txt:List):
        df = pd.DataFrame({"identifier": ids, "text": txt})
        df.to_csv(fname, index=False)

    @staticmethod
    def proc(save_dir:str, trn_file:str, trn_meta:sp.csr_matrix, info_file:str, meta_info:Dict, 
             tst_file:Optional[str]=None, tst_meta:Optional[sp.csr_matrix]=None, lbl_file:Optional[str]=None, 
             lbl_meta:Optional[sp.csr_matrix]=None):
        
        n = len([o for o in os.listdir(save_dir) if re.match(r'conflation_[0-9]{2}', o)]) + 1
        save_dir = f"{save_dir}/conflation_{n:02d}"
        
        raw_dir = f"{save_dir}/raw_data"
        os.makedirs(raw_dir, exist_ok=True)

        sp.save_npz(f"{save_dir}/{os.path.basename(trn_file)}", trn_meta)
        
        if tst_meta is not None:
            sp.save_npz(f"{save_dir}/{os.path.basename(tst_file)}", tst_meta)
            
        if lbl_meta is not None:
            sp.save_npz(f"{save_dir}/{os.path.basename(lbl_file)}", lbl_meta)

        raw_file = f"{raw_dir}/{os.path.basename(info_file)}"
        SaveData.save_raw(raw_file, meta_info["identifier"], meta_info["text"])
        

## Driver

In [95]:
data_dir = "/Users/suchith720/Projects/data/beir/msmarco/XC/substring/"

trn_file = f"{data_dir}/substring_trn_X_Y.npz"
tst_file = f"{data_dir}/substring_tst_X_Y.npz"
info_file = f"{data_dir}/raw_data/substring.raw.csv"

trn_meta, tst_meta = sp.load_npz(trn_file), sp.load_npz(tst_file)
meta_info = pd.read_csv(info_file)

meta_phrases = sp.load_npz(f"{data_dir}/derived-phrases_substring_X_Y.npz")

In [109]:
linker_dir = "/Users/suchith720/Downloads/00_msmarco-gpt-concept-substring-linker-with-ngame-loss-001/"

In [246]:
clusters = Cluster.from_derived_phrases(meta_phrases)

  0%|          | 0/1658 [00:00<?, ?it/s]

In [None]:
clusters = Filter.remove_top_clusters(clusters, topk=1)

In [266]:
ctrn_meta, cmeta_info, meta_idx, ctst_meta = Conflation.perform_conflation(clusters, trn_meta, meta_info, tst_meta)

In [273]:
meta_file = f"{linker_dir}/test_predictions_labels_msmarco-distilbert-cos-v5.npz"
meta_mat = retain_topk(sp.load_npz(meta_file), k=10)

In [277]:
meta_mat = meta_mat[meta_idx][:, meta_idx]

In [295]:
clusters = Cluster.from_similarity(meta_mat, score_thresh=0.7, diff_thresh=0.2, sim_topk=20)

In [305]:
clusters = Filter.remove_top_clusters(clusters, topk=1)

In [307]:
ctrn_meta, cmeta_info, meta_idx, ctst_meta = Conflation.perform_conflation(clusters, ctrn_meta, cmeta_info, ctst_meta)

In [310]:
ctrn_meta.getnnz(axis=0).mean()

np.float64(1.9417916334300378)