In [None]:
#| default_exp 08_map-wikipedia-metadata

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import argparse, os, pandas as pd, pickle, numpy as np, re, scipy.sparse as sp, joblib
import xclib.data.data_utils as du
from tqdm.auto import tqdm

from sugar.core import *
from xcai.core import Filterer

In [None]:
import difflib

## Extraction helper 

In [None]:
#| export
def load_info(data_dir):
    page_info = joblib.load(f'{data_dir}/page_info.joblib')
    redirect_info = joblib.load(f'{data_dir}/redirect_info.joblib')
    return page_info, redirect_info
    

In [None]:
#| export
def get_redirect_map(redirect_info):
    redirect_inp2targ = {o['title']:o['redirect'] for o in redirect_info.values()}
    return redirect_inp2targ
    

In [None]:
#| export
def transform_ids(ids, redirect_inp2targ):
    ids = [o.replace('_', ' ') for o in ids]
    ids = [redirect_inp2targ.get(o, o) for o in ids]
    return ids
    

In [None]:
#| export
def get_all_ids(data_dir, redirect_inp2targ=None, load_lbl=True):
    trn_ids, _ = load_raw_txt(f'{data_dir}/raw_data/train.raw.txt', encoding='latin-1')
    tst_ids, _ = load_raw_txt(f'{data_dir}/raw_data/test.raw.txt', encoding='latin-1')

    if redirect_inp2targ is not None:
        trn_ids = transform_ids(trn_ids, redirect_inp2targ)
        tst_ids = transform_ids(tst_ids, redirect_inp2targ)

    if load_lbl:
        lbl_ids, _ = load_raw_txt(f'{data_dir}/raw_data/label.raw.txt', encoding='latin-1')
        if redirect_inp2targ is not None: 
            lbl_ids = transform_ids(lbl_ids, redirect_inp2targ)
        return set(trn_ids + tst_ids + lbl_ids), trn_ids, tst_ids, lbl_ids
    
    return set(trn_ids + tst_ids), trn_ids, tst_ids
    

In [None]:
#| export
def get_mapping(page_info, all_ids, key):
    mapping = {o['title']:o[key] for o in page_info.values() if key in o}
    mapping = filter_mapping(mapping, all_ids)
    return mapping
    

In [None]:
#| export
def get_vocabulary(mapping, key):
    vocab, mapping_item2idx = create_vocab_and_item2idx(mapping)
    
    vocab_txt = sorted(vocab, key=lambda x: vocab[x])
    vocab_ids = [o.replace(' ', '_') for o in vocab_txt]
    
    if key == 'category': vocab_txt = [o[9:] for o in vocab_txt]

    return vocab_ids, vocab_txt, mapping_item2idx
    

In [None]:
#| export
def get_matrix(mapping_item2idx, vocab_size, trn_ids, tst_ids, lbl_ids=None):
    trn_mat, trn_ids = get_matrix_from_item2idx(mapping_item2idx, vocab_size, trn_ids)
    tst_mat, tst_ids = get_matrix_from_item2idx(mapping_item2idx, vocab_size, tst_ids)
    
    if lbl_ids is not None:
        lbl_mat, lbl_ids = get_matrix_from_item2idx(mapping_item2idx, vocab_size, lbl_ids)
        return trn_mat, tst_mat, lbl_mat
        
    return trn_mat, tst_mat
    

In [None]:
#| export
def filter_vocab(vocab_ids, vocab_txt, trn_mat, tst_mat, lbl_mat=None):
    valid_idx = np.where(trn_mat.getnnz(axis=0) > 0)[0]
    if lbl_mat is not None:
        lbl_idx = np.where(lbl_mat.getnnz(axis=0) > 0)[0]
        valid_idx = np.union1d(valid_idx, lbl_idx)

    trn_mat = trn_mat[:, valid_idx].copy()
    tst_mat = tst_mat[:, valid_idx].copy()
    if lbl_mat is not None: 
        lbl_mat = lbl_mat[:, valid_idx].copy()
    
    vocab_ids, vocab_txt = [vocab_ids[i] for i in valid_idx], [vocab_txt[i] for i in valid_idx]

    return vocab_ids, vocab_txt, trn_mat, tst_mat, lbl_mat
    

