This repository has been archived by the owner on May 22, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 44
/
moder.py
119 lines (100 loc) · 4.32 KB
/
moder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from itertools import chain
import os
from bblfsh import Node
from pyspark import Row, RDD
from pyspark.sql import DataFrame
from sourced.ml.algorithms.uast_ids_to_bag import uast2sequence
from sourced.ml.transformers import Transformer
from sourced.ml.utils.engine import EngineConstants
from sourced.ml.utils import bblfsh_roles
class Moder(Transformer):
"""
Select the items to extract from UASTs.
"""
class Options:
repo = "repo"
file = "file"
function = "func"
__all__ = (file, function, repo)
USE_XPATH = os.getenv("USE_XPATH", False) in ("1", "true", "yes")
# Copied from https://github.com/src-d/hercules/blob/master/shotness.go#L40
# If you change here, please PR it to Hercules as well
FUNC_XPATH = "//*[@roleFunction and @roleDeclaration]"
FUNC_NAME_XPATH = "/*[@roleFunction and @roleIdentifier and @roleName] " \
"| /*/*[@roleFunction and @roleIdentifier and @roleName]"
def __init__(self, mode: str, **kwargs):
super().__init__(**kwargs)
self.mode = mode
def __setstate__(self, state):
super().__setstate__(state)
from bblfsh import Node, filter as filter_uast
self.parse_uast = Node.FromString
self.serialize_uast = Node.SerializeToString
self.filter_uast = filter_uast
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, value: str):
if not isinstance(value, str):
raise TypeError("mode must be a string")
if value not in self.Options.__all__:
raise ValueError("Unsupported mode: " + value)
self._mode = value
def call_repo(self, rows: RDD):
ridcol = EngineConstants.Columns.RepositoryId
uastcol = EngineConstants.Columns.Uast
return rows \
.groupBy(lambda r: r[ridcol]) \
.map(lambda x: Row(**{ridcol: x[0], EngineConstants.Columns.Path: "",
EngineConstants.Columns.BlobId: "",
uastcol: list(chain.from_iterable(i[uastcol] for i in x[1]))}))
def call_file(self, rows: RDD):
return rows
def call_func(self, rows: RDD):
return rows.flatMap(self.extract_functions_from_row)
def __call__(self, rows: DataFrame) -> RDD:
return getattr(self, "call_" + self.mode)(rows.rdd)
def extract_functions_from_row(self, row: Row):
uastbytes = row[EngineConstants.Columns.Uast]
if not uastbytes:
return
uast = self.parse_uast(uastbytes[0])
template = row.asDict()
for func, name in self.extract_functions_from_uast(uast):
data = template.copy()
data[EngineConstants.Columns.Uast] = [bytearray(self.serialize_uast(func))]
data[EngineConstants.Columns.BlobId] += "_%s:%d" % (name, func.start_position.line)
yield Row(**data)
def extract_functions_from_uast(self, uast: Node):
if self.USE_XPATH:
allfuncs = list(self.filter_uast(uast, self.FUNC_XPATH))
else:
node_seq = uast2sequence(uast)
allfuncs = [node for node in node_seq if bblfsh_roles.FUNCTION in node.roles and
bblfsh_roles.DECLARATION in node.roles]
internal = set()
for func in allfuncs:
if id(func) in internal:
continue
if self.USE_XPATH:
sub_seq = self.filter_uast(func, self.FUNC_XPATH)
else:
sub_seq = [node for node in uast2sequence(func) if
bblfsh_roles.FUNCTION in node.roles and
bblfsh_roles.DECLARATION in node.roles]
for sub in sub_seq:
if sub != func:
internal.add(id(sub))
for f in allfuncs:
if id(f) not in internal:
if self.USE_XPATH:
f_seq = self.filter_uast(f, self.FUNC_NAME_XPATH)
else:
f_seq = [node for node in uast2sequence(f) if
bblfsh_roles.FUNCTION in node.roles and
bblfsh_roles.IDENTIFIER in node.roles and
bblfsh_roles.NAME in node.roles]
name = "+".join(n.token for n in f_seq)
if name:
yield f, name