In [None]:
#| default_exp statistics

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import scipy.sparse as sp, re, xclib.data.data_utils as du, numpy as np, pandas as pd, os
from IPython.display import display
from torch.utils.data import Dataset
from termcolor import colored, COLORS

from sugar.core import *
from xcai.data import *

## `UpdatedDataset`

In [None]:
#| export
METADATA_CODE = {'category': 'cat', 'see_also': 'sal', 'hyper_link': 'hlk', 'videos': 'vid', 'images': 'img', 
                 'entity': 'ent', 'canonical': 'can', 'entity_canonical_category': 'ecc', 'entity_canonical': 'enc'}

In [None]:
#| export
class UpdatedDataset:

    @staticmethod
    def load_data_info(data_dir, type, suffix=''):
        if len(suffix): suffix = f'.{suffix}'
        ids, txt = load_raw_file(f'{data_dir}/raw_data/{type}{suffix}.raw.csv')
        return {'identifier':ids, 'input_text':txt}

    @staticmethod
    def load_lbl_info(data_dir, x_prefix, y_prefix):
        ids, txt = load_raw_file(f'{data_dir}/raw_data/label.{x_prefix}-{y_prefix}.raw.csv')
        return {'identifier':ids, 'input_text':txt}

    @staticmethod
    def load_metadata_info(data_dir, metadata_type, x_prefix, y_prefix, z_prefix):
        ids, txt = load_raw_file(f'{data_dir}/raw_data/{metadata_type}.{x_prefix}-{y_prefix}-{z_prefix}.raw.csv')
        return {'identifier':ids, 'input_text':txt}

    @staticmethod
    def get_trn_tst_info(data_dir, suffix=''):
        trn_info = UpdatedDataset.load_data_info(data_dir, 'train', suffix)
        tst_info = UpdatedDataset.load_data_info(data_dir, 'test', suffix)
        return trn_info, tst_info

    @staticmethod
    def load_main_matrix(data_dir, x_prefix, y_prefix, type):
        if os.path.exists(f'{data_dir}/{type}_X_Y_{x_prefix}-{y_prefix}.npz'):
            mat = sp.load_npz(f'{data_dir}/{type}_X_Y_{x_prefix}-{y_prefix}.npz')
        else:
            mat = du.read_sparse_file(f'{data_dir}/{type}_X_Y_{x_prefix}-{y_prefix}.txt')
        return mat

    @staticmethod
    def get_labels(data_dir, x_prefix, y_prefix):
        trn_mat = UpdatedDataset.load_main_matrix(data_dir, x_prefix, y_prefix, 'trn')
        tst_mat = UpdatedDataset.load_main_matrix(data_dir, x_prefix, y_prefix, 'tst')
        
        lbl_info = UpdatedDataset.load_lbl_info(data_dir, x_prefix=x_prefix, y_prefix=y_prefix)
        
        return trn_mat, tst_mat, lbl_info

    @staticmethod
    def load_metadata_matrix(data_dir, x_prefix, y_prefix, z_prefix, main_type, metadata_type):
        if os.path.exists(f'{data_dir}/{metadata_type}_{main_type}_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz'):
            mat = sp.load_npz(f'{data_dir}/{metadata_type}_{main_type}_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz')
        else:
            mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_{main_type}_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.txt')
        return mat

    @staticmethod
    def get_metadata(data_dir, metadata_type, x_prefix, y_prefix, z_prefix):
        trn_mat = UpdatedDataset.load_metadata_matrix(data_dir, x_prefix, y_prefix, z_prefix, 'trn', metadata_type)
        tst_mat = UpdatedDataset.load_metadata_matrix(data_dir, x_prefix, y_prefix, z_prefix, 'tst', metadata_type)
        lbl_mat = UpdatedDataset.load_metadata_matrix(data_dir, x_prefix, y_prefix, z_prefix, 'lbl', metadata_type)
        
        meta_info = UpdatedDataset.load_metadata_info(data_dir, metadata_type, x_prefix, y_prefix, z_prefix)
        
        return trn_mat, tst_mat, lbl_mat, meta_info

    @staticmethod
    def load_datasets(data_dir, metadata_type, x_prefix, y_prefix, z_prefix):
        trn_info, tst_info = UpdatedDataset.get_trn_tst_info(data_dir, x_prefix)
        trn_mat, tst_mat, lbl_info = UpdatedDataset.get_labels(data_dir, x_prefix, y_prefix)
    
        main_trn_dset = MainXCDataset(trn_info, trn_mat, lbl_info)
        main_tst_dset = MainXCDataset(tst_info, tst_mat, lbl_info)
    
        trn_meta_mat, tst_meta_mat, lbl_meta_mat, meta_info = UpdatedDataset.get_metadata(data_dir, metadata_type, x_prefix, y_prefix, z_prefix)
    
        trn_meta_dset = MetaXCDataset(METADATA_CODE[metadata_type], trn_meta_mat, lbl_meta_mat, meta_info)
        tst_meta_dset = MetaXCDataset(METADATA_CODE[metadata_type], tst_meta_mat, lbl_meta_mat, meta_info)
    
        trn_dset = XCDataset(main_trn_dset, cat_meta=trn_meta_dset)
        tst_dset = XCDataset(main_tst_dset, cat_meta=tst_meta_dset)
    
        return trn_dset, tst_dset
        

