Skip to content

Commit

Permalink
Add batch
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed May 27, 2019
1 parent c474e71 commit a9640a7
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Kaleido
## Setting/getting multiple locations

```@docs
Kaleido.@batchlens
Kaleido.batch
Kaleido.MultiLens
Kaleido.PropertyBatchLens
Kaleido.KeyBatchLens
Expand Down
5 changes: 5 additions & 0 deletions src/Kaleido.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,29 @@ end ->
module Kaleido

export
@batchlens,
BijectionLens,
FlatLens,
IndexBatchLens,
KeyBatchLens,
MultiLens,
PropertyBatchLens,
batch,
toℝ₊,
toℝ₋,
to𝕀

using Setfield
using Setfield: ComposedLens, IdentityLens, PropertyLens
using Requires

include("base.jl")
include("lensutils.jl")
include("batchsetters.jl")
include("batchlenses.jl")
include("multilens.jl")
include("flatlens.jl")
include("batching.jl")
include("bijection.jl")

function __init__()
Expand Down
23 changes: 23 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,26 @@ _map(f, xs) = _mapfoldl(f, _push, xs, ())

_cat() = ()
_cat(xs, tuples...) = (xs..., _cat(tuples...)...)

headtail(xs) = _headtail(xs...)
_headtail(x1, xs...) = (x1, xs)

newpartition(x, by) = (
key = by(x),
values = (x,),
by = by,
)

function partitionby(values::Tuple, by)
x, rest = headtail(values)
return _foldl(_partitionstep, rest, (newpartition(x, by),))
end

_partitionstep(partitions, x) =
if partitions[end].by(x) == partitions[end].key
modify(partitions, @lens _[length(partitions)].values) do values
_push(values, x)
end
else
_push(partitions, newpartition(x, partitions[end].by))
end
102 changes: 102 additions & 0 deletions src/batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
batch(lens₁, lens₂, ..., lensₙ) :: Lens
From ``n`` lenses, create a single lens that gets/sets ``n``-tuple in
such a way that the number of call to the constructor is minimized.
This is done by calling [`IndexBatchLens`](@ref) whenever possible.
"""
batch

leftmost(l::Lens) = l
leftmost(l::ComposedLens) = leftmost(l.outer)
rightlens(::Lens) = IdentityLens()
rightlens(l::ComposedLens) = rightlens(l.outer) l.inner

startswithproperty(::T) where {T <: Lens} = startswithproperty(T)
startswithproperty(::Type{<:Lens}) = false
startswithproperty(::Type{<:PropertyLens}) = true
startswithproperty(::Type{<:ComposedLens{LO}}) where {LO} = startswithproperty(LO)

allstartswithproperty(::T) where {T} = allstartswithproperty(T)
allstartswithproperty(::Type{Tuple{}}) = true
allstartswithproperty(::Type{T}) where {N, T <: NTuple{N, Lens}} =
startswithproperty(Base.tuple_type_head(T)) &&
allstartswithproperty(Base.tuple_type_tail(T))

propname(::PropertyLens{name}) where name = name

batch(::IdentityLens) = SingletonLens()

batch(lenses::PropertyLens...) = IndexBatchLens(propname.(lenses)...)

function batch(lenses::Lens...)
allstartswithproperty(lenses) || return MultiLens(lenses...)
partitions = partitionby(lenses, leftmost)
propnames = _map(x -> propname(x.key), partitions)
sublenses = _map(x -> rightlens.(x.values), partitions)
indexlenses = ntuple(i -> (@lens _[i]), length(partitions))
return IndexBatchLens(propnames...)
MultiLens(_compose.(indexlenses, _batch.(sublenses)))
FlatLens(_map(x -> length(x.values), partitions)...)
end
# TODO: Sort the lenses first and compute the permutation to recover
# the original order.

_batch(lenses::Tuple{Vararg{Lens}}) = batch(lenses...)


"""
@batchlens begin
lens_expression_1
lens_expression_2
...
lens_expression_n
end
# Examples
```jldoctest
julia> using Kaleido, Setfield
julia> lens = @batchlens begin
_.a.b.c
_.a.b.d
_.a.e
end;
julia> obj = (a = (b = (c = 1, d = 2), e = 3),);
julia> get(obj, lens)
(1, 2, 3)
julia> set(obj, lens, (10, 20, 30))
(a = (b = (c = 10, d = 20), e = 30),)
```
"""
macro batchlens(lenses_expression)
if !(lenses_expression isa Expr && lenses_expression.head == :block)
error("""
Macro @batchlens needs a block of lens expressions. Got:
$lenses_expression
""")
end
lnns = lenses_expression.args[1:2:end]
exprs = lenses_expression.args[2:2:end]
if all(isa.(lnns, LineNumberNode)) &&
all(isa.(exprs, Expr)) &&
length(lnns) == length(exprs)
lens_exprs = make_lens_expr.(exprs, lnns)
else
lens_exprs = [
make_lens_expr(x, __source__) for x in lenses_expression.args
if x isa Expr
]
end
return esc(Expr(:call, batch, lens_exprs...))
end

function make_lens_expr(ex::Expr, lnn::LineNumberNode)
atlens = Expr(:., Setfield, QuoteNode(Symbol("@lens")))
# Make `:(Setfield.@lens $ex)` with a proper `LineNumberNode`:
return Expr(:macrocall, atlens, lnn, ex)
end
# TODO: handle ∘
11 changes: 11 additions & 0 deletions src/flatlens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ FlatLens(lengths::Vararg{Integer}) = FlatLens{UInt.(lengths)}()

Base.show(io::IO, ::FlatLens{lengths}) where lengths =
print_apply(io, FlatLens, Int.(lengths))


"""
SingletonLens()
Inverse of `FlatLens(1)`.
"""
struct SingletonLens <: Lens end

Setfield.get(obj, ::SingletonLens) = (obj,)
Setfield.set(::Any, ::SingletonLens, obj) = obj[1]
12 changes: 12 additions & 0 deletions src/lensutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
normalize(l::Lens) = l
normalize(l::ComposedLens) = _compose(l.outer, l.inner)

"""
_compose(lens1, lens2)
Like `∘` but fixes the associativity to match with the default one in
Setfield.
"""
_compose(l1::Lens, l2::Lens) = l1 l2
_compose(l1::Lens, l2::ComposedLens) =
_compose(_compose(normalize(l1), l2.outer), l2.inner)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Test
"test_batchlenses.jl"
"test_flatlens.jl"
"test_multilens.jl"
"test_batching.jl"
"test_bijection.jl"
"test_transformvariables.jl"
]
Expand Down
71 changes: 71 additions & 0 deletions test/test_batching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
module TestBatching

include("preamble.jl")
using Kaleido: SingletonLens

lens = @batchlens begin
_.a.b.c.d
_.a.b.c.e
_.a.b.f
_.a.g
end

@testset "`@batchlens` inside `@test`" begin
# Inside of `@test` macro, `LineNumberNode`s are missing. Make
# sure that `@batchlens` can handle that:
@test (@batchlens begin
_.a.b.c.d
_.a.b.c.e
_.a.b.f
_.a.g
end) == lens
end

@testset "`batch`ed lenses" begin
@test (@batchlens begin
_.a
_.b
end) ==
IndexBatchLens(:a, :b)

@test (@batchlens begin
_.a.b
_.a.c
end) ==
IndexBatchLens(:a) MultiLens((
(@lens _[1]) IndexBatchLens(:b, :c),
)) FlatLens(2)

global lens2, lens3
lens2 = (@batchlens begin
_.a.b.c
_.a.b.d
_.a.e
end)
lens3 = IndexBatchLens(:a) MultiLens((
(@lens _[1]) IndexBatchLens(:b, :e) MultiLens((
(@lens _[1]) IndexBatchLens(:c, :d),
(@lens _[2]) Kaleido.SingletonLens(),
)) FlatLens(2, 1),
)) FlatLens(3)
@test lens2 ==
lens3

@test (@batchlens begin
_.a.b.c.d
_.a.b.c.e
_.a.b.f
_.a.g
end) ==
IndexBatchLens(:a) MultiLens((
(@lens _[1]) IndexBatchLens(:b, :g) MultiLens((
(@lens _[1]) IndexBatchLens(:c, :f) MultiLens((
(@lens _[1]) IndexBatchLens(:d, :e),
(@lens _[2]) SingletonLens(),
)) FlatLens(2, 1),
(@lens _[2]) SingletonLens(),
)) FlatLens(3, 1),
)) FlatLens(4)
end

end # module

0 comments on commit a9640a7

Please sign in to comment.