/
blockspectralmatrix.jl
184 lines (149 loc) Β· 5.95 KB
/
blockspectralmatrix.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# a contiguous matrix with some ell information slapped on for safety
struct BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA} <: AbstractArray{T,2}
parent::AA
m_ells::NTuple{M_BLOCKS, UnitRange{Int64}}
n_ells::NTuple{N_BLOCKS, UnitRange{Int64}}
end
@forward BlockSpectralMatrix.parent (Base.getindex, Base.setindex, Base.setindex!,
Base.size, Base.iterate, Base.axes, Base.show, Base.in, Base.strides, Base.elsize,
Base.dataids)
Base.parent(A::BlockSpectralMatrix) = A.parent
function Base.hvcat(rows::Tuple{Vararg{Int}}, values::SA...) where {T, SA<:SpectralArray{T}}
π = hvcat(rows, map(A -> parent(A), values)...)
m_blocks = length(rows)
n_blocks = first(rows)
m_ells = map(A->axes(A,1), values[1:m_blocks])
n_ells = map(A->axes(A,2), values[1:n_blocks:end])
return BlockSpectralMatrix{T, m_blocks, n_blocks, typeof(π)}(π, m_ells, n_ells)
end
function Base.vcat(values::SA...) where {T, SA<:SpectralArray{T}}
π = vcat(map(A -> parent(A), values)...)
m_blocks = length(values)
n_blocks = 1
m_ells = map(A->axes(A,1), values)
n_ells = (axes(first(values),2),)
return BlockSpectralMatrix{T, m_blocks, n_blocks, typeof(π)}(
π, m_ells, n_ells)
end
function Base.hcat(values::SA...) where {T, SA<:SpectralArray{T}}
π = hcat(map(A -> parent(A), values)...)
m_blocks = 1
n_blocks = length(values)
m_ells = (axes(first(values),1),)
n_ells = map(A->axes(A,2), values)
return BlockSpectralMatrix{T, m_blocks, n_blocks, typeof(π)}(
π, m_ells, n_ells)
end
function Base.show(io::IO, m::MIME"text/plain",
ba::BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS}) where {T,M_BLOCKS,N_BLOCKS}
print(io, "BlockSpectralMatrix (column blocks=$(M_BLOCKS), row blocks=$(N_BLOCKS))\n")
print(io, "row β: ")
Base.show(io, m, ba.m_ells)
print(io, "\ncolumn β: ")
Base.show(io, m, ba.n_ells)
print(io, "\nparent: ")
nrows, ncols = displaysize(io)
Base.show(
IOContext(io, :compact => true, :displaysize => (min(10,nrows),ncols)),
m, parent(ba))
end
struct BlockSpectralFactorization{T,M_BLOCKS,N_BLOCKS,AA,
F<:Factorization{T}} <: Factorization{T}
parent::F
m_ells::NTuple{M_BLOCKS, UnitRange{Int64}}
n_ells::NTuple{N_BLOCKS, UnitRange{Int64}}
end
function LinearAlgebra.inv(
F::BlockSpectralFactorization{T,M_BLOCKS,N_BLOCKS,AA}) where {T,M_BLOCKS,N_BLOCKS,AA}
return BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}(inv(F.parent), F.m_ells, F.n_ells)
end
Base.parent(F::BlockSpectralFactorization) = F.parent
function BlockSpectralMatrix(A::SpectralArray{T,2,AA}) where {T,AA}
return BlockSpectralMatrix{T,1,1,AA}(parent(A),
(UnitRange(axes(A,1)),), (UnitRange(axes(A,2)),))
end
function BlockSpectralMatrix(A::SpectralArray{T,1,AA}) where {T,AA}
return BlockSpectralMatrix{T,1,1,AA}(parent(A),
(UnitRange(axes(A,1)),), (UnitRange(axes(A,2)),))
end
function LinearAlgebra.lu(A::SpectralArray{T,2,AA}) where {T,AA}
F = lu(parent(A))
return BlockSpectralFactorization{T,1,1,AA,typeof(F)}(F,
(UnitRange(axes(A,1)),), (UnitRange(axes(A,2)),))
end
function LinearAlgebra.lu(A::BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}) where {T,M_BLOCKS,N_BLOCKS,AA}
F = lu(parent(A))
return BlockSpectralFactorization{T,M_BLOCKS,N_BLOCKS,AA,typeof(F)}(F,
A.m_ells, A.n_ells)
end
function (\)(F::BlockSpectralFactorization, B::SpectralVector{T}) where T
x = parent(F) \ parent(B)
return SpectralVector(x, B.parent.offsets)
end
function (\)(F::BlockSpectralFactorization,
B::BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}) where {T,M_BLOCKS,N_BLOCKS,AA}
x = parent(F) \ parent(B)
return BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}(x, B.m_ells, B.n_ells)
end
function (\)(A::BlockSpectralMatrix,
B::BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}) where {T,M_BLOCKS,N_BLOCKS,AA}
# check dims match
for blockindex = 1:N_BLOCKS
@assert A.n_ells[blockindex] == B.m_ells[blockindex]
end
F = lu(A)
x = parent(F) \ parent(B)
return BlockSpectralMatrix{T,M_BLOCKS,N_BLOCKS,AA}(x, B.m_ells, B.n_ells)
end
function (\)(A::SpectralArray{T,2}, B::SpectralVector{T}) where T
@assert firstindex(A,2) == firstindex(B,1)
F = lu(A)
x = parent(F) \ parent(B)
return SpectralVector(x, B.parent.offsets)
end
"""
getblock(A::BlockSpectralMatrix{T,M_BLOCKS,1}, i) where {T,M_BLOCKS}
Extract a sub-block from a `BlockSpectralMatrix`.
# Arguments:
- `A::BlockSpectralMatrix{T,M_BLOCKS,1}`: array to extract from
- `i::Int`: index of the sub-block (1-indexed)
# Returns:
- `Array{T,2}`: sub-blocks
"""
function getblock(A::BlockSpectralMatrix{T,M_BLOCKS,1},
i::Int) where {T,M_BLOCKS}
@assert i β€ M_BLOCKS
offset = 0
for block_i in 1:(i-1)
offset += length(A.m_ells[block_i])
end
return SpectralVector(parent(A)[(1+offset):(length(A.m_ells[i])+offset)], A.m_ells[i])
end
# based on UnPack's macro
@doc raw"""
@spectra [expr=BlockSpectralMatrix]
Unpack a block vector. This is equivalent to calling `getblock` for all the
sub-blocks and putting them in a Tuple.
# Example
```julia
# compute stacked EE,BB mode-coupling matrix from mask alm
M_EE_BB = mcm(:EE_BB, map2alm(mask1), map2alm(mask2))
# apply the 2Γ2 block mode-coupling matrix to the stacked EE and BB spectra
@spectra Cl_EE, Cl_BB = M_EE_BB \ [pCl_EE; pCl_BB]
```
"""
macro spectra(args)
args.head!=:(=) && error("Expression needs to be of form `a, b = c`")
items, suitcase = args.args
items = isa(items, Symbol) ? [items] : items.args
suitcase_instance = gensym()
kd = [:( $key = PowerSpectra.getblock($suitcase_instance, $i) ) for
(i, key) in enumerate(items)]
kdblock = Expr(:block, kd...)
expr = quote
local $suitcase_instance = $suitcase # handles if suitcase is not a variable but an expression
$kdblock
$suitcase_instance # return RHS of `=` as standard in Julia
end
esc(expr)
end