In [1]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import xclib.data.data_utils as du
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
from IPython.display import display
from timeit import default_timer as timer
import xclib.evaluation.xc_metrics as xc_metrics

# Code

In [15]:
def read_id(filename, rm_suffix_idx=None):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[0][:rm_suffix_idx])
    return text

def read_ids(data_dir, rm_suffix_idx=None, meta_tag=""):
    trn_id = read_id(f"{data_dir}/raw_data/train.raw.txt", rm_suffix_idx)
    tst_id = read_id(f"{data_dir}/raw_data/test.raw.txt", rm_suffix_idx)
    lbl_id = read_id(f"{data_dir}/raw_data/label.raw.txt", rm_suffix_idx)
    if meta_tag:
        meta_id = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt", rm_suffix_idx)
        return trn_id, tst_id, lbl_id, meta_id
    return trn_id, tst_id, lbl_id


In [17]:
def read_map(filename):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[1])
    return text


In [18]:
def read_XC_data(data_dir, meta_tag=""):
    data_tag = f"{meta_tag}_" if meta_tag else meta_tag
    trn_xy = du.read_sparse_file(f"{data_dir}/{data_tag}trn_X_Y.txt")
    tst_xy = du.read_sparse_file(f"{data_dir}/{data_tag}tst_X_Y.txt")
    
    trn_map = read_map(f"{data_dir}/raw_data/train.raw.txt")
    tst_map = read_map(f"{data_dir}/raw_data/test.raw.txt")
    lbl_map = read_map(f"{data_dir}/raw_data/label.raw.txt")
        
    if meta_tag:
        lbl_xy = du.read_sparse_file(f"{data_dir}/{data_tag}lbl_X_Y.txt")
        meta_map = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt")
        return (trn_xy, tst_xy, lbl_xy), (trn_map, tst_map, lbl_map, meta_map)
    return (trn_xy, tst_xy), (trn_map, tst_map, lbl_map)


In [19]:
def load_XC_predictions(pred_file):
    output = np.load(pred_file)
    return csr_matrix((output['data'], output['indices'], output['indptr']), 
                      dtype=float, shape=output['shape'])

def read_XC_results(result_dir):
    pred_trn = load_XC_predictions(f"{result_dir}/trn_predictions_clf.npz")
    pred_tst = load_XC_predictions(f"{result_dir}/tst_predictions_clf.npz")
    return pred_trn, pred_tst


# Load

In [110]:
data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K/"
xy_data, xy_maps = read_XC_data(data_dir)
xy_ids = read_ids(data_dir)

In [111]:
dataset = "LF-WikiTitles-1M"

In [112]:
result_dir = f"{data_dir}/predictions/{dataset}"
xy_pred = read_XC_results(result_dir)

pred_label_file = f"/home/scai/phd/aiz218323/tmp/XC/data/{dataset}/raw_data/label.raw.txt"
label_map = read_map(pred_label_file)

In [113]:
def create_augmented_map(pred_info, xy_info, topk=5):
    preds, label_map = pred_info
    xy_ids, xy_maps = xy_info
    
    aug_maps = []
    label_map = np.array(label_map)
    for pred, xy_map, xy_id in zip(preds, xy_maps, xy_ids):
        aug_map = []
        for row, text, i in zip(pred, xy_map, xy_id):
            aug_text = ",".join(label_map[row.indices[:topk]])
            aug_text = f"{i}->{text},{aug_text}"
            aug_map.append(aug_text)
        aug_maps.append(aug_map)
    return aug_maps


In [114]:
pred_info = xy_pred, label_map
xy_info = xy_ids, xy_maps

aug_maps = create_augmented_map(pred_info, xy_info, topk=5)

In [115]:
label_aug_map = [f"{i}->{text}" for i, text in zip(xy_ids[2], xy_maps[2])]

# Save

In [116]:
def write_map(texts, filename):
    with open(filename, 'w') as file:
        for text in texts:
            file.write(f"{text}\n")
            

In [117]:
save_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K"

In [118]:
write_map(aug_maps[0], f"{save_dir}/raw_data/train_{dataset}.raw.txt")
write_map(aug_maps[1], f"{save_dir}/raw_data/test_{dataset}.raw.txt")
write_map(label_aug_map, f"{save_dir}/raw_data/label_{dataset}.raw.txt")

In [146]:
file1 = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K/raw_data/label_category.raw.txt"

file2 = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K/raw_data/label_category_metadata.raw.txt"

In [147]:
with open(file1) as f1, open(file2, 'w') as f2:
    for line in f1:
        idx, text = line.split('->', maxsplit=1)
        
        text_parts = text.split(",", maxsplit=1)
        if len(text_parts) > 1:
            f2.write(f'{idx}->{text_parts[1]}')
        else:
            f2.write(f'{idx}->{text_parts[0]}')
        