## `Dataset`

In [None]:
#| export
class Dataset:

    @staticmethod
    def load_data_lbl_info(data_dir, type, encoding='utf-8'):
        fname = f'{data_dir}/raw_data/{type}.raw'
        ids, txt = load_raw_file(fname+'.csv', encoding=encoding) if os.path.exists(fname+'.csv') else load_raw_file(fname+'.txt', encoding=encoding)
        return {'identifier':ids, 'input_text':txt}

    @staticmethod
    def load_metadata_info(data_dir, metadata_type, encoding='utf-8'):
        fname = f'{data_dir}/raw_data/{metadata_type}.raw'
        ids, txt = load_raw_file(fname+'.csv', encoding=encoding) if os.path.exists(fname+'.csv') else load_raw_file(fname+'.txt', encoding=encoding)
        return {'identifier':ids, 'input_text':txt}

    @staticmethod
    def get_trn_tst_info(data_dir, encoding='utf-8'):
        trn_info = Dataset.load_data_lbl_info(data_dir, 'train', encoding)
        tst_info = Dataset.load_data_lbl_info(data_dir, 'test', encoding)
        return trn_info, tst_info

    @staticmethod
    def get_labels(data_dir, encoding='utf-8'):
        trn_mat = du.read_sparse_file(f'{data_dir}/trn_X_Y.txt') if os.path.exists(f'{data_dir}/trn_X_Y.txt') else sp.load_npz(f'{data_dir}/trn_X_Y.npz')
        tst_mat = du.read_sparse_file(f'{data_dir}/tst_X_Y.txt') if os.path.exists(f'{data_dir}/tst_X_Y.txt') else sp.load_npz(f'{data_dir}/tst_X_Y.npz')
            
        lbl_info = Dataset.load_data_lbl_info(data_dir, 'label', encoding)
        
        return trn_mat, tst_mat, lbl_info

    @staticmethod
    def get_metadata(data_dir, metadata_type, encoding='utf-8'):
        trn_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_trn_X_Y.txt') if os.path.exists(f'{data_dir}/{metadata_type}_trn_X_Y.txt') else sp.load_npz(f'{data_dir}/{metadata_type}_trn_X_Y.npz')
        tst_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_tst_X_Y.txt') if os.path.exists(f'{data_dir}/{metadata_type}_tst_X_Y.txt') else sp.load_npz(f'{data_dir}/{metadata_type}_tst_X_Y.npz')
        lbl_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_lbl_X_Y.txt') if os.path.exists(f'{data_dir}/{metadata_type}_lbl_X_Y.txt') else sp.load_npz(f'{data_dir}/{metadata_type}_lbl_X_Y.npz')
        
        meta_info = Dataset.load_metadata_info(data_dir, metadata_type, encoding)
        
        return trn_mat, tst_mat, lbl_mat, meta_info

    @staticmethod
    def load_datasets(data_dir, metadata_type, encoding='utf-8'):
        trn_info, tst_info = Dataset.get_trn_tst_info(data_dir, encoding)
        trn_mat, tst_mat, lbl_info = Dataset.get_labels(data_dir, encoding)
    
        main_trn_dset = MainXCDataset(trn_info, trn_mat, lbl_info)
        main_tst_dset = MainXCDataset(tst_info, tst_mat, lbl_info)
    
        trn_meta_mat, tst_meta_mat, lbl_meta_mat, meta_info = Dataset.get_metadata(data_dir, metadata_type, encoding)
    
        trn_meta_dset = MetaXCDataset(METADATA_CODE[metadata_type], trn_meta_mat, lbl_meta_mat, meta_info)
        tst_meta_dset = MetaXCDataset(METADATA_CODE[metadata_type], tst_meta_mat, lbl_meta_mat, meta_info)
    
        trn_dset = XCDataset(main_trn_dset, cat_meta=trn_meta_dset)
        tst_dset = XCDataset(main_tst_dset, cat_meta=tst_meta_dset)
    
        return trn_dset, tst_dset
        

