/
matchrule.jl
149 lines (142 loc) · 5.21 KB
/
matchrule.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
@doc raw"
match_rule(ixs, iy)
match_rule(code::EinCode)
Returns the rule that matches, otherwise use `DefaultRule` - the slow `loop_einsum` backend.
"
function match_rule(ixs, iy)
if length(ixs) == 1
return match_rule_unary(ixs[1], iy)
elseif length(ixs) == 2
return match_rule_binary(ixs[1], ixs[2], iy)
else
return DefaultRule()
end
end
function match_rule_unary(ix, iy)
Nx = length(ix)
Ny = length(iy)
# the first rule with the higher the priority
if Ny == 0 && Nx == 2 && ix[1] == ix[2]
return Tr()
elseif allunique(iy)
if ix == iy
return Identity()
elseif allunique(ix)
if Nx == Ny
if all(i -> i in iy, ix)
return Permutedims()
else # e.g. (abcd->bcde)
return DefaultRule()
end
else
if all(i -> i in ix, iy)
return Sum()
elseif all(i -> i in iy, ix) # e.g. ij->ijk
return Repeat()
else # e.g. ijkxc,ijkl
return DefaultRule()
end
end
else # ix is not unique
if all(i -> i in ix, iy) && all(i -> i in iy, ix) # ijjj->ij
return Diag()
else
return DefaultRule()
end
end
else # iy is not unique
if allunique(ix) && all(x->x∈iy, ix)
if all(y->y∈ix, iy) # e.g. ij->ijjj
return Duplicate()
else # e.g. ij->ijjl
return DefaultRule()
end
else
return DefaultRule()
end
end
end
match_rule(code::EinCode) = match_rule(getixs(code), getiy(code))
@inline function _add_batch(::SimpleBinaryRule{ix1,ix2,iy}) where {ix1,ix2,iy}
SimpleBinaryRule{(ix1...,'l'), (ix2...,'l'), (iy...,'l')}()
end
@inline _add_batch(::DefaultRule) = DefaultRule()
function match_rule_binary(ix1, ix2, iy)
Nx1, Nx2, Ny = length(ix1), length(ix2), length(iy)
if !_isunique(ix1) || !_isunique(ix2) || !_isunique(iy)
DefaultRule()
elseif (Nx1 + Nx2 + Ny) % 2 == 0 # no batch
_match_simple2(ix1,ix2,iy,Nx1,Nx2,Ny)
elseif Nx1>0 && Nx2>0 && Ny>0 && ix1[Nx1]==ix2[Nx2]==iy[Ny]
rule = _match_simple2(ix1,ix2,iy,Nx1-1,Nx2-1,Ny-1)
_add_batch(rule)
else
DefaultRule()
end
end
@inline function _isunique(ix)
if length(ix) <= 1
return true
elseif length(ix) == 2
return @inbounds ix[1] != ix[2]
elseif length(ix) == 3
@inbounds a, b, c = ix
return a != c && a != c && a != b
else # to default rules
return false
end
end
function _match_simple2(ix1, ix2, iy, Nx1, Nx2, Ny)
if Nx1==0
if (Ny==Nx2==0)
return SimpleBinaryRule((), (), ())
elseif (Ny==Nx2==1 && ix2[1] == iy[1])
return SimpleBinaryRule((), ('k',), ('k',))
end
elseif Nx1==1
if (Nx2==0 && Ny==1 && iy[1]==ix1[1])
return SimpleBinaryRule(('i',), (), ('i',))
elseif (Nx2==1 && Ny==0 && ix1[1]==ix2[1])
return SimpleBinaryRule(('j',), ('j',), ())
elseif Nx2==1 && Ny==2
if (iy[1]==ix1[1] && iy[2]==ix2[1])
return SimpleBinaryRule(('i',), ('k',), ('i','k'))
elseif iy[1]==ix2[1] && iy[2]==ix1[1]
return SimpleBinaryRule(('i',), ('k',), ('k','i'))
end
elseif Nx2==2 && Ny==1
if ix2[1]==ix1[1] && ix2[2]==iy[1]
return SimpleBinaryRule(('j',), ('j','k'), ('k',))
elseif ix2[1]==iy[1] && ix2[2]==ix1[1]
return SimpleBinaryRule(('j',), ('k','j'), ('k',))
end
end
elseif Nx1==2
if Nx2==1 && Ny==1
if ix1[1]==ix2[1] && ix1[2]==iy[1]
return SimpleBinaryRule(('j','i'), ('j',), ('i',))
elseif ix1[1]==iy[1] && ix1[2]==ix2[1]
return SimpleBinaryRule(('i','j'), ('j',), ('i',))
end
elseif (Nx2==2 && Ny==2)
if ix1[1]==ix2[1] && ix1[2]==iy[1] && ix2[2]==iy[2]
return SimpleBinaryRule(('j','i'), ('j','k'), ('i','k'))
elseif ix1[1]==ix2[2] && ix1[2]==iy[1] && ix2[1]==iy[2]
return SimpleBinaryRule(('j','i'), ('k','j'), ('i','k'))
elseif ix1[1]==ix2[1] && ix1[2]==iy[2] && ix2[2]==iy[1]
return SimpleBinaryRule(('j','i'), ('j','k'), ('k','i'))
elseif ix1[1]==ix2[2] && ix1[2]==iy[2] && ix2[1]==iy[1]
return SimpleBinaryRule(('j','i'), ('k','j'), ('k','i'))
elseif ix1[2]==ix2[1] && ix1[1]==iy[1] && ix2[2]==iy[2]
return SimpleBinaryRule(('i','j'), ('j','k'), ('i','k'))
elseif ix1[2]==ix2[2] && ix1[1]==iy[1] && ix2[1]==iy[2]
return SimpleBinaryRule(('i','j'), ('k','j'), ('i','k'))
elseif ix1[2]==ix2[1] && ix1[1]==iy[2] && ix2[2]==iy[1]
return SimpleBinaryRule(('i','j'), ('j','k'), ('k','i'))
elseif ix1[2]==ix2[2] && ix1[1]==iy[2] && ix2[1]==iy[1]
return SimpleBinaryRule(('i','j'), ('k','j'), ('k','i'))
end
end
end
return DefaultRule()
end