In [1]:
import json
import re
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import sparse

In [2]:
def get_metadata():
    with open('data/arxiv-metadata-oai-snapshot.json', 'r') as f:
        for line in f:
            yield line
            
stopwords=["", "new","non","using","a","about","above","after","again","against","all","am","an","and","any","are","aren't","as","at","be","because","been","before","being","below","between","both","but","by","can't","cannot","could","couldn't","did","didn't","do","does","doesn't","doing","don't","down","during","each","few","for","from","further","had","hadn't","has","hasn't","have","haven't","having","he","he'd","he'll","he's","her","here","here's","hers","herself","him","himself","his","how","how's","i","i'd","i'll","i'm","i've","if","in","into","is","isn't","it","it's","its","itself","let's","me","more","most","mustn't","my","myself","no","nor","not","of","off","on","once","only","or","other","ought","our","ours	ourselves","out","over","own","same","shan't","she","she'd","she'll","she's","should","shouldn't","so","some","such","than","that","that's","the","their","theirs","them","themselves","then","there","there's","these","they","they'd","they'll","they're","they've","this","those","through","to","too","under","until","up","very","was","wasn't","we","we'd","we'll","we're","we've","were","weren't","what","what's","when","when's","where","where's","which","while","who","who's","whom","why","why's","with","won't","would","wouldn't","you","you'd","you'll","you're","you've","your","yours","yourself","yourselves"]

def pre_process_abstract(s):
    s = re.split('\W+', s)
    s = [word.lower() for word in s if word.lower() not in stopwords]
    return s

def doc_list_to_sparse(words_list):
    word_set = set([])
    for words in words_list:
        word_set.update(words)
    word_dict = {}
    for i, word in enumerate(word_set):
        word_dict[word] = i
        
    indices = []
    for i in tqdm(range(len(words_list))):
        words = words_list[i]
        for word in words:
            indices.append([i, word_dict[word]])
    
    indices = torch.from_numpy(np.asarray(indices)).long()
    values = torch.ones(len(indices))
    size = torch.Size([len(words_list), len(word_set)])
    print(indices.size())
    print(values.size())
    print(size)
    return torch.sparse.FloatTensor(indices.t(), values, size), word_dict

def doc_list_to_tf_idf(words_list, idf_words_idx):
    idf_word_counter = Counter()
    for words in [words_list[i] for i in idf_words_idx]:
        idf_word_counter.update(list(set(words)))
    
    word_set = set(list(idf_word_counter.keys()))
    
    word_dict = {}
    idf = {}
    for i, word in enumerate(word_set):
        word_dict[word] = i
        idf[word] = np.log(np.asarray([float(len(idf_words_idx)) / idf_word_counter[word]]))
        
    indices = []
    values = []
    for i in tqdm(range(len(words_list))):
        words = words_list[i]
        num_word = 0
        for word in words:
            if word not in word_set:
                continue
            num_word += 1
        for word in words:
            if word not in word_set:
                continue
            indices.append([i, word_dict[word]])
            values.append(1.0 / num_word * idf[word])
    
    indices = torch.from_numpy(np.asarray(indices)).long()
    values = torch.from_numpy(np.asarray(values)).float().squeeze()
    size = torch.Size([len(words_list), len(word_set)])
    return torch.sparse.FloatTensor(indices.t(), values, size), word_dict

In [3]:
metadata = get_metadata()
total_num = 0
for paper in metadata:
    total_num += 1
print(total_num)

1796911


In [4]:
#load abstracts, categories, created_dates
abstracts = []
categories = []
created_dates = []

metadata = get_metadata()
for paper in tqdm(metadata, total=total_num):
    parsed = json.loads(paper)
    abstract = pre_process_abstract(parsed["abstract"])
    abstracts.append(abstract)
    categories.append(parsed["categories"])
    created_dates.append(parsed["versions"][0]['created'])

  4%|▍         | 74734/1796911 [00:25<09:38, 2978.41it/s]


KeyboardInterrupt: 

In [8]:
checkpoint = {}
checkpoint["abstracts"] = abstracts
checkpoint["categories"] = categories
checkpoint["created_dates"] = created_dates
torch.save(checkpoint, "data/arxiv_all.pt")

