In [1]:
#| default_exp 42_entity-conflation

In [2]:
%reload_ext autoreload
%autoreload 2

In [454]:
#| export
import scipy.sparse as sp, numpy as np, argparse, os, torch
from tqdm.auto import tqdm
from termcolor import colored, COLORS
from scipy.sparse.csgraph import connected_components
from typing import List, Optional, Dict, Set

from sugar.core import load_raw_file, save_raw_file
from xclib.utils.sparse import retain_topk

In [506]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

## Helper functions

In [209]:
#| export
def show_conflated_labels(idxs:List, components:Dict, lbl_ids2txt:Dict, fname:Optional[str]=None):
    file = fname if fname is None else open(fname, 'w')
    for i, idx in enumerate(idxs):
        txt = " || ".join([lbl_ids2txt[o] for o in components[idx]])
        if fname is None: print(f'{i+1:03d}. {txt}')
        else: file.write(f'{i+1:03d}. {txt}\n')
    if fname is not None: file.close()
        

## Setup

In [538]:
data_dir = "/Users/suchith720/Projects/data/(mapped)LF-WikiSeeAlsoTitles-320K/"

pred_file = f"{data_dir}/category_renee_tst_X_Y.npz"

lbl_file = f"{data_dir}/raw_data/category.raw.txt"
trn_file = f"{data_dir}/category_renee_trn_X_Y.npz"
tst_file = f"{data_dir}/category_renee_tst_X_Y.npz"

embed_file = None

In [559]:
#| export
def load_data(pred_file:str, trn_file:str, tst_file:str, lbl_file:str, embed_file:Optional[str]=None,
              encoding:Optional[str]='utf-8'):
    pred_lbl, trn_lbl, tst_lbl = sp.load_npz(pred_file), sp.load_npz(trn_file), sp.load_npz(tst_file)
    lbl_ids, lbl_txt = load_raw_file(lbl_file, encoding=encoding)
    lbl_repr = None if embed_file is None else torch.load(embed_file)
    return pred_lbl, trn_lbl, tst_lbl, (lbl_ids, lbl_txt), lbl_repr
    

In [540]:
pred_lbl, trn_lbl, tst_lbl, (lbl_ids, lbl_txt), lbl_repr = load_data(pred_file, trn_file, tst_file, lbl_file, 
                                                                     embed_file, encoding='latin-1')

In [509]:
lbl_repr = F.normalize(torch.randn(len(lbl_ids), 768), dim=1)

In [526]:
batch_size = 1000
topk = 3
min_thresh = 2
max_thresh = 100

percentile_thresh = 50
score_thresh = 0.0

In [560]:
#| export
class Filter:

    @staticmethod
    def by_length(components:Dict, min_thresh:Optional[int]=1, max_thresh:Optional[int]=100):
        cluster_len = np.array([len(components[idx]) for idx in sorted(components)])
        mask = np.logical_and(np.where(cluster_len >= min_thresh, 1, 0), np.where(cluster_len <= max_thresh, 1, 0))
        return set(np.where(mask)[0])

    @staticmethod
    def topk(data_lbl:sp.csr_matrix, k:Optional[int]=3):
        return retain_topk(data_lbl, k=k)

    @staticmethod
    def threshold(data_lbl:sp.csr_matrix, t:int):
        idx = np.where(data_lbl.data < t)[0]
        data_lbl.data[idx] = 0
        data_lbl.eliminate_zeros()
        return data_lbl
        

In [561]:
#| export
def get_one_hop(data_lbl:sp.csr_matrix, batch_size:Optional[int]=1024):
    data_lbl = data_lbl.copy()
    data_lbl.data[:] = 1.0
    
    lbl_data = data_lbl.T.tocsr()
    lbl_lbl = [lbl_data[i:i+batch_size]@data_lbl for i in tqdm(range(0, lbl_data.shape[0], batch_size))]
    return sp.vstack(lbl_lbl)
    

In [573]:
#| export
def normalize_matrix(data_lbl:sp.csr_matrix, lbl_lbl:sp.csr_matrix):
    lbl_mult = data_lbl.getnnz(axis=0).astype(np.float32)
    mask = lbl_mult > 0
    lbl_mult[mask] = 1/lbl_mult[mask]
    return lbl_lbl.multiply(lbl_mult[None]).multiply(lbl_mult[:, None]).tocsr()
    