## Compute statistics

In [None]:
#| export
CODE_METADATA = {'cat':'Category', 'sal':'Seealso', 'hlk':'Hyperlink', 'vid':'Videos', 'img':'Images', 
                 'ent':'Entity', 'can':'Canonical', 'ecc':'Entity Canonical Category', 'enc':'Entity Canonical'}

In [None]:
#| export
def matrix_stats(mat):
    n_dat = mat.shape[0]
    n_lbl = mat.shape[1]

    num_dat_lbl = mat.getnnz(axis=0)
    num_lbl_dat = mat.getnnz(axis=1)

    avg_dat_lbl = num_dat_lbl.mean()
    avg_lbl_dat = num_lbl_dat.mean()

    max_dat_lbl = num_dat_lbl.max()
    max_lbl_dat = num_lbl_dat.max()

    zro_dat_lbl = np.sum(num_dat_lbl == 0)
    zro_lbl_dat = np.sum(num_lbl_dat == 0)

    stats_dict = {
        f'# Entries' : n_dat,
        f'# Features': n_lbl,
        f'Avg. Entries per feature' : avg_dat_lbl,
        f'Avg. Feature per entry'   : avg_lbl_dat,
        f'Max. Entries per feature' : max_dat_lbl,
        f'Max. Feature per entry'   : max_lbl_dat,
        f'# Features without entry' : zro_dat_lbl,
        f'# Entries without feature': zro_lbl_dat,
    }
    return stats_dict
    

In [None]:
#| export
def main_dset_stats(dset):
    return matrix_stats(dset.data_lbl)
    

In [None]:
#| export
def meta_dset_stats(dset):
    dat_stats = matrix_stats(dset.data_meta)
    dat_stats['Dataset'] = 'Query'
    
    lbl_stats = matrix_stats(dset.lbl_meta)
    lbl_stats['Dataset'] = 'Label'

    return [dat_stats, lbl_stats]
    

In [None]:
#| export
def dset_stats(dset):
    stats = []
    
    main_stats = main_dset_stats(dset.data)
    main_stats['Dataset'] = 'Main'
    stats.append(main_stats)

    for o in dset.meta.values():
        meta_stats = meta_dset_stats(o)
        for s in meta_stats: s['Dataset'] = f'{s["Dataset"]} {CODE_METADATA[o.prefix]} Metadata'
        stats.extend(meta_stats)

    return stats
    

In [None]:
#| export
def trn_tst_stats(trn_dset, tst_dset):
    trn_stats = dset_stats(trn_dset)
    for o in trn_stats: o['Split'] = 'Train'

    tst_stats = dset_stats(tst_dset)
    for o in tst_stats: o['Split'] = 'Test'

    stats = trn_stats + tst_stats
    return stats
    

In [None]:
#| export
def print_stats(stats):
    df = pd.DataFrame(stats).set_index(['Split', 'Dataset'])
    with pd.option_context('display.precision', 2):
        display(df)
    

In [None]:
#| export
def print_dset_stats(trn_dset, tst_dset):
    stats = trn_tst_stats(trn_dset, tst_dset)
    print_stats(stats)
    

## Text

