In [1]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import xclib.data.data_utils as du
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
from IPython.display import display
from timeit import default_timer as timer
import xclib.evaluation.xc_metrics as xc_metrics

# Code

In [2]:
def read_id(filename, rm_suffix_idx=None):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[0][:rm_suffix_idx])
    return text

def read_ids(data_dir, rm_suffix_idx=None, meta_tag=""):
    trn_id = read_id(f"{data_dir}/raw_data/train.raw.txt", rm_suffix_idx)
    tst_id = read_id(f"{data_dir}/raw_data/test.raw.txt", rm_suffix_idx)
    lbl_id = read_id(f"{data_dir}/raw_data/label.raw.txt", rm_suffix_idx)
    if meta_tag:
        meta_id = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt", rm_suffix_idx)
        return trn_id, tst_id, lbl_id, meta_id
    return trn_id, tst_id, lbl_id


In [3]:
def read_map(filename):
    text = []
    with open(filename) as file:
        for line in file:
            text.append(line[:-1].split('->', maxsplit=1)[1])
    return text


In [4]:
def read_XC_data(data_dir, meta_tag=""):
    data_tag = f"{meta_tag}_" if meta_tag else meta_tag
    trn_xy = du.read_sparse_file(f"{data_dir}/{data_tag}trn_X_Y.txt")
    tst_xy = du.read_sparse_file(f"{data_dir}/{data_tag}tst_X_Y.txt")
    
    trn_map = read_map(f"{data_dir}/raw_data/train.raw.txt")
    tst_map = read_map(f"{data_dir}/raw_data/test.raw.txt")
    lbl_map = read_map(f"{data_dir}/raw_data/label.raw.txt")
        
    if meta_tag:
        lbl_xy = du.read_sparse_file(f"{data_dir}/{data_tag}lbl_X_Y.txt")
        meta_map = read_map(f"{data_dir}/raw_data/{meta_tag}.raw.txt")
        return (trn_xy, tst_xy, lbl_xy), (trn_map, tst_map, lbl_map, meta_map)
    return (trn_xy, tst_xy), (trn_map, tst_map, lbl_map)


In [5]:
def read_filter_files(data_dir):
    filter_trn = np.loadtxt(f"{data_dir}/filter_labels_train.txt", dtype=np.int64)
    filter_tst = np.loadtxt(f"{data_dir}/filter_labels_test.txt", dtype=np.int64)
    
    return filter_trn, filter_tst


In [6]:
def load_XC_predictions(pred_file):
    output = np.load(pred_file)
    return csr_matrix((output['data'], output['indices'], output['indptr']), 
                      dtype=float, shape=output['shape'])


In [7]:
def read_XC_results(result_dir):
    pred_trn = load_XC_predictions(f"{result_dir}/trn_predictions_clf.npz")
    pred_tst = load_XC_predictions(f"{result_dir}/tst_predictions_clf.npz")
    return pred_trn, pred_tst


In [8]:
def compute_metrics(data, pred, trn_xy, a=0.55, b=1.5):
    inv_propen = xc_metrics.compute_inv_propesity(trn_xy, a, b)

    prec = xc_metrics.Metrics(true_labels=data, inv_psp=inv_propen)
    result = prec.eval(pred, 5)
    prec_metric = pd.DataFrame(result, index=['p', 'n', 'psp', 'psn'], 
                               columns=[1, 2, 3, 4, 5])
    
    recall = xc_metrics.recall( X=pred, true_labels=data, k=200)
    psrecall = xc_metrics.psrecall(X=pred, true_labels=data, inv_psp=inv_propen, k=200)
    columns = np.array([20, 100, 200])
    recall_metric = pd.DataFrame([recall[columns-1], psrecall[columns-1]], 
                                 columns=columns, index=['r', 'psr'])
    
    return prec_metric, recall_metric


