In [1]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import xclib.data.data_utils as du
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
from IPython.display import display
from timeit import default_timer as timer
import xclib.evaluation.xc_metrics as xc_metrics

# Code

In [2]:
def read_id(filename, rm_suffix_idx=None):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[0][:rm_suffix_idx])
    return text

def read_ids(data_dir, rm_suffix_idx=None, meta_tag=""):
    trn_id = read_id(f"{data_dir}/raw_data/train.raw.txt", rm_suffix_idx)
    tst_id = read_id(f"{data_dir}/raw_data/test.raw.txt", rm_suffix_idx)
    lbl_id = read_id(f"{data_dir}/raw_data/label.raw.txt", rm_suffix_idx)
    if meta_tag:
        meta_id = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt", rm_suffix_idx)
        return trn_id, tst_id, lbl_id, meta_id
    return trn_id, tst_id, lbl_id


In [3]:
def read_map(filename):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[1])
    return text


In [4]:
def read_XC_data(data_dir, meta_tag=""):
    data_tag = f"{meta_tag}_" if meta_tag else meta_tag
    trn_xy = du.read_sparse_file(f"{data_dir}/{data_tag}trn_X_Y.txt")
    tst_xy = du.read_sparse_file(f"{data_dir}/{data_tag}tst_X_Y.txt")
    
    trn_map = read_map(f"{data_dir}/raw_data/train.raw.txt")
    tst_map = read_map(f"{data_dir}/raw_data/test.raw.txt")
    lbl_map = read_map(f"{data_dir}/raw_data/label.raw.txt")
        
    if meta_tag:
        lbl_xy = du.read_sparse_file(f"{data_dir}/{data_tag}lbl_X_Y.txt")
        meta_map = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt")
        return (trn_xy, tst_xy, lbl_xy), (trn_map, tst_map, lbl_map, meta_map)
    return (trn_xy, tst_xy), (trn_map, tst_map, lbl_map)


In [5]:
def load_XC_predictions(pred_file):
    output = np.load(pred_file)
    return csr_matrix((output['data'], output['indices'], output['indptr']), 
                      dtype=float, shape=output['shape'])

In [6]:
def read_XC_results(result_dir, is_clf=True):
    model_type = 'clf' if is_clf else 'knn'
    pred_trn = load_XC_predictions(f"{result_dir}/trn_predictions_{model_type}.npz")
    pred_tst = load_XC_predictions(f"{result_dir}/tst_predictions_{model_type}.npz")
    return pred_trn, pred_tst


# Load

In [121]:
data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K/"
result_dir = "/home/scai/phd/aiz218323/tmp/XC/results/NGAME/STransformer/\
G-LF-WikiSeeAlsoTitles-300K/v_0_100/"

xy_data, xy_maps = read_XC_data(data_dir)
xy_pred = read_XC_results(result_dir)

In [135]:
xy_ids = read_ids(data_dir, rm_suffix_idx=-2)

In [7]:
cat_data, cat_maps = read_XC_data(data_dir, meta_tag="category")



# Metadata Augmentation

In [95]:
k, meta_map = 5, np.array(cat_maps[-1])

aug_maps = []
for xy_map, cat in tqdm(zip(xy_maps, cat_data), total=len(xy_maps)):
    aug_map = []
    for text, aug_row in tqdm(zip(xy_map, cat), total=len(xy_map)):
        aug_idxs = np.random.permutation(aug_row.indices)[:k]
        aug_text = [mt[9:] for mt in meta_map[aug_idxs]]
        aug_text = text + "," + ",".join(aug_text) if len(aug_text) else text
        aug_map.append(aug_text)
    aug_maps.append(aug_map)
    

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/696333 [00:00<?, ?it/s]

  0%|          | 0/260207 [00:00<?, ?it/s]

  0%|          | 0/235882 [00:00<?, ?it/s]

In [182]:
def write_maps(data_dir, ids, maps, meta_tag):
    save_file = f"{data_dir}/raw_data/train_{meta_tag}.raw.txt"
    write_map(ids[0], maps[0], save_file)
    
    save_file = f"{data_dir}/raw_data/test_{meta_tag}.raw.txt"
    write_map(ids[1], maps[1], save_file)
    
    save_file = f"{data_dir}/raw_data/label_{meta_tag}.raw.txt"
    write_map(ids[2], maps[2], save_file)