In [None]:
#| export
class TextDataset(Dataset):
    
    def __init__(self, dset, pattern='.*_text$'):
        self.dset, self.pattern = dset, pattern
        colors = list(COLORS.keys())
        self.colors = [colors[i] for i in np.random.permutation(len(colors))]
    
    def __getitem__(self, idx):
        o = self.dset[idx]
        return {k:v for k,v in o.items() if re.match(self.pattern, k)}

    def show(self, idxs):
        for idx in idxs:
            for i,(k,v) in enumerate(self[idx].items()):
                key = colored(k, self.colors[i], attrs=["reverse", "blink"])
                value = colored(f': {v}', self.colors[i])
                print(key, value)
            print()

    def get_head_data(self, topk=10):
        return np.argsort(self.dset.data.data_lbl.getnnz(axis=1))[:-topk:-1]

    def get_tail_data(self, topk=10):
        num = self.dset.data.data_lbl.getnnz(axis=1)
        idx = np.argsort(num)
        valid = (num > 0)[idx]
        return idx[valid][:topk]
        
    def dump_txt(self, fname, idxs):
        with open(fname, 'w') as file:
            for idx in idxs:
                for i,(k,v) in enumerate(self[idx].items()):
                    file.write(f'{k}: {v}\n')
                file.write('\n')
            
    def dump_csv(self, fname, idxs):
        df = pd.DataFrame([self[idx] for idx in idxs])
        df.to_csv(fname, index=False)

    def dump(self, fname, idxs):
        if fname.endswith('.txt'): 
            self.dump_txt(fname, idxs)
        elif fname.endswith('.csv'): 
            self.dump_csv(fname, idxs)
        else: 
            raise ValueError(f'Invalid file extension: {fname}')
        
    

## Save dataset

In [None]:
#| export
PREFIX_METDATA = {'cat': 'category', 'hlk': 'hyper_link', 'sal': 'see_also'}

In [None]:
#| export
def save_labels(data_dir, trn_dset, tst_dset):
    os.makedirs(data_dir, exist_ok=True)

    sp.save_npz(f'{data_dir}/trn_X_Y.npz', trn_dset.data.data_lbl)
    sp.save_npz(f'{data_dir}/tst_X_Y.npz', tst_dset.data.data_lbl)

    os.makedirs(f'{data_dir}/raw_data', exist_ok=True)

    save_raw_file(f'{data_dir}/raw_data/train.raw.txt', trn_dset.data.data_info['identifier'], trn_dset.data.data_info['input_text'])
    save_raw_file(f'{data_dir}/raw_data/test.raw.txt', tst_dset.data.data_info['identifier'], tst_dset.data.data_info['input_text'])
    save_raw_file(f'{data_dir}/raw_data/label.raw.txt', trn_dset.data.lbl_info['identifier'], trn_dset.data.lbl_info['input_text'])
    

In [None]:
#| export
def save_metadata(data_dir, trn_dset, tst_dset):
    metadata_type = None
    
    for metadata in trn_dset.meta.keys():
        metadata_type = PREFIX_METDATA[trn_dset.meta[metadata].prefix]
        
        sp.save_npz(f'{data_dir}/{metadata_type}_trn_X_Y.npz', trn_dset.meta[metadata].data_meta)
        sp.save_npz(f'{data_dir}/{metadata_type}_tst_X_Y.npz', tst_dset.meta[metadata].data_meta)
        sp.save_npz(f'{data_dir}/{metadata_type}_lbl_X_Y.npz', trn_dset.meta[metadata].lbl_meta)
        
        save_raw_file(f'{data_dir}/raw_data/{metadata_type}.raw.txt', trn_dset.meta[metadata].meta_info['identifier'], trn_dset.meta[metadata].meta_info['input_text'])
    

In [None]:
#| export
def save_dataset(data_dir, trn_dset, tst_dset):
    valid_idx = np.where(trn_dset.data.data_lbl.getnnz(axis=1) > 0)[0]
    trn_dset = trn_dset._getitems(valid_idx)

    valid_idx = np.where(tst_dset.data.data_lbl.getnnz(axis=1) > 0)[0]
    tst_dset = tst_dset._getitems(valid_idx)

    save_labels(data_dir, trn_dset, tst_dset)
    save_metadata(data_dir, trn_dset, tst_dset)

    return trn_dset, tst_dset
    

## Helper

In [None]:
#| export
def show_updated_dataset(data_dir, save_dir, metadata_type, x_prefix, y_prefix, z_prefix, idxs, use_trn=True):
    trn_dset, tst_dset = UpdatedDataset.load_datasets(data_dir, save_dir, metadata_type, x_prefix, y_prefix, z_prefix)
    print_dset_stats(trn_dset, tst_dset)
    
    txt_dset = TextDataset(trn_dset if use_trn else tst_dset)
    txt_dset.show(idxs)
    
    return trn_dset, tst_dset

