/
julia_vm.py
319 lines (264 loc) · 9.63 KB
/
julia_vm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
from restrain_jit.bejulia.instructions import *
from restrain_jit.bejulia.representations import *
from restrain_jit.bejulia.jl_protocol import bridge, Aware
from restrain_jit.jit_info import PyCodeInfo, PyFuncInfo
from restrain_jit.abs_compiler import instrnames as InstrNames
from restrain_jit.abs_compiler.from_bc import Interpreter
from restrain_jit.vm.am import AM, run_machine
from dataclasses import dataclass
from bytecode import Bytecode, ControlFlowGraph, Instr as PyInstr, CellVar, CompilerFlags
import typing as t
import types
import sys
def load_arg(x, cellvars, lineno):
if x in cellvars:
return PyInstr(InstrNames.LOAD_DEREF, CellVar(x), lineno=lineno)
return PyInstr(InstrNames.LOAD_FAST, x, lineno=lineno)
def copy_func(f: types.FunctionType):
# noinspection PyArgumentList
nf = types.FunctionType(f.__code__, f.__globals__, None, None,
f.__closure__)
nf.__defaults__ = f.__defaults__
nf.__name__ = f.__name__
nf.__qualname__ = f.__qualname__
nf.__module__ = f.__module__
nf.__kwdefaults__ = f.__kwdefaults__
nf.__annotations__ = f.__annotations__
nf.__dict__ = f.__dict__
return nf
@dataclass
class JuVM(AM[Instr, Repr]):
def set_lineno(self, lineno: int):
# TODO
pass
def get_module(self) -> types.ModuleType:
return self.module
def require_global(self, s: str):
self.globals.add(s)
@classmethod
def func_info(cls, func: types.FunctionType) -> types.FunctionType:
names = func.__code__.co_names
code = Bytecode.from_code(func.__code__)
codeinfo = cls.code_info(code)
def r_compile():
jit_func = Aware.f(self)
print("jit_func", type(jit_func))
bc = Bytecode()
bc.append(PyInstr(InstrNames.LOAD_CONST, jit_func))
bc.extend(
[load_arg(each, cellvars, lineno) for each in argnames])
bc.extend([
PyInstr(InstrNames.CALL_FUNCTION, len(argnames)),
PyInstr(InstrNames.RETURN_VALUE)
])
bc._copy_attr_from(code)
start_func.__code__ = bc.to_code()
start_func.__jit__ = jit_func
return jit_func
start_func = copy_func(func)
start_func_code = Bytecode()
lineno = code.first_lineno
argnames = code.argnames
start_func_code.argnames = argnames
cellvars = code.cellvars
start_func_code.extend([
PyInstr(InstrNames.LOAD_CONST, r_compile, lineno=lineno),
PyInstr(InstrNames.CALL_FUNCTION, 0, lineno=lineno),
*(load_arg(each, cellvars, lineno) for each in argnames),
PyInstr(InstrNames.CALL_FUNCTION,
len(argnames),
lineno=lineno),
PyInstr(InstrNames.RETURN_VALUE, lineno=lineno)
])
start_func_code._copy_attr_from(code)
self = PyFuncInfo(func.__name__, func.__module__,
func.__defaults__, func.__kwdefaults__,
func.__closure__, func.__globals__, codeinfo,
func, {}, names)
start_func.__code__ = start_func_code.to_code()
start_func.__func_info__ = self
start_func.__compile__ = r_compile
start_func.__jit__ = None
return start_func
@classmethod
def code_info(cls, code: Bytecode) -> PyCodeInfo[Repr]:
cfg = ControlFlowGraph.from_bytecode(code)
current = cls.empty()
run_machine(
Interpreter(code.first_lineno).abs_i_cfg(cfg), current)
glob_deps = tuple(current.globals)
instrs = current.instrs
instrs = current.pass_push_pop_inline(instrs)
return PyCodeInfo(code.name, tuple(glob_deps), code.argnames,
code.freevars, code.cellvars, code.filename,
code.first_lineno, code.argcount,
code.kwonlyargcount,
bool(code.flags & CompilerFlags.GENERATOR),
bool(code.flags & CompilerFlags.VARKEYWORDS),
bool(code.flags & CompilerFlags.VARARGS),
instrs)
def pop_exception(self, must: bool) -> Repr:
name = self.alloc()
self.add_instr(name, PopException(must))
return Reg(name)
def meta(self) -> dict:
return self._meta
def last_block_end(self) -> str:
return self.end_label
def push_block(self, end_label: str) -> None:
self.blocks.append((end_label, []))
def pop_block(self) -> Repr:
end_label, instrs = self.blocks.pop()
regname = self.alloc()
instr = UnwindBlock(instrs)
self.add_instr(regname, instr)
return Reg(regname)
def from_const(self, val: Repr) -> object:
assert isinstance(val, Const)
return val.val
def ret(self, val: Repr):
return self.add_instr(None, Return(val))
def const(self, val: object):
return Const(val)
@classmethod
def reg_of(cls, n: str):
return Reg(n)
def from_higher(self, qualifier: str, name: str):
regname = self.alloc()
self.add_instr(regname, PyGlob(qualifier, name))
return Reg(regname)
def from_lower(self, qualifier: str, name: str):
regname = self.alloc()
self.add_instr(regname, JlGlob(qualifier, name))
return Reg(regname)
def app(self, f: Repr, args: t.List[Repr]) -> Repr:
name = self.alloc()
reg = Reg(name)
self.add_instr(name, App(f, args))
return reg
def store(self, n: str, val: Repr):
self.add_instr(None, Store(Reg(n), val))
def load(self, n: str) -> Repr:
r = Reg(n)
name = self.alloc()
self.add_instr(name, Load(r))
return Reg(name)
def assign(self, n: str, v: Repr):
self.add_instr(None, Ass(Reg(n), v))
def peek(self, n: int):
try:
return self.st[-n - 1]
except IndexError:
name = self.alloc()
self.add_instr(name, Peek(n))
return name
def jump(self, n: str):
self.add_instr(None, Jmp(n))
def jump_if_push(self, n: str, cond: Repr, leave: Repr):
self.add_instr(None, JmpIfPush(n, cond, leave))
def jump_if(self, n: str, cond: Repr):
self.add_instr(None, JmpIf(n, cond))
def label(self, n: str) -> None:
self.st.clear()
self.add_instr(None, Label(n))
def push(self, r: Repr) -> None:
self.st.append(r)
self.add_instr(None, Push(r))
def pop(self) -> Repr:
try:
a = self.st.pop()
self.add_instr(None, Pop())
except IndexError:
name = self.alloc()
self.add_instr(name, Pop())
a = Reg(name)
return a
def release(self, name: Repr):
"""
release temporary variable
"""
if not isinstance(name, Reg):
return
name = name.n
if name in self.used:
self.used.remove(name)
self.unused.add(name)
def alloc(self):
"""
allocate a new temporary variable
"""
if self.unused:
return self.unused.pop()
tmp_name = f"tmp-{len(self.used)}"
self.used.add(tmp_name)
return tmp_name
def add_instr(self, tag, instr: Instr):
self.instrs.append(A(tag, instr))
return None
_meta: dict
# stack
st: t.List[Repr]
# instructions
blocks: t.List[t.Tuple[t.Optional[str], t.List[A]]]
# allocated temporary
used: t.Set[str]
unused: t.Set[str]
globals: t.Set[str]
module: types.ModuleType
@property
def instrs(self):
return self.blocks[-1][1]
@property
def end_label(self) -> t.Optional[str]:
return self.blocks[-1][0]
@classmethod
def pass_push_pop_inline(cls, instrs):
blacklist = set()
i = 0
while True:
try:
assign = instrs[i]
k, v = assign.lhs, assign.rhs
except IndexError:
break
if isinstance(v, UnwindBlock):
v.instrs = cls.pass_push_pop_inline(v.instrs)
if k is None and isinstance(v, Pop):
j = i - 1
while True:
assign = instrs[j]
k, v = assign.lhs, assign.rhs
if k is None and isinstance(v, Push):
try:
assign = instrs[i]
k, v = assign.lhs, assign.rhs
except IndexError:
break
if k is None and isinstance(v, Pop):
pass
else:
break
blacklist.add(j)
blacklist.add(i)
i += 1
j -= 1
try:
assign = instrs[j]
k, v = assign.lhs, assign.rhs
except IndexError:
break
if k is None and isinstance(v, Push):
continue
break
else:
i += 1
break
else:
i = i + 1
return [
each for i, each in enumerate(instrs) if i not in blacklist
]
@classmethod
def empty(cls, module=None):
return cls({}, [], [(None, [])], set(), set(), set(), module
or sys.modules[cls.__module__])