-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtape.jl
121 lines (99 loc) · 3.75 KB
/
tape.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#######################
# AbstractInstruction #
#######################
abstract type AbstractInstruction end
const InstructionTape = Vector{AbstractInstruction}
function record!(tp::InstructionTape, ::Type{InstructionType}, args...) where InstructionType
tp !== NULL_TAPE && push!(tp, InstructionType(args...))
return nothing
end
function Base.:(==)(a::AbstractInstruction, b::AbstractInstruction)
return (a.func == b.func &&
a.input == b.input &&
a.output == b.output &&
a.cache == b.cache)
end
# Ensure that the external state is "captured" so that external
# reference-breaking (e.g. destructive assignment) doesn't break
# internal instruction state. By default, `capture` is a no-op.
@inline capture(state) = state
@inline capture(state::Tuple) = map(capture, state)
# ScalarInstruction #
#-------------------#
struct ScalarInstruction{F,I,O,C} <: AbstractInstruction
func::F
input::I
output::O
cache::C
# disable default outer constructor
function ScalarInstruction{F,I,O,C}(func, input, output, cache) where {F,I,O,C}
return new{F,I,O,C}(func, input, output, cache)
end
end
@inline function _ScalarInstruction(func::F, input::I, output::O, cache::C) where {F,I,O,C}
return ScalarInstruction{F,I,O,C}(func, input, output, cache)
end
function ScalarInstruction(func, input, output, cache = nothing)
return _ScalarInstruction(func, capture(input), capture(output), cache)
end
# SpecialInstruction #
#--------------------#
struct SpecialInstruction{F,I,O,C} <: AbstractInstruction
func::F
input::I
output::O
cache::C
# disable default outer constructor
function SpecialInstruction{F,I,O,C}(func, input, output, cache) where {F,I,O,C}
return new{F,I,O,C}(func, input, output, cache)
end
end
@inline function _SpecialInstruction(func::F, input::I, output::O, cache::C) where {F,I,O,C}
return SpecialInstruction{F,I,O,C}(func, input, output, cache)
end
function SpecialInstruction(func, input, output, cache = nothing)
return _SpecialInstruction(func, capture(input), capture(output), cache)
end
##########
# passes #
##########
function forward_pass!(tape::InstructionTape)
for instruction in tape
forward_exec!(instruction)
end
return nothing
end
@noinline forward_exec!(instruction::ScalarInstruction) = scalar_forward_exec!(instruction)
@noinline forward_exec!(instruction::SpecialInstruction) = special_forward_exec!(instruction)
function reverse_pass!(tape::InstructionTape)
for i in length(tape):-1:1
reverse_exec!(tape[i])
end
return nothing
end
@noinline reverse_exec!(instruction::ScalarInstruction) = scalar_reverse_exec!(instruction)
@noinline reverse_exec!(instruction::SpecialInstruction) = special_reverse_exec!(instruction)
###################
# Pretty Printing #
###################
# extra spaces here accomodates padding in show(::IO, ::AbstractInstruction)
compactrepr(x::Tuple) = "("*join(map(compactrepr, x), ",\n ")*")"
compactrepr(x::AbstractArray) = length(x) < 5 ? match(r"\[.*?\]", repr(x)).match : summary(x)
compactrepr(x) = repr(x)
function Base.show(io::IO, instruction::AbstractInstruction, pad = "")
name = isa(instruction, ScalarInstruction) ? "ScalarInstruction" : "SpecialInstruction"
println(io, pad, "$(name)($(instruction.func)):")
println(io, pad, " input: ", compactrepr(instruction.input))
println(io, pad, " output: ", compactrepr(instruction.output))
print(io, pad, " cache: ", compactrepr(instruction.cache))
end
function Base.show(io::IO, tp::InstructionTape)
println(io, length(tp), "-element InstructionTape:")
i = 1
for instruction in tp
print(io, "$i => ")
show(io, instruction)
println(io)
i += 1
end
end