## Combiner helper

In [None]:
#| export
def ids_intersection(new_ids, old_ids):
    new_ids2idx = {k:idx for idx,k in enumerate(new_ids)}
    old2new_ids = [(idx,new_ids2idx[ids]) for idx,ids in enumerate(old_ids) if ids in new_ids2idx]
    
    old_not_in_new_ids = [idx for idx,ids in enumerate(old_ids) if ids not in new_ids2idx]
    
    old_idx, new_idx = zip(*old2new_ids)
    return (old_idx, new_idx), old_not_in_new_ids
    

In [None]:
#| export
def combine_matrix(old_mat, new_mat, old_idx, new_idx):
    old_mat, new_mat = old_mat.tocsc(), new_mat.tocsc()

    idx = np.argsort(new_idx)
    old_idx, new_idx = np.array(old_idx)[idx], np.array(new_idx)[idx]

    new_idx, idx = np.unique(new_idx, return_index=True)
    old_idx = old_idx[idx]

    ptr = 0
    data, indices, indptr = [], [], [0]
    for i in range(new_mat.shape[1]):
        new_start, new_end = new_mat.indptr[i], new_mat.indptr[i+1]
        if ptr < len(new_idx) and i == new_idx[ptr]:
            old_start, old_end = old_mat.indptr[old_idx[ptr]], old_mat.indptr[old_idx[ptr]+1]
            data.extend(new_mat.data[new_start:new_end].tolist() + old_mat.data[old_start:old_end].tolist())
            indices.extend(new_mat.indices[new_start:new_end].tolist() + old_mat.indices[old_start:old_end].tolist())
            indptr.append(len(data))
            ptr += 1
        else:
            data.extend(new_mat.data[new_start:new_end].tolist())
            indices.extend(new_mat.indices[new_start:new_end].tolist())
            indptr.append(len(data))
    
    combined_mat = sp.csc_matrix((data, indices, indptr), shape=new_mat.shape, dtype=new_mat.dtype)
    combined_mat = combined_mat.tocsr()
    combined_mat.sum_duplicates()
    return combined_mat
    

In [None]:
#| export
def get_combined_data(old_vocab_ids, old_vocab_txt, old_trn_mat, old_tst_mat, new_vocab_ids, new_vocab_txt, new_trn_mat, new_tst_mat, 
                      old_lbl_mat=None, new_lbl_mat=None):
    (old_idx, new_idx), old_not_in_new_ids = ids_intersection(new_vocab_ids, old_vocab_ids)

    combined_trn_mat = combine_matrix(old_trn_mat, new_trn_mat, old_idx, new_idx)
    combined_trn_mat = sp.hstack([combined_trn_mat, old_trn_mat[:, old_not_in_new_ids]])
    
    combined_tst_mat = combine_matrix(old_tst_mat, new_tst_mat, old_idx, new_idx)
    combined_tst_mat = sp.hstack([combined_tst_mat, old_tst_mat[:, old_not_in_new_ids]])

    combined_vocab_ids = new_vocab_ids + [old_vocab_ids[i] for i in old_not_in_new_ids]
    combined_vocab_txt = new_vocab_txt + [old_vocab_txt[i] for i in old_not_in_new_ids]

    if old_lbl_mat is not None:
        combined_lbl_mat = combine_matrix(old_lbl_mat, new_lbl_mat, old_idx, new_idx)
        combined_lbl_mat = sp.hstack([combined_lbl_mat, old_lbl_mat[:, old_not_in_new_ids]])
        return combined_vocab_ids, combined_vocab_txt, combined_trn_mat, combined_tst_mat, combined_lbl_mat 

    return combined_vocab_ids, combined_vocab_txt, combined_trn_mat, combined_tst_mat
    

## Extract `labels`

