In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import os.path as osp
import glob
from tqdm import tqdm
from ogb.lsc import MAG240MEvaluator

from root import ROOT

train_set = 'train' # 'train' or 'train_val'
data_dir = osp.join(ROOT, 'mag240m_kddcup2021')

In [4]:
__meta__ = torch.load(osp.join(data_dir, 'meta.pt'))
num_papers = __meta__['paper']
num_authors = __meta__['author']
num_institutions = __meta__['institution']
__split__ = torch.load(osp.join(data_dir, 'split_dict.pt'))
train_idx = __split__['train']
val_idx = __split__['valid']
test_idx = __split__['test']
num_classes = 153

paper_labels = np.load(osp.join(data_dir, 'processed', 'paper', 'node_label.npy')).astype(np.int16)
val_reverse_idx = np.zeros(num_papers, dtype=np.int32)
val_reverse_idx[val_idx] = np.arange(val_idx.shape[0])
val_labels = paper_labels[val_idx]

In [5]:
def edge_index(dir, id1: str, id2: str,
                   id3: str = None) -> np.ndarray:
        src = id1
        rel, dst = (id3, id2) if id3 is None else (id2, id3)
        rel = self.__rels__[(src, dst)] if rel is None else rel
        name = f'{src}___{rel}___{dst}'
        path = osp.join(dir, 'processed', name, 'edge_index.npy')
        return np.load(path)
    
author_writes_authors, author_writes_papers = edge_index(data_dir, 'author', 'writes', 'paper')
author_writes_papers_argsort = np.argsort(author_writes_papers)
papers_written_papers = author_writes_papers[author_writes_papers_argsort]
papers_written_authors = author_writes_authors[author_writes_papers_argsort]
papers, author_counts = np.unique(papers_written_papers, return_counts=True)
paper_row_start = np.insert(np.cumsum(author_counts), 0, 0)

In [5]:
path = osp.join(data_dir, f'author_{train_set}_label_probs.npy')
if not osp.exists(path)
    if train_set == 'train':
        paper_idx = train_idx
    else:
        paper_idx = np.concatenate([train_idx, val_idx], axis=0)

    author_label_counts = np.zeros((num_authors, num_classes), dtype=np.int16)
    for paper in tqdm(paper_idx):
        cur_paper_row_start = paper_row_start[paper]
        paper_authors = papers_written_authors[cur_paper_row_start:cur_paper_row_start+author_counts[paper]]
        author_label_counts[paper_authors, paper_labels[paper]] += 1

    author_label_probs = author_label_counts.astype(np.float32) + 1e-10
    author_label_probs = author_label_probs / author_label_probs.sum(axis=-1, keepdims=True)
    np.save(path, author_label_probs)
else:
    author_label_probs = np.load(path)

100%|██████████| 1251341/1251341 [00:09<00:00, 135651.84it/s]


In [None]:
def get_eval(pred_prob):
    preds = np.argmax(pred_prob, axis=-1)
    res = {'y_pred': preds, 'y_true': val_labels}
    eval = MAG240MEvaluator().eval(res)
    return eval

In [None]:
results_folder = osp.join('results_rgnn', 'valid', 'rgat')
model = '*'
pred_probs_paths = osp.join(base_folder, model, '*', '*', 'pred_probs.npy')
pred_probs_paths = glob.glob(preds_probs_paths)
pred_probs = []
for pred_probs_path in pred_probs_paths:
    pred_prob = np.load(pred_probs_path)
    pred_prob = np.exp(pred_prob)
    pred_prob = pred_prob / pred_prob.sum(axis=-1, keepdims=True)
    print(pred_probs_path, get_eval(pred_prob))
    pred_probs.append(np.expand_dims(pred_prob, axis=-1))
pred_probs = np.concatenate(pred_probs, axis=-1)

In [None]:
mean_pred_prob = pred_probs.mean(axis=-1)
print("Mean logit aggregation: ", get_eval(mean_pred_prob))
max_pred_prob = pred_probs.max(axis=-1)
print("Max logit aggregation: ", get_eval(max_pred_prob))

In [None]:
post_process_paper_idx = val_idx
pred_prob = max_pred_prob
new_pred_probs = np.zeros_like(pred_prob)
for paper in tqdm(post_process_paper_idx):
    cur_paper_row_start = paper_row_start[paper]
    paper_authors = papers_written_authors[cur_paper_row_start:cur_paper_row_start+author_counts[paper]]
    cur_paper_author_label_probs = author_label_probs[paper_authors]
    cur_paper_author_label_probs = cur_paper_author_label_probs.mean(axis=0, keepdims=1)
    cur_paper_pred_probs = pred_prob[val_reverse_idx[paper]]
    new_pred_probs[val_reverse_idx[paper]] = 0.4 * cur_paper_author_label_probs + 0.6 * cur_paper_pred_probs
print(get_eval(new_pred_probs))

