/
expand_imports.py
164 lines (138 loc) · 5.24 KB
/
expand_imports.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
""" ExpandImports replaces imports by their full paths. """
from pythran.passmanager import Transformation
from pythran.utils import path_to_attr, path_to_node
from pythran.conversion import mangle
from pythran.syntax import PythranSyntaxError
from pythran.analyses import Ancestors
import gast as ast
class ExpandImports(Transformation):
"""
Expands all imports into full paths.
Attributes
----------
imports : {str}
Imported module (python base module name)
symbols : {str : (str,)}
Matching between used name and real cxx name.
Examples
--------
>>> import gast as ast
>>> from pythran import passmanager, backend
>>> node = ast.parse("from math import cos ; cos(2)")
>>> pm = passmanager.PassManager("test")
>>> _, node = pm.apply(ExpandImports, node)
>>> print(pm.dump(backend.Python, node))
import math as __pythran_import_math
__pythran_import_math.cos(2)
>>> node = ast.parse("from os.path import join ; join('a', 'b')")
>>> _, node = pm.apply(ExpandImports, node)
>>> print(pm.dump(backend.Python, node))
import os as __pythran_import_os
__pythran_import_os.path.join('a', 'b')
"""
def __init__(self):
super(ExpandImports, self).__init__(Ancestors)
self.imports = set()
self.symbols = dict()
def visit_Module(self, node):
"""
Visit the whole module and add all import at the top level.
>> import numpy.linalg
Becomes
>> import numpy
"""
node.body = [k for k in (self.visit(n) for n in node.body) if k]
imports = [ast.Import([ast.alias(i, mangle(i))]) for i in self.imports]
node.body = imports + node.body
ast.fix_missing_locations(node)
return node
def visit_Import(self, node):
""" Register imported modules and usage symbols. """
for alias in node.names:
alias_name = tuple(alias.name.split('.'))
self.imports.add(alias_name[0])
if alias.asname:
self.symbols[alias.asname] = alias_name
else:
self.symbols[alias_name[0]] = alias_name[:1]
self.update = True
return None
def visit_ImportFrom(self, node):
""" Register imported modules and usage symbols. """
module_path = tuple(node.module.split('.'))
self.imports.add(module_path[0])
for alias in node.names:
path = module_path + (alias.name,)
self.symbols[alias.asname or alias.name] = path
self.update = True
return None
def visit_FunctionDef(self, node):
"""
Update import context using overwriting name information.
Examples
--------
>> import foo
>> import bar
>> def foo(bar):
>> print(bar)
In this case, neither bar nor foo can be used in the foo function and
in future function, foo will not be usable.
"""
self.symbols.pop(node.name, None)
gsymbols = self.symbols.copy()
[self.symbols.pop(arg.id, None) for arg in node.args.args]
self.generic_visit(node)
self.symbols = gsymbols
return node
def visit_Assign(self, node):
"""
Update import context using overwriting name information.
Examples
--------
>> import foo
>> def bar():
>> foo = 2
>> print(foo)
In this case, foo can't be used after assign.
"""
if isinstance(node.value, ast.Name) and node.value.id in self.symbols:
symbol = path_to_node(self.symbols[node.value.id])
if not getattr(symbol, 'isliteral', lambda: False)():
for target in node.targets:
if not isinstance(target, ast.Name):
err = "Unsupported module aliasing"
raise PythranSyntaxError(err, target)
self.symbols[target.id] = self.symbols[node.value.id]
return None # this assignment is no longer needed
new_node = self.generic_visit(node)
# no problem if targets contains a subscript, it is not a new assign.
[self.symbols.pop(t.id, None)
for t in new_node.targets if isinstance(t, ast.Name)]
return new_node
def visit_Name(self, node):
"""
Replace name with full expanded name.
Examples
--------
>> from numpy.linalg import det
>> det(a)
Becomes
>> numpy.linalg.det(a)
"""
if node.id in self.symbols:
symbol = path_to_node(self.symbols[node.id])
if not getattr(symbol, 'isliteral', lambda: False)():
parent = self.ancestors[node][-1]
blacklist = (ast.Tuple,
ast.List,
ast.Set,
ast.Return)
if isinstance(parent, blacklist):
raise PythranSyntaxError(
"Unsupported module identifier manipulation",
node)
new_node = path_to_attr(self.symbols[node.id])
new_node.ctx = node.ctx
ast.copy_location(new_node, node)
return new_node
return node