In [5]:
checkpoint = torch.load("data/arxiv_all.pt")
abstracts = checkpoint["abstracts"]
categories = checkpoint["categories"]
created_dates = checkpoint["created_dates"]
del checkpoint

In [6]:
abstract_word_counter = Counter()
for abstract in tqdm(abstracts, total=len(abstracts)):
    abstract_word_counter.update(abstract)

100%|██████████| 1796911/1796911 [00:26<00:00, 68427.06it/s] 


In [7]:
final_abstracts = []
WordcountCutoff= 30
abstract_word_counter_final = Counter()
for abstract in tqdm(abstracts, total=len(abstracts)):
    final_abstract = [word for word in abstract if abstract_word_counter[word] > WordcountCutoff]
    final_abstracts.append(final_abstract)
    abstract_word_counter_final.update(final_abstract)

100%|██████████| 1796911/1796911 [01:21<00:00, 21918.22it/s]


In [8]:
final_categories = []
category_counter = Counter()
for category in tqdm(categories):
    category_counter.update(category.split(" "))
    final_categories.append(category.split(" "))

100%|██████████| 1796911/1796911 [00:21<00:00, 82026.20it/s] 


In [9]:
#cs
category_list = list(category_counter.keys())
select_category_list = [category for category in category_list if (category.startswith("cs"))]

In [10]:
filtered_abstracts = []
filtered_categories = []
filtered_full_categories = []
filtered_date = []
for (abstract, category, date) in tqdm(zip(final_abstracts, final_categories, created_dates)):
    if len(abstract) <= 20:
        continue
    if category[0] not in select_category_list:
        continue
    filtered_abstracts.append(abstract)
    filtered_categories.append(category[0])
    filtered_full_categories.append(category)
    filtered_date.append(date)

1796911it [00:01, 991221.33it/s] 


In [11]:
months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
months_to_val = {}
for i, month in enumerate(months):
    months_to_val[month] = i

date_val = []
for date in tqdm(filtered_date):
    d_m_y_h_m_s = date.split(" ")[1:4] + date.split(" ")[4].split(":")
    date_val.append(int(d_m_y_h_m_s[5]) + int(d_m_y_h_m_s[4]) * 60 + int(d_m_y_h_m_s[3]) * 60 * 60 + int(d_m_y_h_m_s[0]) * 30 * 60 * 60 + months_to_val[d_m_y_h_m_s[1]] * 40 * 30 * 60 * 60 + int(d_m_y_h_m_s[2]) * 20 * 40 * 30 * 60 * 60)
    

100%|██████████| 257062/257062 [00:00<00:00, 328612.55it/s]


In [12]:
rank = np.argsort(np.asarray(date_val))
sorted_abstracts = [filtered_abstracts[i] for i in rank]
sorted_categories = [filtered_categories[i] for i in rank]
sorted_full_categories = [filtered_full_categories[i] for i in rank]
sorted_date = [filtered_date[i] for i in rank]
sorted_date_val = np.asarray(date_val)[rank]

In [13]:
abstracts_bow, word_dict = doc_list_to_sparse(sorted_abstracts)

100%|██████████| 257062/257062 [01:29<00:00, 2876.62it/s] 


torch.Size([26064580, 2])
torch.Size([26064580])
torch.Size([257062, 44393])


In [14]:
#label
categories_list = list(set(filtered_categories))
categories_dict = {}
for i, category in enumerate(categories_list):
     categories_dict[category] = i
categories_label = [categories_dict[category] for category in sorted_categories]

In [15]:
remain_classes = []
num_classes = 40
for i in range(num_classes):
    if (np.asarray(categories_label) == i).sum() < 3000:
#         print(list(categories_dict.keys())[i])
        continue
    else:
        remain_classes.append(i)
remain_data_id = np.asarray([label in remain_classes for label in categories_label]).astype(np.bool)
categories_label_np = np.asarray(categories_label)
for i, class_id in enumerate(remain_classes):
    categories_label_np[categories_label_np == class_id] = i

In [16]:
# alpha, beta = int(1 / 3 * len(categories_label_np[remain_data_id])), int(1.0 / 2 * len(categories_label_np[remain_data_id]))
alpha, beta = int(1 / 2 * len(categories_label_np[remain_data_id])), int(3.0 / 4 * len(categories_label_np[remain_data_id]))
np.random.seed(0)
train_val_split = np.arange(len(categories_label_np)).astype(np.int)[remain_data_id][:beta].tolist()
abstracts_tfidf, word_dict = doc_list_to_tf_idf(sorted_abstracts, train_val_split)