In [None]:
#| export
def get_labels(page_info, redirect_info, data_dir, key):
    redirect_inp2targ = get_redirect_map(redirect_info)
    
    all_ids, trn_ids, tst_ids = get_all_ids(data_dir, redirect_inp2targ, load_lbl=False)
    
    lbl_mapping = get_mapping(page_info, all_ids, key)

    lbl_ids, lbl_txt, mapping_item2idx = get_vocabulary(lbl_mapping, key)
    
    trn_mat, tst_mat = get_matrix(mapping_item2idx, len(lbl_ids), trn_ids, tst_ids)
    
    lbl_ids, lbl_txt, trn_mat, tst_mat, _ = filter_vocab(lbl_ids, lbl_txt, trn_mat, tst_mat)
    
    return trn_mat, trn_ids, tst_mat, tst_ids, lbl_ids, lbl_txt
    

In [None]:
#| export
def get_filterer(trn_ids, tst_ids, label_ids, trn_mat, tst_mat):
    trn_filterer, tst_filterer = Filterer.generate(trn_ids, tst_ids, label_ids, trn_mat, tst_mat)
    tst_mat = Filterer.apply(tst_mat, tst_filterer)
    return trn_filterer, tst_filterer, tst_mat
    

In [None]:
#| export
def save_labels(save_dir, trn_mat, tst_mat, lbl_ids, lbl_txt, x_prefix, y_prefix, trn_filterer=None, tst_filterer=None):
    sp.save_npz(f'{save_dir}/trn_X_Y_{x_prefix}-{y_prefix}.npz', trn_mat)
    sp.save_npz(f'{save_dir}/tst_X_Y_{x_prefix}-{y_prefix}.npz', tst_mat)

    if trn_filterer is not None: np.savetxt(f'{save_dir}/filter_labels_train_{x_prefix}-{y_prefix}.txt', trn_filterer)
    if tst_filterer is not None: np.savetxt(f'{save_dir}/filter_labels_test_{x_prefix}-{y_prefix}.txt', tst_filterer)
    
    os.makedirs(f'{save_dir}/raw_data', exist_ok=True)
    
    save_raw_txt(f'{save_dir}/raw_data/label.{x_prefix}-{y_prefix}.raw.txt', lbl_ids, lbl_txt)
    

In [None]:
#| export
def _save_old_raw_txt(data_dir, save_dir, redirect_info, prefix=''):
    if len(prefix): x_prefix = f'.{prefix}'
        
    redirect_inp2targ = get_redirect_map(redirect_info)
    _, trn_ids, tst_ids, lbl_ids = get_all_ids(data_dir, redirect_inp2targ)
    
    save_raw_txt(f'{save_dir}/raw_data/train{x_prefix}.raw.txt', [o.replace(' ', '_') for o in trn_ids], trn_ids)
    save_raw_txt(f'{save_dir}/raw_data/test{x_prefix}.raw.txt', [o.replace(' ', '_') for o in tst_ids], tst_ids)
    
    if len(prefix): y_prefix = f'.{prefix}-{prefix}'
    save_raw_txt(f'{save_dir}/raw_data/label{y_prefix}.raw.txt', [o.replace(' ', '_') for o in lbl_ids], lbl_ids)
    

In [None]:
#| export
def get_and_save_labels(page_info, redirect_info, data_dir, save_dir, key, old_prefix, new_prefix):
    trn_mat, trn_ids, tst_mat, tst_ids, lbl_ids, lbl_txt = get_labels(page_info, redirect_info, data_dir, key)
    
    trn_filterer, tst_filterer, tst_mat = get_filterer(trn_ids, tst_ids, lbl_ids, trn_mat, tst_mat)
    
    save_labels(save_dir, trn_mat, tst_mat, lbl_ids, lbl_txt, trn_filterer=trn_filterer, tst_filterer=tst_filterer, 
                x_prefix=old_prefix, y_prefix=new_prefix)

    _save_old_raw_txt(data_dir, save_dir, redirect_info, prefix=old_prefix)
    