def show_dataset(data_dir, metadata_type, idxs, suffix='', use_trn=True):
    trn_dset, tst_dset = Dataset.load_datasets(data_dir, metadata_type, suffix=suffix)
    print_dset_stats(trn_dset, tst_dset)
    
    txt_dset = TextDataset(trn_dset if use_trn else tst_dset)
    txt_dset.show(idxs)
    
    return trn_dset, tst_dset
    

## `__main__`

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--metadata_key', type=str, required=None)
    parser.add_argument('--x_prefix', type=str, required=True)
    parser.add_argument('--y_prefix', type=str, required=True)
    parser.add_argument('--z_prefix', type=str, required=True)
    return parser.parse_args()
    

In [None]:
if __name__ == '__main__':
    args = parse_args()
    trn_dset, tst_dset = UpdatedDataset.load_datasets(args.data_dir, args.save_dir, args.metadata_type, 
                                                      args.x_prefix, args.y_prefix, args.z_prefix)
    

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K/'
save_dir = '/home/scai/phd/aiz218323/scratch/datasets/wikipedia/20250123/LF-WikiSeeAlsoTitles-320K/'

metadata_type = 'category'
x_prefix, y_prefix, z_prefix = 'new', 'new', 'new'

In [None]:
trn_dset, tst_dset = UpdatedDataset.load_datasets(save_dir, metadata_type, x_prefix, y_prefix, z_prefix)

In [None]:
new_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/20250123-LF-WikiSeeAlsoTitles-320K/'

In [None]:
trn_dset, tst_dset = save_dataset(new_dir, trn_dset, tst_dset)

In [None]:
print_dset_stats(trn_dset, tst_dset)

Unnamed: 0_level_0,Unnamed: 1_level_0,# Entries,# Features,Avg. Entries per feature,Avg. Feature per entry,Max. Entries per feature,Max. Feature per entry,# Features without entry,# Entries without feature
Split,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
Train,Main,1183189,1205595,2.57,2.62,4747,987,0,0
Train,Query Category Metadata,1183189,1155247,5.21,5.09,65056,245,159691,102390
Train,Label Category Metadata,1205595,1155247,3.75,3.6,25831,247,280544,420595
Test,Main,340278,1205595,0.82,2.9,1075,150,758109,0
Test,Query Category Metadata,340278,1155247,1.49,5.06,17104,247,625513,24033
Test,Label Category Metadata,1205595,1155247,3.75,3.6,25831,247,280544,420595


In [None]:
txt_dset = TextDataset(trn_dset)

In [None]:
os.makedirs(f'{new_dir}/examples', exist_ok=True)

idxs = np.random.permutation(trn_dset.n_data)[:1000]
txt_dset.dump(f'{new_dir}/examples/random_train.txt', idxs)

In [None]:
txt_dset.show([10, 20, 30])

[5m[7m[90mdata_input_text[0m [90m: Arithmetic mean[0m
[5m[7m[94mlbl2data_input_text[0m [94m: ['Fréchet mean', 'Generalized mean', 'Summary statistics', 'Standard deviation', 'Standard error of the mean', 'Sample mean and covariance', 'Inequality of arithmetic and geometric means'][0m
[5m[7m[30mcat2data_input_text[0m [30m: ['Means'][0m
[5m[7m[95mcat2lbl2data_input_text[0m [95m: [['Means'], ['Means', 'Inequalities', 'Articles with example Haskell code'], ['Summary statistics'], ['Summary statistics', 'Statistical deviation and dispersion'], [], ['U-statistics', 'Summary statistics', 'Estimation methods', 'Covariance and correlation', 'Matrices'], []][0m

[5m[7m[90mdata_input_text[0m [90m: Annual plant[0m
[5m[7m[94mlbl2data_input_text[0m [94m: ['Ephemeral plant'][0m
[5m[7m[30mcat2data_input_text[0m [30m: ['Annual plants'][0m
[5m[7m[95mcat2lbl2data_input_text[0m [95m: [['Ephemeral plants', 'Plants', 'Flowers']][0m

[5m[7m[90mdata_input_tex