Skip to content

Commit

Permalink
make grids lazy (#37)
Browse files Browse the repository at this point in the history
Make grids iterable, fix types, incidental changes.
  • Loading branch information
tpapp committed Sep 20, 2022
1 parent c8d6c19 commit 154e052
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SpectralKit"
uuid = "5c252ae7-b5b6-46ab-a016-b0e3d78320b7"
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
version = "0.9.2"
version = "0.10.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ ct = coordinate_transformations(BoundedLinear(-1, 2.0), SemiInfRational(-3.0, 3.
basis = smolyak_basis(Chebyshev, InteriorGrid2(), SmolyakParameters(3), 2)
x = grid(basis)
θ = collocation_matrix(basis) \ f2.(from_pm1.(ct, x)) # find the coefficients
z = SVector(0.5, 0.7) # evaluate at this point
z = (0.5, 0.7) # evaluate at this point
isapprox(f2(z), linear_combination(basis, θ, to_pm1(ct, z)), rtol = 0.005)
```

Expand Down Expand Up @@ -96,7 +96,7 @@ Chebyshev

### Multivariate bases

Multivariate bases operate on vectors. `StaticArrays.SVector` is preferred for performance, but all `<:AbstractVector` types should work.
Multivariate bases operate on tuples or vectors (`StaticArrays.SVector` is preferred for performance, but all `<:AbstractVector` types should work).

```@docs
SmolyakParameters
Expand Down
40 changes: 35 additions & 5 deletions src/chebyshev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,58 @@ end
#### grids
####

"""
$(SIGNATURES)
Return a gridpoint for collocation, with `1 ≤ i ≤ dimension(basis)`.
`T` is used *as a hint* for the element type of grid coordinates, and defaults to `Float64`.
The actual type can be broadened as required. Methods are type stable.
!!! note
Not all grids have this method defined, especially if it is impractical. See
[`grid`](@ref), which is part of the API, this function isn't.
"""
function gridpoint(::Type{T}, basis::Chebyshev{InteriorGrid}, i::Integer) where {T <: Real}
@unpack N = basis
@argcheck 1 i N # FIXME use boundscheck
sinpi((N - 2 * i + 1) / T(2 * N)) # use formula Xu (2016)
@argcheck 1 i N # FIXME use boundscheck
sinpi((N - 2 * i + 1) / T(2 * N))::T # use formula Xu (2016)
end

function gridpoint(::Type{T}, basis::Chebyshev{EndpointGrid}, i::Integer) where {T <: Real}
@unpack N = basis
@argcheck 1 i N # FIXME use boundscheck
if N == 1
cospi(1/T(2)) # 0.0 as a fallback, even though it does not have endpoints
cospi(1/T(2))::T # 0.0 as a fallback, even though it does not have endpoints
else
cospi((N - i) ./ T(N - 1))
cospi((N - i) ./ T(N - 1))::T
end
end

function gridpoint(::Type{T}, basis::Chebyshev{InteriorGrid2}, i::Integer) where {T <: Real}
@unpack N = basis
@argcheck 1 i N # FIXME use boundscheck
cospi(((N - i + 1) ./ T(N + 1)))
cospi(((N - i + 1) ./ T(N + 1)))::T
end

struct ChebyshevGridIterator{T,B}
basis::B
end

Base.eltype(::Type{<:ChebyshevGridIterator{T}}) where {T} = T

Base.length(itr::ChebyshevGridIterator) = dimension(itr.basis)

function Base.iterate(itr::ChebyshevGridIterator{T}, i = 1) where {T}
@unpack basis = itr
if i dimension(basis)
gridpoint(T, basis, i), i + 1
else
nothing
end
end

grid(::Type{T}, basis::B) where {T<:Real,B<:Chebyshev} = ChebyshevGridIterator{T,B}(basis)

####
#### augmenting
Expand Down
27 changes: 6 additions & 21 deletions src/generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ function dimension end
Return an iterable with known element type and length (`Base.HasEltype()`,
`Base.HasLength()`) of basis functions in `basis` evaluated at `x`.
Univariate bases operate on real numbers, while for multivariate bases,
`StaticArrays.SVector` is preferred for performance, though all `<:AbstractVector` types
Univariate bases operate on real numbers, while for multivariate bases, `Tuple`s or
`StaticArrays.SVector` are preferred for performance, though all `<:AbstractVector` types
should work.
Methods are type stable.
Expand Down Expand Up @@ -141,34 +141,19 @@ Equivalent to an `EndpointGrid` with endpoints dropped.
"""
struct InteriorGrid2 <: AbstractGrid end

"""
$(SIGNATURES)
Return a gridpoint for collocation, with `1 ≤ i ≤ dimension(basis)`.
`T` is used *as a hint* for the element type of grid coordinates, and defaults to `Float64`.
The actual type can be broadened as required. Methods are type stable.
!!! note
Not all grids have this method defined, especially if it is impractical. See
[`grid`](@ref), which is part of the API, this function isn't.
"""
gridpoint(basis, i) = gridpoint(Float64, basis, i)

"""
`$(FUNCTIONNAME)([T], basis)`
Return a grid recommended for collocation, with `dimension(basis)` elements.
Return an iterator for the grid recommended for collocation, with `dimension(basis)`
elements.
`T` is used *as a hint* for the element type of grid coordinates, and defaults to `Float64`.
The actual type can be broadened as required. Methods are type stable.
`T` for the element type of grid coordinates, and defaults to `Float64`.
Methods are type stable.
"""
grid(basis) = grid(Float64, basis)

function grid(::Type{T}, basis) where {T<:Real}
map(i -> gridpoint(T, basis, i), 1:dimension(basis))
end

"""
$(SIGNATURES)
Expand Down
23 changes: 20 additions & 3 deletions src/smolyak_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,29 @@ function basis_at(smolyak_basis::SmolyakBasis{<:SmolyakIndices{N,H}}, x) where {
SmolyakProduct(smolyak_indices, univariate_bases_at)
end

struct SmolyakGridIterator{T,I,S}
smolyak_indices::I
sources::S
end

Base.eltype(::Type{<:SmolyakGridIterator{T}}) where {T} = T

Base.length(itr::SmolyakGridIterator) = length(itr.smolyak_indices)

function grid(::Type{T},
smolyak_basis::SmolyakBasis{<:SmolyakIndices{N,H}}) where {T<:Real,N,H}
@unpack smolyak_indices, univariate_parent = smolyak_basis
x = sacollect(SVector{H}, gridpoint(T, univariate_parent, i)
for i in SmolyakGridShuffle(univariate_parent.grid_kind, H))
[SVector{N}(map(i -> x[i], ι)) for ι in smolyak_indices]
sources = sacollect(SVector{H}, gridpoint(T, univariate_parent, i)
for i in SmolyakGridShuffle(univariate_parent.grid_kind, H))
SmolyakGridIterator{NTuple{N,T},typeof(smolyak_indices),typeof(sources)}(smolyak_indices, sources)
end

function Base.iterate(itr::SmolyakGridIterator, state...)
@unpack smolyak_indices, sources = itr
result = iterate(smolyak_indices, state...)
result nothing && return nothing
ι, state′ = result
map(i -> sources[i], ι), state′
end

"""
Expand Down
24 changes: 11 additions & 13 deletions src/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ coordinate transformations
(0.0,2.0) ↔ (-1, 1) [linear transformation]
(2,∞) ↔ (-1, 1) [rational transformation with scale 3]
julia> x = from_pm1(ct, SVector(0.4, 0.5))
2-element SVector{2, Float64} with indices SOneTo(2):
1.4
11.0
julia> x = from_pm1(ct, (0.4, 0.5))
(1.4, 11.0)
julia> y = to_pm1(ct, x)
2-element SVector{2, Float64} with indices SOneTo(2):
0.3999999999999999
0.5
(0.3999999999999999, 0.5)
```
"""
function coordinate_transformations(transformations::Tuple)
Expand All @@ -108,20 +104,22 @@ end

coordinate_transformations(transformations...) = coordinate_transformations(transformations)

function to_pm1(ct::CoordinateTransformations, x::SVector{N}) where N
SVector{N}(map((t, x) -> to_pm1(t, x), ct.transformations, Tuple(x)))
function to_pm1(ct::CoordinateTransformations, x::Tuple)
@argcheck length(ct.transformations) == length(x)
map((t, x) -> to_pm1(t, x), ct.transformations, x)
end

function to_pm1(ct::CoordinateTransformations, x::AbstractVector)
to_pm1(ct, SVector{length(ct.transformations)}(x))
to_pm1(ct, Tuple(x))
end

function from_pm1(ct::CoordinateTransformations, x::SVector{N}) where N
SVector{N}(map((t, x) -> from_pm1(t, x), ct.transformations, Tuple(x)))
function from_pm1(ct::CoordinateTransformations, x::Tuple)
@argcheck length(ct.transformations) == length(x)
map((t, x) -> from_pm1(t, x), ct.transformations, x)
end

function from_pm1(ct::CoordinateTransformations, x::AbstractVector)
from_pm1(ct, SVector{length(ct.transformations)}(x))
from_pm1(ct, Tuple(x))
end

####
Expand Down
2 changes: 1 addition & 1 deletion test/test_chebyshev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
end

# check grid
g = @inferred grid(basis)
g = @inferred collect(grid(basis))
@test length(g) == N
if grid_kind InteriorGrid()
@test all(x -> isapprox(chebyshev_cos(x, N + 1), 0, atol = 1e-14), g)
Expand Down
9 changes: 8 additions & 1 deletion test/test_generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
@test !is_function_basis("a fish")
end

@testset "collocation matrix with default grid" begin
basis = Chebyshev(EndpointGrid(), 10) # basis for approximation
@test collocation_matrix(basis) == collocation_matrix(basis, collect(grid(basis)))
end

@testset "non-square collocation matrix" begin
f = exp # function for comparison
basis = Chebyshev(EndpointGrid(), 10) # basis for approximation
x = grid(Chebyshev(EndpointGrid(), 20)) # denser grid for approximation
g = grid(Chebyshev(EndpointGrid(), 20))
iterator_sanity_checks(g)
x = @inferred collect(g) # denser grid for approximation
C = collocation_matrix(basis, x)
@test all(isfinite, C)
@test size(C) == (20, 10)
Expand Down
6 changes: 4 additions & 2 deletions test/test_smolyak.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ end
f(x) = (x[1] - 3) * (x[2] + 5) # linear function, just a sanity check
basis = smolyak_basis(Chebyshev, InteriorGrid(), SmolyakParameters(3), 2)
@test @inferred(domain(basis)) == ((-1, 1), (-1, 1))
x = @inferred grid(Float64, basis)
g = grid(Float64, basis)
iterator_sanity_checks(g)
x = @inferred collect(g)
M = @inferred collocation_matrix(basis, x)
θ = M \ f.(x)
@test sum(abs.(θ) .> 1e-8) == 4
Expand Down Expand Up @@ -187,7 +189,7 @@ end
for B2 in (B1 + 1):M2
basis1 = smolyak_basis(Chebyshev, grid_kind, SmolyakParameters(B1, M1), 2)
basis2 = smolyak_basis(Chebyshev, grid_kind, SmolyakParameters(B2, M2), 2)
@test is_approximate_subset(grid(basis1), grid(basis2))
@test is_approximate_subset(collect(grid(basis1)), collect(grid(basis2)))
end
end
end
Expand Down
10 changes: 5 additions & 5 deletions test/test_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ end
ct = coordinate_transformations(t1, t2)
x = SVector(rand_pm1(), rand_pm1())
y = @inferred from_pm1(ct, x)
@test y isa SVector{2}
@test y == SVector(from_pm1(t1, x[1]), from_pm1(t2, x[2]))
@test y isa NTuple{2,Float64}
@test y == (from_pm1(t1, x[1]), from_pm1(t2, x[2]))

# handle generic inputs
y2 = @inferred from_pm1(ct, Vector(x))
@test y2 isa SVector{2,Float64} && y2 == y
@test y2 isa NTuple{2,Float64} && y2 == y

x2 = @inferred to_pm1(ct, Vector(y))
@test x2 isa SVector{2,Float64} && x2 x
x2 = @inferred to_pm1(ct, [y...])
@test x2 isa NTuple{2,Float64} && all(x2 . x)
end

@testset "partial application" begin
Expand Down
1 change: 1 addition & 0 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end
function is_approximately_in(a, b; atol = eps())
_same(a::Real, b::Real) = a == b || abs(a - b) atol # Inf = Inf, etc
_same(a::AbstractVector, b::AbstractVector) = all(_same.(a, b))
_same(a::Tuple, b::Tuple) = mapreduce((x, y) -> abs(x - y), max, a, b) atol
map(a -> any(b -> _same(a, b), b), a)
end

Expand Down

2 comments on commit 154e052

@tpapp
Copy link
Owner Author

@tpapp tpapp commented on 154e052 Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/68652

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.0 -m "<description of version>" 154e0523dc3fc094e32d22aa900fd70c83e0b594
git push origin v0.10.0

Please sign in to comment.