In [574]:
#| export
def compute_embed_similarity(lbl_lbl:sp.csr_matrix, embed:torch.Tensor, batch_size:Optional[int]=1024):
    lbl_lbl = lbl_lbl.tocoo()
    scores = []
    for i in tqdm(range(0, lbl_lbl.nnz, batch_size)):
        row_idx, col_idx = lbl_lbl.row[i:i+batch_size], lbl_lbl.col[i:i+batch_size]
        sc = lbl_repr[row_idx].view(len(row_idx), 1, -1) @ lbl_repr[col_idx].view(len(col_idx), -1, 1)
        scores.append(sc.squeeze(1).squeeze(1))
    scores = torch.hstack(scores)
    lbl_lbl.data[:] = scores.numpy()
    return lbl_lbl.tocsr()
    

In [575]:
#| export
def get_components(data_lbl:sp.csr_matrix, lbl_ids:List, lbl_repr:Optional[torch.Tensor]=None, 
                   score_thresh:Optional[float]=0.0, q:Optional[float]=50, batch_size:Optional[int]=1024):
    lbl_lbl = get_one_hop(data_lbl, batch_size)
    lbl_lbl = normalize_matrix(data_lbl, lbl_lbl)
    lbl_lbl = Filter.threshold(lbl_lbl, t=np.percentile(lbl_lbl.data, q=q))
    
    if lbl_repr is not None:
        lbl_lbl = compute_embed_similarity(lbl_lbl, lbl_repr, batch_size=batch_size)
        lbl_lbl = Filter.threshold(lbl_lbl, t=score_thresh)
    
    n_comp, clusters = connected_components(lbl_lbl, directed=False, return_labels=True)
    components = {}
    for idx,ids in zip(clusters, lbl_ids):
        components.setdefault(idx, []).append(ids)
    return components
    

In [530]:
data_lbl = Filter.topk(pred_lbl, k=topk)

In [532]:
components = get_components(data_lbl, lbl_ids, lbl_repr=lbl_repr, score_thresh=score_thresh, 
                            q=percentile_thresh, batch_size=batch_size)

  0%|          | 0/657 [00:00<?, ?it/s]

  0%|          | 0/448 [00:00<?, ?it/s]

In [533]:
valid_cluster_idxs = Filter.by_length(components, min_thresh=min_thresh, max_thresh=max_thresh)

In [534]:
lbl_ids2txt = {k:v for k,v in zip(lbl_ids, lbl_txt)}

In [576]:
#| export
def get_valid_components(components:Dict, valid_cluster_idxs:Set):
    valid_components, lbl_ids2cluster = dict(), dict()

    curr_cluster_idx = 0
    for idx, cluster in components.items():
        if idx in valid_cluster_idxs:
            valid_components[curr_cluster_idx] = cluster
            for o in cluster: lbl_ids2cluster[o] = curr_cluster_idx
            curr_cluster_idx += 1
        else:
            for o in cluster:
                valid_components[curr_cluster_idx] = [o]
                lbl_ids2cluster[o] = curr_cluster_idx
                curr_cluster_idx += 1
                
    return valid_components, lbl_ids2cluster
    

In [566]:
#| export
def get_conflated_info(components:Dict, lbl_ids2txt:Dict):
    return [" || ".join([lbl_ids2txt[o] for o in components[i]]) for i in sorted(components)]
        

In [567]:
#| export
def get_id_to_cluster_idx_mapping(lbl_ids2cluster_map:Dict, lbl_ids:List):
    return np.array([lbl_ids2cluster_map[o] for o in lbl_ids])
    

In [568]:
#| export
def get_conflated_matrix(data_lbl:sp.csr_matrix, lbl_ids2cluster:Dict):
    indices = [lbl_ids2cluster[idx] for idx in data_lbl.indices]
    data = len(indices) * [1]
    
    matrix = sp.csr_matrix((data, indices, data_lbl.indptr), dtype=np.float32)
    matrix.sum_duplicates()
    return matrix
    

In [535]:
valid_components, lbl_ids2cluster_map = get_valid_components(components, valid_cluster_idxs)

In [536]:
conflated_lbl_txt = get_conflated_info(valid_components, lbl_ids2txt)
lbl_ids2cluster = get_id_to_cluster_idx_mapping(lbl_ids2cluster_map, lbl_ids)

