Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] make tmsc work with git-based modelforge and sourced.ml #12

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
sourced-ml==0.5.1
ast2vec>=0.3.8-alpha
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
"console_scripts": ["tmsc=tmsc.__main__:main"],
},
keywords=["machine learning on source code", "topic modeling",
"github", "bblfsh", "babelfish", "ast2vec"],
install_requires=["ast2vec>=0.3.8-alpha"],
"github", "bblfsh", "babelfish"],
install_requires=["sourced-ml>=0.5.1", "ast2vec>=0.3.8-alpha"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ast2vec is dead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, was kept just as a hack to satisfy internal imports 😕

package_data={"": ["LICENSE.md", "README.md"]},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down
30 changes: 19 additions & 11 deletions tmsc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import json
import logging
import sys
import os

from ast2vec import Topics, DocumentFrequencies, DEFAULT_BBLFSH_TIMEOUT
from ast2vec.bow import BOWBase
from sourced.ml.models import BOW, Topics, DocumentFrequencies
from modelforge.backends import create_backend
from modelforge.index import GitIndex

from tmsc.environment import initialize
from tmsc.topic_detector import TopicDetector

DEFAULT_BBLFSH_TIMEOUT = 20

def main():
parser = argparse.ArgumentParser()
Expand All @@ -33,6 +35,10 @@ def main():
parser.add_argument("--prune-df", default=20, type=int,
help="Minimum number of times an identifer must occur in different "
"documents to be taken into account.")
parser.add_argument("--index_repo", default="https://github.com/src-d/models",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are normally hardcoded in modelforgecfg.py which exists in sourced-ml. Users should be abstracted from these details.

I think that

args.topics = Topics(log_level=args.log_level).load(source=args.topics)

will work - the backend will be created automatically. If it doesn't then there is a bug somewhere.

Copy link
Author

@bzz bzz Aug 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a 🐛 confirmed

import logging
from sourced.ml.models import Topics
topics = Topics(log_level=logging.INFO).load(source=None)

works as expected only if there is warm cache already in ~/.source{d}/topics

In case if the cache is empty, the very same code results in

Traceback (most recent call last):
  File "./test_tm.py", line 6, in <module>
    topics = Topics(log_level=logging.INFO).load(source=None)
  File "~/src-d/tmsc/.venv3/lib/python3.6/site-packages/modelforge/model.py", line 82, in load
    raise ValueError("The backend must be set to load a UUID or the default "
ValueError: The backend must be set to load a UUID or the default model.

help="Models index repository.")
parser.add_argument("--index_cache", default=os.path.join(BOW.cache_dir(), "models"),
help="Local cache of models index repository")
parser.add_argument("-n", "--nnn", default=10, type=int,
help="Number of topics to print.")
parser.add_argument("-f", "--format", default="human", choices=["json", "human"],
Expand All @@ -42,24 +48,26 @@ def main():
if args.linguist is None:
args.linguist = "./enry"
initialize(args.log_level, enry=args.linguist)

if args.gcs:
backend = create_backend(args="bucket=" + args.gcs)
else:
backend = create_backend()
if args.topics is not None:
args.topics = Topics(log_level=args.log_level).load(source=args.topics, backend=backend)
if args.df is not None:
args.df = DocumentFrequencies(log_level=args.log_level).load(
source=args.df, backend=backend)
if args.bow is not None:
args.bow = BOWBase(log_level=args.log_level).load(source=args.bow, backend=backend)
git_index = GitIndex(index_repo=args.index_repo, cache=args.index_cache, log_level=args.log_level)
backend = create_backend(git_index=git_index)

args.topics = Topics(log_level=args.log_level).load(source=args.topics, backend=backend) #source=args.topics
args.df = DocumentFrequencies(log_level=args.log_level).load(source=args.df, backend=backend)
#args.bow = BOW(log_level=args.log_level).load(source=args.bow, backend=backend)

sr = TopicDetector(
topics=args.topics, docfreq=args.df, bow=args.bow, verbosity=args.log_level,
prune_df_threshold=args.prune_df, gcs_bucket=args.gcs, repo2bow_kwargs={
prune_df_threshold=args.prune_df, repo2bow_kwargs={
"linguist": args.linguist,
"bblfsh_endpoint": args.bblfsh,
"timeout": args.timeout})

topics = sr.query(args.input, size=args.nnn)

if args.format == "json":
json.dump({"repository": args.input, "topics": topics}, sys.stdout)
elif args.format == "human":
Expand Down
5 changes: 2 additions & 3 deletions tmsc/environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from modelforge.logs import setup_logging
from ast2vec import ensure_bblfsh_is_running_noexc, install_enry


__initialized__ = False
Expand All @@ -22,6 +21,6 @@ def initialize(log_level=logging.INFO, enry="./enry"):
if __initialized__:
return
setup_logging(log_level)
ensure_bblfsh_is_running_noexc()
install_enry(target=enry, warn_exists=False)
# ensure_bblfsh_is_running_noexc()
# install_enry(target=enry, warn_exists=False)
__initialized__ = True
95 changes: 43 additions & 52 deletions tmsc/topic_detector.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import logging
import sys
import re

from ast2vec import Topics, Repo2Base, DocumentFrequencies
from ast2vec.bow import BOWBase
from ast2vec.model2.uast2bow import Uasts2BOW
from modelforge.backends import create_backend
# compatibility with old ast2vec version that depends on old modelforge
sys.modules["modelforge.generate_meta"] = None
sys.modules["modelforge.model.write_model"] = None

import numpy
from scipy.sparse import csr_matrix
from ast2vec.repo2.base import Repo2Base
from sourced.ml.models import BOW, Topics, DocumentFrequencies

from tmsc.environment import initialize

from tmsc.uast2bow import Uasts2BOW

class Repo2BOW(Repo2Base):
"""
Implements the step repository -> :class:`ast2vec.nbow.NBOW`.
"""
MODEL_CLASS = BOWBase
MODEL_CLASS = BOW

def __init__(self, vocabulary, docfreq, **kwargs):
super().__init__(**kwargs)
Expand All @@ -31,84 +34,72 @@ class TopicDetector:
r"(https://|ssh://git@|git://)(github.com/[^/]+/[^/]+)(|.git|/)")

def __init__(self, topics=None, docfreq=None, bow=None, verbosity=logging.DEBUG,
prune_df_threshold=1, gcs_bucket=None, initialize_environment=True,
repo2bow_kwargs=None):
if initialize_environment:
initialize()
prune_df_threshold=1, repo2bow_kwargs=None):

self._log = logging.getLogger("topic_detector")
self._log.setLevel(verbosity)
if gcs_bucket:
backend = create_backend(args="bucket=" + gcs_bucket)
else:
backend = create_backend()
if topics is None:
self._topics = Topics(log_level=verbosity).load(backend=backend)
else:
assert isinstance(topics, Topics)
self._topics = topics

if not topics:
raise ValueError("Please provide a Topic model")
assert isinstance(topics, Topics)
self._topics = topics
self._log.info("Loaded topics model: %s", self._topics)
if docfreq is None:
if docfreq is not False:
self._docfreq = DocumentFrequencies(log_level=verbosity).load(
source=self._topics.dep("docfreq")["uuid"], backend=backend)
else:
self._docfreq = None
self._log.warning("Disabled document frequencies - you will "
"not be able to query custom repositories.")

if not docfreq:
self._docfreq = None
self._log.warning("Disabled document frequencies - you will "
"not be able to query arbitrary repositories.")
self._repo2bow = None
else:
assert isinstance(docfreq, DocumentFrequencies)
self._docfreq = docfreq
if self._docfreq is not None:
self._docfreq = self._docfreq.prune(prune_df_threshold)
self._log.info("Loaded docfreq model: %s", self._docfreq)
if bow is not None:
assert isinstance(bow, BOWBase)
self._log.info("Loaded docfreq model: %s", self._docfreq)
self._repo2bow = Repo2BOW(
{t: i for i, t in enumerate(self._topics.tokens)}, self._docfreq,
**(repo2bow_kwargs or {}))

if not bow:
self._bow = None
self._log.warning("No BOW cache was loaded.")
else:
assert isinstance(bow, BOW)
self._bow = bow
if self._topics.matrix.shape[1] != self._bow.matrix.shape[1]:
raise ValueError("Models do not match: topics has %s tokens while bow has %s" %
(self._topics.matrix.shape[1], self._bow.matrix.shape[1]))
self._log.info("Attached BOW model: %s", self._bow)
else:
self._bow = None
self._log.warning("No BOW cache was loaded.")
if self._docfreq is not None:
self._repo2bow = Repo2BOW(
{t: i for i, t in enumerate(self._topics.tokens)}, self._docfreq,
**(repo2bow_kwargs or {}))
else:
self._repo2bow = None

def query(self, url_or_path_or_name, size=5):
if size > len(self._topics):
raise ValueError("size may not be greater than the number of topics - %d" %
len(self._topics))
if self._bow is not None:
token_vector = None
if self._bow:
try:
repo_index = self._bow.repository_index_by_name(
repo_index = self._bow.documents_index(
url_or_path_or_name)
except KeyError:
repo_index = -1
if repo_index == -1:
match = self.GITHUB_URL_RE.match(url_or_path_or_name)
if match is not None:
if match:
name = match.group(2)
try:
repo_index = self._bow.repository_index_by_name(name)
repo_index = self._bow.documents_index(name)
except KeyError:
pass
else:
repo_index = -1
if repo_index >= 0:
token_vector = self._bow.matrix[repo_index]
else:
if self._docfreq is None:
if repo_index:
token_vector = self._bow.matrix[repo_index]

if not token_vector:
if not self._docfreq:
raise ValueError("You need to specify document frequencies model to process "
"custom repositories")
bow_dict = self._repo2bow.convert_repository(url_or_path_or_name)
token_vector = numpy.zeros(self._topics.matrix.shape[1], dtype=numpy.float32)
for i, v in bow_dict.items():
token_vector[i] = v
token_vector = csr_matrix(token_vector)

topic_vector = -numpy.squeeze(self._topics.matrix.dot(token_vector.T).toarray())
order = numpy.argsort(topic_vector)
result = []
Expand Down
55 changes: 55 additions & 0 deletions tmsc/uast2bow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import defaultdict
import marshal
import math
import types

from sourced.ml.models import BOW, DocumentFrequencies
from sourced.ml.algorithms.uast_ids_to_bag import UastIds2Bag

class Uasts2BOW:
def __init__(self, vocabulary: dict, docfreq: DocumentFrequencies,
getter: callable):
self._docfreq = docfreq
self._uast2bag = UastIds2Bag(vocabulary) #TODO replace with sourced.ml.
self._reverse_vocabulary = [None] * len(vocabulary)
for key, val in vocabulary.items():
self._reverse_vocabulary[val] = key
self._getter = getter

@property
def vocabulary(self):
return self._uast2bag.token2index #.vocabulary

@property
def docfreq(self):
return self._docfreq

def __call__(self, file_uast_generator):
freqs = defaultdict(int)
for file_uast in file_uast_generator:
bag = self._uast2bag(self._getter(file_uast)) #.uast_to_bag
for key, freq in bag.items():
freqs[key] += freq
missing = []
for key, val in freqs.items():
try:
freqs[key] = math.log(1 + val) * math.log(
self._docfreq.docs / self._docfreq[self._reverse_vocabulary[key]])
except KeyError:
missing.append(key)
for key in missing:
del freqs[key]
return dict(freqs)

def __getstate__(self):
state = self.__dict__.copy()
if isinstance(self._getter, types.FunctionType) \
and self._getter.__name__ == (lambda: None).__name__:
assert self._getter.__closure__ is None
state["_getter"] = marshal.dumps(self._getter.__code__)
return state

def __setstate__(self, state):
self.__dict__ = state
if isinstance(self._getter, bytes):
self._getter = types.FunctionType(marshal.loads(self._getter), globals())