In [183]:
def write_map(ids, texts, save_file):
    with open(save_file, 'w') as file:
        for id, text in zip(ids, texts):
            file.write(f"{id}->{text}\n")
            

In [123]:
write_maps(data_dir, xy_ids, aug_maps, "category")

# XC Augmentation

In [140]:
meta_data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiTitles-1M/"
meta_result_dir = "/home/scai/phd/aiz218323/tmp/XC/results/NGAME/STransformer/\
G-LF-WikiTitles-1M/v_0_100/"

meta_data, meta_maps = read_XC_data(meta_data_dir)
meta_pred = read_XC_results(meta_result_dir)

In [141]:
meta_ids = read_ids(meta_data_dir)

In [207]:
def XC_augmentation(aug_map, meta_info, meta_lbl, k=5):
    aug_text, aug_ids = aug_map
    meta_pred, meta_ids = meta_info
    meta_lbl = np.array(meta_lbl)
    
    _, idxs, meta_idxs = np.intersect1d(aug_ids, meta_ids, return_indices=True)
    
    for idx, meta_idx in tqdm(zip(idxs, meta_idxs), total=len(idxs)):
        aug_lbl_idxs = np.argsort(meta_pred[meta_idx].data)[:-k-1:-1]
        aug_lbl_idxs = meta_pred[meta_idx].indices[aug_lbl_idxs]
        aug_text[idx] = aug_text[idx] + "," + ",".join(meta_lbl[aug_lbl_idxs])
        

In [208]:
k = 5
trn_aug_text = xy_maps[0].copy()
tst_aug_text = xy_maps[1].copy()

trn_aug_map = (trn_aug_text, xy_ids[0])
tst_aug_map = (tst_aug_text, xy_ids[1])

trn_meta_info = (meta_pred[0], meta_ids[0])
tst_meta_info = (meta_pred[1], meta_ids[1])

meta_lbl = meta_maps[2]

XC_augmentation(trn_aug_map, trn_meta_info, meta_lbl, k=5)
XC_augmentation(trn_aug_map, tst_meta_info, meta_lbl, k=5)

XC_augmentation(tst_aug_map, trn_meta_info, meta_lbl, k=5)
XC_augmentation(tst_aug_map, tst_meta_info, meta_lbl, k=5)

  0%|          | 0/472007 [00:00<?, ?it/s]

  0%|          | 0/206122 [00:00<?, ?it/s]

  0%|          | 0/173893 [00:00<?, ?it/s]

  0%|          | 0/75956 [00:00<?, ?it/s]

In [285]:
aug_maps = [trn_aug_text, tst_aug_text, xy_maps[2]]
write_maps(data_dir, xy_ids, aug_maps, "XCcategory")

# Encoder Augmentation

In [315]:
meta_data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiTitles-1M/"
meta_result_dir = "/home/scai/phd/aiz218323/tmp/XC/results/NGAME/STransformer/\
G-LF-WikiTitles-1M/v_0_100/"

meta_data, meta_maps = read_XC_data(meta_data_dir)
meta_pred = read_XC_results(meta_result_dir, is_clf=False)

In [316]:
meta_ids = read_ids(meta_data_dir)

In [317]:
k = 5
trn_aug_text = xy_maps[0].copy()
tst_aug_text = xy_maps[1].copy()

trn_aug_map = (trn_aug_text, xy_ids[0])
tst_aug_map = (tst_aug_text, xy_ids[1])

trn_meta_info = (meta_pred[0], meta_ids[0])
tst_meta_info = (meta_pred[1], meta_ids[1])

meta_lbl = meta_maps[2]

XC_augmentation(trn_aug_map, trn_meta_info, meta_lbl, k=5)
XC_augmentation(trn_aug_map, tst_meta_info, meta_lbl, k=5)

XC_augmentation(tst_aug_map, trn_meta_info, meta_lbl, k=5)
XC_augmentation(tst_aug_map, tst_meta_info, meta_lbl, k=5)

  0%|          | 0/472007 [00:00<?, ?it/s]

  0%|          | 0/206122 [00:00<?, ?it/s]

  0%|          | 0/173893 [00:00<?, ?it/s]

  0%|          | 0/75956 [00:00<?, ?it/s]

In [318]:
aug_maps = [trn_aug_text, tst_aug_text, xy_maps[2]]
write_maps(data_dir, xy_ids, aug_maps, "KNNcategory")