100%|██████████| 257062/257062 [02:14<00:00, 1909.99it/s]


In [17]:
raw_checkpoint = {}
raw_checkpoint["sorted_abstracts"] = sorted_abstracts
raw_checkpoint["sorted_categories"] = sorted_categories
raw_checkpoint["sorted_date"] = sorted_date
raw_checkpoint["sorted_date_val"] = sorted_date_val
raw_checkpoint["word_dict"] = word_dict
raw_checkpoint["categories_dict"] = categories_dict
raw_checkpoint["abstracts_bow"] = abstracts_bow
raw_checkpoint["categories_label"] = categories_label
raw_checkpoint["abstracts_tfidf"] = abstracts_tfidf
raw_checkpoint["sorted_full_categories"] = sorted_full_categories
torch.save(raw_checkpoint, "data/arxiv_before_split.pt")

In [18]:
raw_checkpoint = torch.load("data/arxiv_before_split.pt")
sorted_abstracts = raw_checkpoint["sorted_abstracts"]
abstracts_bow = raw_checkpoint["abstracts_bow"]
categories_label = np.asarray(raw_checkpoint["categories_label"])
categories_dict = raw_checkpoint["categories_dict"]
abstracts_tfidf = raw_checkpoint["abstracts_tfidf"]
del raw_checkpoint

In [19]:
remain_classes = []
num_classes = 40
for i in range(num_classes):
    if (np.asarray(categories_label) == i).sum() < 3000:
        continue
    else:
        remain_classes.append(i)
        print(list(categories_dict.keys())[i])
remain_data_id = np.asarray([label in remain_classes for label in categories_label]).astype(np.bool)
print(remain_data_id.shape)

cs.DB
cs.AI
cs.RO
cs.DM
cs.SI
cs.DS
cs.CR
cs.HC
cs.CL
cs.LO
cs.CC
cs.GT
cs.CV
cs.PL
cs.SY
cs.DC
cs.NI
cs.IT
cs.CY
cs.LG
cs.NE
cs.IR
cs.SE
(257062,)


In [20]:
categories_label_np = np.asarray(categories_label)
for i, class_id in enumerate(remain_classes):
    categories_label_np[categories_label_np == class_id] = i

In [21]:
abstracts_tfidf_dict = {"indices":abstracts_tfidf._indices().numpy(), "values": abstracts_tfidf._values().numpy(), "size": list(abstracts_tfidf.size())}
values, indices, size = abstracts_tfidf_dict["values"], abstracts_tfidf_dict["indices"], abstracts_tfidf_dict["size"]
np_sparse_tfidf_dict = sparse.coo_matrix((values, (indices[0], indices[1])), size).tocsr()

In [22]:
checkpoint = {}

alpha, beta = int(1 / 2 * len(categories_label_np[remain_data_id])), int(3.0 / 4 * len(categories_label_np[remain_data_id]))

np.random.seed(0)
train_val_split = np.arange(len(categories_label_np)).astype(np.int)[remain_data_id][:beta]#[np.random.permutation(beta)]
test_split = np.arange(len(categories_label_np)).astype(np.int)[remain_data_id][beta:]

checkpoint["train_x"] = np_sparse_tfidf_dict[train_val_split[:alpha]]
checkpoint["train_y"] = np.asarray(categories_label_np)[train_val_split[:alpha]]
checkpoint["val_x"] = np_sparse_tfidf_dict[train_val_split[alpha:]]
checkpoint["val_y"] = np.asarray(categories_label_np)[train_val_split[alpha:]]
checkpoint["test_x"] = np_sparse_tfidf_dict[test_split]
checkpoint["test_y"] = np.asarray(categories_label_np)[test_split]
checkpoint["num_classes"] = len(remain_classes)
checkpoint["remain_classes"] = remain_classes
torch.save(checkpoint, "data/arxiv.pt")

In [23]:
print(checkpoint["train_x"].shape)

(116874, 41942)


In [29]:
print(116874*2)

233748