In [None]:
info_dir = '/home/scai/phd/aiz218323/scratch/datasets/wikipedia/20250123/info/'
page_info, redirect_info = load_info(info_dir)

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K/'
save_dir = '/home/scai/phd/aiz218323/scratch/datasets/wikipedia/20250123/LF-WikiSeeAlsoTitles-320K/'

In [None]:
get_and_save_labels(page_info, redirect_info, data_dir, save_dir, key='see_also', old_prefix='old', new_prefix='new')

  0%|          | 0/746706 [00:00<?, ?it/s]

  0%|          | 0/693082 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

## Combine `labels`

In [None]:
#| export
def _load_new_labels(data_dir, x_prefix, y_prefix):
    trn_mat = sp.load_npz(f'{data_dir}/trn_X_Y_{x_prefix}-{y_prefix}.npz')
    tst_mat = sp.load_npz(f'{data_dir}/tst_X_Y_{x_prefix}-{y_prefix}.npz')
        
    lbl_ids, lbl_txt = load_raw_txt(f'{data_dir}/raw_data/label.{x_prefix}-{y_prefix}.raw.txt')
    
    return trn_mat, tst_mat, lbl_ids, lbl_txt
    

In [None]:
#| export
def _load_old_labels(old_dir, new_dir, prefix=''):
    trn_mat = du.read_sparse_file(f'{old_dir}/trn_X_Y.txt')
    tst_mat = du.read_sparse_file(f'{old_dir}/tst_X_Y.txt')

    if len(prefix): prefix = f'.{prefix}-{prefix}'
    lbl_ids, lbl_txt = load_raw_txt(f'{new_dir}/raw_data/label{prefix}.raw.txt')
    
    return trn_mat, tst_mat, lbl_ids, lbl_txt
    

In [None]:
#| export
def combined_labels(new_dir, old_dir, old_prefix='', new_prefix=''):
    new_trn_mat, new_tst_mat, new_lbl_ids, new_lbl_txt = _load_new_labels(new_dir, old_prefix, new_prefix)
    old_trn_mat, old_tst_mat, old_lbl_ids, old_lbl_txt = _load_old_labels(old_dir, new_dir, old_prefix)
    
    return get_combined_data(old_lbl_ids, old_lbl_txt, old_trn_mat, old_tst_mat, 
                             new_lbl_ids, new_lbl_txt, new_trn_mat, new_tst_mat)
    

In [None]:
combined_lbl_ids, combined_lbl_txt, combined_trn_mat, combined_tst_mat = combined_labels(save_dir, data_dir, old_prefix='old', 
                                                                                         new_prefix='new')

In [None]:
#| export
def _get_new_data_ids(data_dir, prefix=''):
    if len(prefix): prefix = f'.{prefix}'
        
    trn_ids, trn_txt = load_raw_txt(f'{data_dir}/raw_data/train{prefix}.raw.txt')
    tst_ids, tst_txt = load_raw_txt(f'{data_dir}/raw_data/test{prefix}.raw.txt')

    return trn_ids, tst_ids

def _get_filterer(data_dir, prefix, lbl_ids, trn_mat, tst_mat):
    trn_ids, tst_ids = _get_new_data_ids(data_dir, prefix=prefix)
    return get_filterer(trn_ids, tst_ids, lbl_ids, trn_mat, tst_mat)
    
    

In [None]:
#| export
def combine_and_save_labels(new_dir, old_dir, old_prefix='', new_prefix='', com_prefix=''):
    lbl_ids, lbl_txt, trn_mat, tst_mat = combined_labels(new_dir, old_dir, old_prefix=old_prefix, new_prefix=new_prefix)
    
    trn_filterer, tst_filterer, tst_mat = _get_filterer(new_dir, old_prefix, lbl_ids, trn_mat, tst_mat)
    
    save_labels(new_dir, trn_mat, tst_mat, lbl_ids, lbl_txt, trn_filterer=trn_filterer, tst_filterer=tst_filterer, 
                x_prefix=old_prefix, y_prefix=com_prefix)
    

