In [12]:
import torch

In [13]:
import pickle

In [14]:
import numpy as np

### Subsumption class (to avoid pickle errors)

In [None]:
import logging
import sys
import os
import io

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import binarize
import scipy.sparse as sp

In [None]:
class Subsumption:
    def __init__(self, data, topics) -> None:
        self.data_path = data
        self.topics_path = topics
        self.is_topic_path = True
        self.topics_label = ""
        self.overlaps = None
        self.weights = None
        self.features = None
        self.ifeatures = None
        self.lengths = None

    def load_data(self):
        if os.path.exists(self.data_path):
            logging.info('loading preprocessed data from %s' % self.data_path)
            if self.data_path.endswith(".txt"):
                self.data = open(self.data_path, "r")
            else:
                with open(self.data_path, 'rb') as fin:
                    self.data = pickle.load(fin)
        else:
            logging.error("preprocessed data doesn't exist")
            sys.exit()

    def load_topics(self):
        fname = self.topics_path
        if not os.path.exists(fname):
            self.is_topic_path = False
            fname = '/calcul/datasets/nasa/topics-%s.txt' % self.topics_path
            if not os.path.exists(fname):
                logging.error("not a filename or a valid topic name")
                sys.exit()
        logging.info('loading topics from %s' % fname)
        with open(fname, 'r') as f_in:
            self.topics = f_in.read()
        self.topics = self.topics.split('\n')
        logging.info('loaded %d topics' % len(self.topics))

    def make_counts(self):
        logging.info("getting topics counts")
        pattern = "(?u)\\b[\\w-]+\\b"

        self.vectorizer = CountVectorizer(vocabulary=set(
            self.topics), token_pattern=pattern, ngram_range=(1, 3))
        self.counts = self.vectorizer.transform(self.data)
        if isinstance(self.data, io.IOBase):
            self.data.close()
        del(self.data)
        self.features = self.vectorizer.get_feature_names()
        self.ifeatures = {k: v for v, k in enumerate(self.features)}

    def make_matrices(self):
        logging.info("getting the overlap and weight matrices")
        self.counts = binarize(self.counts)
        self.overlaps = self.counts.T.dot(self.counts)
        # del(self.counts)
        self.overlaps.data *= self.overlaps.data > 1
        self.overlaps.eliminate_zeros()
        self.lengths = self.overlaps.diagonal()
        diagonal = sp.diags([1./x if x > 0 else 0 for x in self.lengths])
        self.overlaps = diagonal.dot(self.overlaps)

        self.weights = self.overlaps.minimum(self.overlaps.T)
        dotp_sub = self.overlaps - self.weights
        dotp_sub.eliminate_zeros()
        dotp_sub.data[dotp_sub.data > 0] = 1
        self.weights = self.weights.minimum(dotp_sub)
        self.weights.data *= -1

    def dump(self, obj, prefix, suffix):
        filename = prefix + "/" + \
            self.data_path.split("/")[-1].split(".")[0]
        if self.is_topic_path:
            if self.topics_label:
                filename += "-" + self.topics_label
        else:
            filename += "-" + self.topics_path 
        filename += suffix
        with open(filename, "wb") as fout:
            pickle.dump(obj, fout)

### Taxogen input file generation

In [20]:
embeddings = torch.load('../data/embeddings/better_europa_embeddings.th')

In [21]:
with open("keywords.txt", "w") as fout:
    with open("embeddings.txt", "w") as fout2:
        for x in embeddings:
            keyword = "_".join(x.split())
            fout.write(keyword + "\n")
            fout2.write(keyword + " " + " ".join([str(e) for e in embeddings[x]]) + "\n")

In [22]:
subsumption = pickle.load(open("your _subsumption.pickle", "rb"))

In [None]:
with open("papers.txt", "w") as fout:
    papers = np.unique(subsumption.counts.nonzero()[0])
    for paper in papers:
        topics = subsumption.counts[paper,:].nonzero()[1]
        fout.write(" ".join(["_".join(subsumption.features[topic].split()) for topic in topics]) + "\n")        

In [None]:
with open("doc_ids.txt", "w") as fout:
    papers = np.unique(subsumption.counts.nonzero()[0])
    i = 0
    for paper in papers:
        fout.write(str(i) + "\n")
        i += 1

In [None]:
with open("keyword_cnt.txt", "w") as fout:
    papers = np.unique(subsumption.counts.nonzero()[0])
    i = 0
    for paper in papers:
        l = str(i)
        topics = subsumption.counts[paper,:].nonzero()[1]
        for topic in topics:
            l += "\t" + "_".join(subsumption.features[topic].split()) + "\t" + "1"
        fout.write(l + "\n")
        i += 1

In [None]:
paper_index = np.unique(subsumption.counts.nonzero()[0])
paper_dict = { paper_index[i]: i for i in range(len(paper_index)) }

In [None]:
with open("index.txt", "w") as fout:
    topics = np.unique(subsumption.counts.nonzero()[1])
    for topic in topics:
        l = "_".join(subsumption.features[topic].split())
        papers = subsumption.counts[:,topic].nonzero()[0]
        l += "\t" + ",".join([str(paper_dict[paper]) for paper in papers])
        fout.write(l + "\n")