-
Notifications
You must be signed in to change notification settings - Fork 0
/
matrix2x2.jl
115 lines (100 loc) · 3.24 KB
/
matrix2x2.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Matrix2x2toArray(arr)
Convert Julia's 2x2 matrix to C-ordered 4-array, paying special attention to column major
ordering of Julia versus row major ordered of C. Only used in tests.
"""
function Matrix2x2toArray4(arr::Array{T, N}) where {T, N}
s = size(arr)
if N < 2 || s[1] != 2 || s[2] != 2
throw(ArgumentError("size(arr) must equal (2, 2, ...)"))
end
if N == 2
return reshape(
permutedims(arr, (2, 1)), 4
)
else
rest = s[3:end]
return reshape(
permutedims(arr, (2, 1, 3:N...)), 4, rest...
)
end
end
"""
Array4toMatrix2x2(arr)
Convert C-ordered 4-array to Julia's 2x2 matrix, paying special attention to column major
ordering of Julia versus row major ordered of C. Only used in tests.
"""
function Array4toMatrix2x2(arr::Array{T, N}) where {T, N}
s = size(arr)
if N < 1 || s[1] != 4
throw(ArgumentError("size(arr) must equal (4, ...)"))
end
if N == 1
return permutedims(
reshape(arr, 2, 2), (2, 1)
)
else
rest = s[2:end]
return permutedims(
reshape(arr, 2, 2, rest...), (2, 1, 3:(N+1)...)
)
end
end
# We use our own, hardcoded in-place matrix multiplications below, as these
# are faster since we *know* these are 2x2 matrices, plus we incorporate
# the adjoint conjugate into the equation which avoids allocations.
@inline @inbounds function AxB!(C, A, B)
C[1] = A[1] * B[1] + A[2] * B[3]
C[2] = A[1] * B[2] + A[2] * B[4]
C[3] = A[3] * B[1] + A[4] * B[3]
C[4] = A[3] * B[2] + A[4] * B[4]
end
@inline @inbounds function AxBH!(C, A, B)
C[1] = A[1] * conj(B[1]) + A[2] * conj(B[2])
C[2] = A[1] * conj(B[3]) + A[2] * conj(B[4])
C[3] = A[3] * conj(B[1]) + A[4] * conj(B[2])
C[4] = A[3] * conj(B[3]) + A[4] * conj(B[4])
end
@inline @inbounds function AHxB!(C, A, B)
C[1] = conj(A[1]) * B[1] + conj(A[3]) * B[3]
C[2] = conj(A[1]) * B[2] + conj(A[3]) * B[4]
C[3] = conj(A[2]) * B[1] + conj(A[4]) * B[3]
C[4] = conj(A[2]) * B[2] + conj(A[4]) * B[4]
end
@inline @inbounds function plusAxB!(C, A, B)
C[1] += A[1] * B[1] + A[2] * B[3]
C[2] += A[1] * B[2] + A[2] * B[4]
C[3] += A[3] * B[1] + A[4] * B[3]
C[4] += A[3] * B[2] + A[4] * B[4]
end
@inline @inbounds function plusAxBH!(C, A, B)
C[1] += A[1] * conj(B[1]) + A[2] * conj(B[2])
C[2] += A[1] * conj(B[3]) + A[2] * conj(B[4])
C[3] += A[3] * conj(B[1]) + A[4] * conj(B[2])
C[4] += A[3] * conj(B[3]) + A[4] * conj(B[4])
end
@inline @inbounds function plusAHxB!(C, A, B)
C[1] += conj(A[1]) * B[1] + conj(A[3]) * B[3]
C[2] += conj(A[1]) * B[2] + conj(A[3]) * B[4]
C[3] += conj(A[2]) * B[1] + conj(A[4]) * B[3]
C[4] += conj(A[2]) * B[2] + conj(A[4]) * B[4]
end
@inline function AdivB!(C, A, B)
f = B[1] * B[4] - B[2] * B[3]
if f == 0
throw(SingularException(0))
end
C[1] = (A[1] * B[4] - A[2] * B[3]) / f
C[2] = (A[2] * B[1] - A[1] * B[2]) / f
C[3] = (A[3] * B[4] - A[4] * B[3]) / f
C[4] = (A[4] * B[1] - A[3] * B[2]) / f
end
@inline function invA!(A)
f = A[1] * A[4] - A[2] * A[3]
if f == 0
throw(SingularException(0))
end
A[1], A[4] = A[4] / f, A[1] / f
A[2] = -A[2] / f
A[3] = -A[3] / f
end