In [None]:
combine_and_save_labels(save_dir, data_dir, old_prefix='old', new_prefix='new', com_prefix='combined')

## Extract `metadata`

In [None]:
#| export
def get_metadata(page_info, key, trn_ids, tst_ids, lbl_ids):
    all_ids = set(trn_ids + tst_ids + lbl_ids)

    metadata_mapping = get_mapping(page_info, all_ids, key)

    metadata_ids, metadata_txt, mapping_item2idx = get_vocabulary(metadata_mapping, key)
    
    trn_mat, tst_mat, lbl_mat = get_matrix(mapping_item2idx, len(metadata_ids), trn_ids, tst_ids, lbl_ids)

    metadata_ids, metadata_txt, trn_mat, tst_mat, lbl_mat = filter_vocab(metadata_ids, metadata_txt, trn_mat, tst_mat, lbl_mat)
    return trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt
    

In [None]:
#| export
def _get_ids(fname):
    ids, _ = load_raw_txt(fname)
    return [o.replace('_', ' ') for o in ids]
    
def _get_all_ids(data_dir, old_prefix, new_prefix, com_prefix):
    old_trn_ids = _get_ids(f'{data_dir}/raw_data/train.{old_prefix}.raw.txt')
    
    old_tst_ids = _get_ids(f'{data_dir}/raw_data/test.{old_prefix}.raw.txt')
    
    old_lbl_ids = _get_ids(f'{data_dir}/raw_data/label.{old_prefix}-{old_prefix}.raw.txt')
    new_lbl_ids = _get_ids(f'{data_dir}/raw_data/label.{old_prefix}-{new_prefix}.raw.txt')
    com_lbl_ids = _get_ids(f'{data_dir}/raw_data/label.{old_prefix}-{com_prefix}.raw.txt')

    return old_trn_ids, old_tst_ids, old_lbl_ids, new_lbl_ids, com_lbl_ids
    

In [None]:
#| export
def save_metadata(save_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, metadata_type, x_prefix, y_prefix, z_prefix):
    sp.save_npz(f'{save_dir}/{metadata_type}_trn_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz', trn_mat)
    sp.save_npz(f'{save_dir}/{metadata_type}_tst_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz', tst_mat)
    sp.save_npz(f'{save_dir}/{metadata_type}_lbl_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz', lbl_mat)
    
    os.makedirs(f'{save_dir}/raw_data', exist_ok=True)
    
    save_raw_txt(f'{save_dir}/raw_data/{metadata_type}.{x_prefix}-{y_prefix}-{z_prefix}.raw.txt', metadata_ids, metadata_txt)
    

In [None]:
#| export
def _save_old_metadata_raw_txt(data_dir, save_dir, redirect_info, metadata_type, prefix):
    redirect_inp2targ = get_redirect_map(redirect_info)
    metadata_ids, metadata_txt = load_raw_txt(f'{data_dir}/raw_data/{metadata_type}.raw.txt', encoding='latin-1')
    
    metadata_ids = transform_ids(metadata_ids, redirect_inp2targ)
    metadata_txt = [o[9:] for o in metadata_ids]
    metadata_ids = [o.replace(' ', '_') for o in metadata_ids]

    save_raw_txt(f'{save_dir}/raw_data/{metadata_type}.{prefix}-{prefix}-{prefix}.raw.txt', metadata_ids, metadata_txt)
    