In [537]:
conflated_trn_lbl = get_conflated_matrix(trn_lbl, lbl_ids2cluster)
conflated_tst_lbl = get_conflated_matrix(tst_lbl, lbl_ids2cluster)

In [569]:
#| export
def get_conflated_path(fname):
    file_dir = os.path.dirname(fname)
    file_name, file_type = os.path.basename(fname).split('.', maxsplit=1)
    return f'{file_dir}/{file_name}_conflated.{file_type}'
    

In [570]:
#| export
def save_conflated_data(lbl_txt:List, lbl_file:str, trn_lbl:sp.csr_matrix, trn_file:str, 
                        tst_lbl:sp.csr_matrix, tst_file:str):
    lbl_file = get_conflated_path(lbl_file)
    trn_file = get_conflated_path(trn_file)
    tst_file = get_conflated_path(tst_file)

    save_raw_file(lbl_file, range(len(lbl_txt)), lbl_txt)
    sp.save_npz(trn_file, trn_lbl)
    sp.save_npz(tst_file, tst_lbl)
    

In [208]:
save_conflated_data(conflated_lbl_txt, lbl_file, conflated_trn_lbl, trn_file, conflated_tst_lbl, tst_file)

In [571]:
#| export
def main(pred_file:str, trn_file:str, tst_file:str, lbl_file:str, embed_file:Optional[str]=None, 
         topk:Optional[int]=3, batch_size:Optional[int]=1024, min_thresh:Optional[int]=2, 
         max_thresh:Optional[int]=100, score_thresh:Optional[float]=0.0, percentile_thresh:Optional[float]=50, 
         encoding:Optional[str]='latin-1'):
    
    pred_lbl, trn_lbl, tst_lbl, (lbl_ids, lbl_txt), lbl_repr = load_data(pred_file, trn_file, tst_file, 
                                                                         lbl_file, embed_file, encoding=encoding)
    lbl_ids2txt = {k:v for k,v in zip(lbl_ids, lbl_txt)}

    data_lbl = Filter.topk(pred_lbl, k=topk)
    components = get_components(data_lbl, lbl_ids, lbl_repr=lbl_repr, score_thresh=score_thresh, 
                                q=percentile_thresh, batch_size=batch_size)

    valid_cluster_idxs = Filter.by_length(components, min_thresh=min_thresh, max_thresh=max_thresh)
    
    valid_components, lbl_ids2cluster_map = get_valid_components(components, valid_cluster_idxs)
    conflated_lbl_txt = get_conflated_info(valid_components, lbl_ids2txt)
    lbl_ids2cluster = get_id_to_cluster_idx_mapping(lbl_ids2cluster_map, lbl_ids)

    conflated_trn_lbl = get_conflated_matrix(trn_lbl, lbl_ids2cluster)
    conflated_tst_lbl = get_conflated_matrix(tst_lbl, lbl_ids2cluster)

    save_conflated_data(conflated_lbl_txt, lbl_file, conflated_trn_lbl, trn_file, conflated_tst_lbl, tst_file)
    

## Driver

In [544]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--pred_file', type=str, required=True)
    parser.add_argument('--lbl_file', type=str, required=True)
    parser.add_argument('--trn_file', type=str, required=True)
    parser.add_argument('--tst_file', type=str, required=True)
    parser.add_argument('--embed_file', type=str, default=None)

    parser.add_argument('--topk', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=1024)
    
    parser.add_argument('--min_thresh', type=int, default=2)
    parser.add_argument('--max_thresh', type=int, default=100)
    parser.add_argument('--score_thresh', type=float, default=0.0)
    parser.add_argument('--percentile_thresh', type=float, default=50)
    
    parser.add_argument('--encoding', type=str, default='latin-1')
    
    return parser.parse_args()
    

In [226]:
#| export
if __name__ == '__main__':
    args = parse_args()
    
    main(args.pred_file, args.trn_file, args.tst_file, args.lbl_file, args.embed_file, topk=args.topk, 
         batch_size=args.batch_size, min_thresh=args.min_thresh, max_thresh=args.max_thresh, 
         score_thresh=args.score_thresh, percentile_thresh=args.percentile_thresh, encoding=args.encoding)