In [9]:
def compute_filter_idx(trn_xy, trn_map, tst_map, lbl_map):
    _, trn_idx, trn_lbl_idx = np.intersect1d(trn_map, lbl_map, return_indices=True)
    _, tst_idx, tst_lbl_idx = np.intersect1d(tst_map, lbl_map, return_indices=True)
    
    xy_leak = trn_xy[trn_idx][:, tst_lbl_idx]
    trn_filter_idx, lbl_filter_idx = xy_leak.nonzero()
    
    filter_x = tst_idx[lbl_filter_idx]
    filter_y = trn_lbl_idx[trn_filter_idx]
    
    abba_leak = np.vstack([filter_x, filter_y]).T
    self_leak = np.vstack([tst_idx, tst_lbl_idx]).T
    
    return abba_leak
    return np.vstack([self_leak, abba_leak])

#filter_idx = compute_filter_idx(xy_data[0], *xy_maps)

# Load

## G-LF-WikiSeeAlsoTitles-300K

In [10]:
data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiSeeAlsoTitles-300K/"

xy_data, xy_maps = read_XC_data(data_dir)
filter_idx = read_filter_files(data_dir)

In [11]:
result_dir = "/home/scai/phd/aiz218323/tmp/XC/results/NGAME/STransformer/\
G-LF-WikiSeeAlsoTitles-300K/v_8_200/"

xy_pred = read_XC_results(result_dir)

Below code is to remove the data leak due to meta-data

In [None]:
xy_ids = read_ids(data_dir, rm_suffix_idx=-2)

meta_dir = "/home/scai/phd/aiz218323/tmp/XC/data/G-LF-WikiTitles-1M/"
meta_ids = read_ids(meta_dir)

diff_ids = np.setdiff1d(xy_ids[1], meta_ids[0], )
valid_tst_idx = np.where(np.isin(xy_ids[1], diff_ids))[0]

In [None]:
for i, (data, pred, filtr) in enumerate(zip(xy_data, xy_pred, filter_idx)):
    pred[filtr[:, 0], filtr[:, 1]] = 0
    pred.eliminate_zeros()
    
    if i == 1:
        data = data[valid_tst_idx]
        pred = pred[valid_tst_idx]
        
    metrics = compute_metrics(data, pred, xy_data[0])
    display(metrics[0])
    display(metrics[1])
    

Continuation of the code

In [12]:
for i, (data, pred, filtr) in enumerate(zip(xy_data[1:], xy_pred[1:], filter_idx[1:])):
    pred[filtr[:, 0], filtr[:, 1]] = 0
    pred.eliminate_zeros()
    
    metrics = compute_metrics(data, pred, xy_data[0])
    display(metrics[0])
    display(metrics[1])
    

  self._set_arrayXarray(i, j, x)


Unnamed: 0,1,2,3,4,5
p,0.287633,0.210304,0.167495,0.139556,0.120144
n,0.287633,0.285832,0.291117,0.296723,0.301761
psp,0.199464,0.208196,0.218579,0.228249,0.237585
psn,0.199464,0.211789,0.222705,0.230999,0.237504


Unnamed: 0,20,100,200
r,0.4152,0.513529,0.552019
psr,0.368585,0.473587,0.51468


## G-LF-WikiTitles-1M

In [14]:
data_dir = "/home/scai/phd/aiz218323/tmp/XC/data/LF-WikiTitles-700K/"
xy_data, xy_maps = read_XC_data(data_dir)

result_dir = "/home/scai/phd/aiz218323/tmp/XC/results/NGAME/STransformer/\
LF-WikiTitles-700K/v_0_200/"

xy_pred = read_XC_results(result_dir)

for data, pred in zip(xy_data[1:], xy_pred[1:]):
    metrics = compute_metrics(data, pred, xy_data[0])
    display(metrics[0])
    display(metrics[1])



Unnamed: 0,1,2,3,4,5
p,0.460758,0.371424,0.310665,0.265413,0.231579
n,0.460758,0.417337,0.400219,0.394063,0.392801
psp,0.291126,0.292423,0.292593,0.291471,0.291991
psn,0.291126,0.297976,0.306444,0.314951,0.323131


Unnamed: 0,20,100,200
r,0.457455,0.552771,0.590409
psr,0.452902,0.541441,0.576518