In [None]:
#| export
def get_and_save_metadata(page_info, redirect_info, data_dir, save_dir, key, old_prefix, new_prefix, com_prefix):
    trn_ids, tst_ids, old_lbl_ids, new_lbl_ids, com_lbl_ids = _get_all_ids(save_dir, old_prefix, new_prefix, com_prefix)

    trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt = get_metadata(page_info, key, trn_ids, tst_ids, old_lbl_ids)
    save_metadata(save_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, key, x_prefix=old_prefix, y_prefix=old_prefix, 
                  z_prefix=new_prefix)

    trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt = get_metadata(page_info, key, trn_ids, tst_ids, new_lbl_ids)
    save_metadata(save_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, key, x_prefix=old_prefix, y_prefix=new_prefix, 
                  z_prefix=new_prefix)

    trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt = get_metadata(page_info, key, trn_ids, tst_ids, com_lbl_ids)
    save_metadata(save_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, key, x_prefix=old_prefix, y_prefix=com_prefix, 
                  z_prefix=new_prefix)

    _save_old_metadata_raw_txt(data_dir, save_dir, redirect_info, metadata_type=key, prefix=old_prefix)
    

In [None]:
get_and_save_metadata(page_info, redirect_info, data_dir, save_dir, key='category', old_prefix='old', new_prefix='new', 
                      com_prefix='combined')

  0%|          | 0/885604 [00:00<?, ?it/s]

  0%|          | 0/693082 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

  0%|          | 0/312330 [00:00<?, ?it/s]

  0%|          | 0/997468 [00:00<?, ?it/s]

  0%|          | 0/693082 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

  0%|          | 0/573503 [00:00<?, ?it/s]

  0%|          | 0/1016125 [00:00<?, ?it/s]

  0%|          | 0/693082 [00:00<?, ?it/s]

  0%|          | 0/177515 [00:00<?, ?it/s]

  0%|          | 0/646869 [00:00<?, ?it/s]

## Combine `metadata`

In [None]:
#| export
def _load_old_metadata(data_dir, save_dir, metadata_type, prefix):
    trn_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_trn_X_Y.txt')
    tst_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_tst_X_Y.txt')
    lbl_mat = du.read_sparse_file(f'{data_dir}/{metadata_type}_lbl_X_Y.txt')
    
    metadata_ids, metadata_txt = load_raw_txt(f'{save_dir}/raw_data/{metadata_type}.{prefix}-{prefix}-{prefix}.raw.txt')

    return trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt
    

In [None]:
#| export
def _load_new_metadata(data_dir, metadata_type, x_prefix, y_prefix, z_prefix):
    trn_mat = sp.load_npz(f'{data_dir}/{metadata_type}_trn_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz')
    tst_mat = sp.load_npz(f'{data_dir}/{metadata_type}_tst_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz')
    lbl_mat = sp.load_npz(f'{data_dir}/{metadata_type}_lbl_X_Y_{x_prefix}-{y_prefix}-{z_prefix}.npz')
    
    metadata_ids, metadata_txt = load_raw_txt(f'{data_dir}/raw_data/{metadata_type}.{x_prefix}-{y_prefix}-{z_prefix}.raw.txt')

    return trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt
    

In [None]:
#| export
def _align_lbl_mat(mat, data_dir, x_prefix, y_prefix):
    old_lbl_ids, _ = load_raw_txt(f'{data_dir}/raw_data/label.{x_prefix}-{x_prefix}.raw.txt')
    new_lbl_ids, _ = load_raw_txt(f'{data_dir}/raw_data/label.{x_prefix}-{y_prefix}.raw.txt')

    (new_idx, old_idx), _ = ids_intersection(old_lbl_ids, new_lbl_ids)

    data, indices, indptr = [], [], [0]
    ptr = 0
    for i in range(len(new_lbl_ids)):
        if ptr < len(new_idx) and i == new_idx[ptr]:
            old_start, old_end = mat.indptr[old_idx[ptr]], mat.indptr[old_idx[ptr]+1]
            data.extend(mat.data[old_start:old_end].tolist())
            indices.extend(mat.indices[old_start:old_end].tolist())
            indptr.append(len(data))
            ptr += 1
        else: 
            indptr.append(len(data))
    return sp.csr_matrix((data, indices, indptr), dtype=mat.dtype, shape=(len(new_lbl_ids), mat.shape[1]))
    

In [None]:
#| export
def combined_metadata(new_dir, old_dir, metadata_type, x_prefix, y_prefix, z_prefix):
    old_trn_mat, old_tst_mat, old_lbl_mat, old_metadata_ids, old_metadata_txt = _load_old_metadata(old_dir, new_dir, metadata_type, 
                                                                                                   prefix=x_prefix)
    new_trn_mat, new_tst_mat, new_lbl_mat, new_metadata_ids, new_metadata_txt = _load_new_metadata(new_dir, metadata_type, x_prefix, 
                                                                                                   y_prefix, z_prefix)

    if old_lbl_mat.shape[0] != new_lbl_mat.shape[0]: old_lbl_mat = _align_lbl_mat(old_lbl_mat, new_dir, x_prefix, y_prefix)
    
    return get_combined_data(old_metadata_ids, old_metadata_txt, old_trn_mat, old_tst_mat, 
                             new_metadata_ids, new_metadata_txt, new_trn_mat, new_tst_mat, 
                             old_lbl_mat, new_lbl_mat)
    

In [None]:
com_category_ids, com_category_txt, com_trn_mat, com_tst_mat, com_lbl_mat = combined_metadata(save_dir, data_dir, x_prefix='old', 
                                                                                              y_prefix='new', z_prefix='new')

In [None]:
#| export
def combine_and_save_metadata(new_dir, old_dir, metadata_type, old_prefix, new_prefix, com_prefix):
    metadata_ids, metadata_txt, trn_mat, tst_mat, lbl_mat = combined_metadata(new_dir, old_dir, metadata_type, x_prefix=old_prefix, 
                                                                              y_prefix=old_prefix, z_prefix=new_prefix)
    save_metadata(new_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, metadata_type, 
                  x_prefix=old_prefix, y_prefix=old_prefix, z_prefix=com_prefix)

    metadata_ids, metadata_txt, trn_mat, tst_mat, lbl_mat = combined_metadata(new_dir, old_dir, metadata_type, x_prefix=old_prefix, 
                                                                              y_prefix=new_prefix, z_prefix=new_prefix)
    save_metadata(new_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, metadata_type, 
                  x_prefix=old_prefix, y_prefix=new_prefix, z_prefix=com_prefix)

    metadata_ids, metadata_txt, trn_mat, tst_mat, lbl_mat = combined_metadata(new_dir, old_dir, metadata_type, x_prefix=old_prefix, 
                                                                              y_prefix=com_prefix, z_prefix=new_prefix)
    save_metadata(new_dir, trn_mat, tst_mat, lbl_mat, metadata_ids, metadata_txt, metadata_type, 
                  x_prefix=old_prefix, y_prefix=com_prefix, z_prefix=com_prefix)
    

In [None]:
combine_and_save_metadata(save_dir, data_dir, metadata_type='category', old_prefix='old', new_prefix='new', com_prefix='combined')

## `__main__`

In [None]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--info_dir', type=str, required=True)
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--label_key', type=str, default=None)
    parser.add_argument('--metadata_key', type=str, default=None)
    parser.add_argument('--combine', action='store_true')
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    page_info, redirect_info = load_info(args.info_dir)

    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(f'{args.save_dir}/raw_data', exist_ok=True)

    if args.label_key is not None:
        print('Extracting labels ...')
        get_and_save_labels(page_info, redirect_info, args.data_dir, args.save_dir, key=args.label_key, old_prefix='old', new_prefix='new')
        if args.combine:
            print('Combining labels ...')
            combine_and_save_labels(args.save_dir, args.data_dir, old_prefix='old', new_prefix='new', com_prefix='combined')
            
    if args.metadata_key is not None:
        print('Extracting metadata ...')
        get_and_save_metadata(page_info, redirect_info, args.data_dir, args.save_dir, key=args.metadata_key, old_prefix='old', 
                              new_prefix='new', com_prefix='combined')
        if args.combine:
            print('Combining labels ...')
            combine_and_save_metadata(args.save_dir, args.data_dir, args.metadata_key, old_prefix='old', new_prefix='new', com_prefix='combined')
