Skip to content

Commit

Permalink
fix broadcast_dims with groupby (#684)
Browse files Browse the repository at this point in the history
* fix broadcastdims after broupby

* use OpaqueArray
  • Loading branch information
rafaqz committed Apr 4, 2024
1 parent c09557d commit b69b48f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
6 changes: 6 additions & 0 deletions src/Dimensions/format.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ function format(dims::Tuple{<:Pair,Vararg{Pair}}, A::AbstractArray)
end
return format(dims, A)
end
# Make a dummy array that assumes the dims are the correct length and don't hold `Colon`s
function format(dims::DimTuple)
ax = map(parent first axes, dims)
A = CartesianIndices(ax)
return format(dims, A)
end
format(dims::Tuple{Vararg{Any,N}}, A::AbstractArray{<:Any,N}) where N = format(dims, axes(A))
@noinline format(dims::Tuple{Vararg{Any,M}}, A::AbstractArray{<:Any,N}) where {N,M} =
throw(DimensionMismatch("Array A has $N axes, while the number of dims is $M: $(map(basetypeof, dims))"))
Expand Down
14 changes: 11 additions & 3 deletions src/dimindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,20 @@ struct DimSlices{T,N,D<:Tuple{Vararg{Dimension}},P} <: AbstractDimArrayGenerator
end
DimSlices(x; dims, drop=true) = DimSlices(x, dims; drop)
function DimSlices(x, dims; drop=true)
newdims = length(dims) == 0 ? map(d -> rebuild(d, :), DD.dims(x)) : dims
inds = map(d -> rebuild(d, first(axes(x, d))), newdims)
newdims = if length(dims) == 0
map(d -> rebuild(d, :), DD.dims(x))
else
dims
end
inds = map(basedims(newdims)) do d
rebuild(d, first(axes(x, d)))
end
# `getindex` returns these views
T = typeof(view(x, inds...))
N = length(newdims)
D = typeof(newdims)
return DimSlices{T,N,D,typeof(x)}(x, newdims)
P = typeof(x)
return DimSlices{T,N,D,P}(x, newdims)
end

rebuild(ds::A; dims) where {A<:DimSlices{T,N}} where {T,N} =
Expand Down
22 changes: 17 additions & 5 deletions src/groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ end
rebuild(A, data, dims, refdims, name, metadata) # Rebuild as a reguilar DimArray
end

function Base.summary(io::IO, A::DimGroupByArray{T,N}) where {T,N}
function Base.summary(io::IO, A::DimGroupByArray{T,N}) where {T<:AbstractArray{T1,N1},N} where {T1,N1}
print_ndims(io, size(A))
print(io, string(nameof(typeof(A)), "{$(nameof(T)),$N}"))
print(io, string(nameof(typeof(A)), "{$(nameof(T)){$T1,$N1},$N}"))
end

function show_after(io::IO, mime, A::DimGroupByArray)
Expand Down Expand Up @@ -80,6 +80,17 @@ function Base.show(io::IO, s::DimSummariser)
end
Base.alignment(io::IO, s::DimSummariser) = (textwidth(sprint(show, s)), 0)

# An array that doesn't know what it holds, to simplify dispatch
struct OpaqueArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
parent::A
end
Base.parent(A::OpaqueArray) = A.parent
Base.size(A::OpaqueArray) = size(parent(A))
for f in (:getindex, :view, :dotview)
@eval Base.$f(A::OpaqueArray, args...) = Base.$f(parent(A), args...)
end
Base.setindex!(A::OpaqueArray, args...) = Base.setindex!(parent(A), args...)


abstract type AbstractBins <: Function end

Expand Down Expand Up @@ -331,9 +342,11 @@ function DataAPI.groupby(A::DimArrayOrStack, dimfuncs::DimTuple)
end
# Separate lookups dims from indices
group_dims = map(first, dim_groups_indices)
indices = map(rebuild, dimfuncs, map(last, dim_groups_indices))
# Get indices for each group wrapped with dims for indexing
indices = map(rebuild, group_dims, map(last, dim_groups_indices))

views = DimSlices(A, indices)
# Hide that the parent is a DimSlices
views = OpaqueArray(DimSlices(A, indices))
# Put the groupby query in metadata
meta = map(d -> dim2key(d) => val(d), dimfuncs)
metadata = Dict{Symbol,Any}(:groupby => length(meta) == 1 ? only(meta) : meta)
Expand Down Expand Up @@ -394,7 +407,6 @@ function _group_indices(dim::Dimension, bins::AbstractBins; labels=bins.labels)
return _group_indices(transformed_lookup, group_lookup; labels)
end


# Get a vector of intervals for the bins
_groups_from(_, bins::Bins{<:Any,<:AbstractArray}) = bins.bins
function _groups_from(transformed, bins::Bins{<:Any,<:Integer})
Expand Down
20 changes: 17 additions & 3 deletions test/groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,22 @@ end
end
end
@test all(collect(mean.(gb)) .=== manualmeans)
@test all(
mean.(gb) .=== manualmeans
)
@test all(mean.(gb) .=== manualmeans)
end

@testset "broadcastdims runs after groupby" begin
dimlist = (
Ti(Date("2021-12-01"):Day(1):Date("2022-12-31")),
X(range(1, 10, length=10)),
Y(range(1, 5, length=15)),
Dim{:Variable}(["var1", "var2"])
)
data = rand(396, 10, 15, 2)
A = DimArray(data, dimlist)
month_length = DimArray(daysinmonth, dims(A, Ti))
g_tempo = DimensionalData.groupby(month_length, Ti=>seasons(; start=December))
sum_days = sum.(g_tempo, dims=Ti)
weights = map(./, g_tempo, sum_days)
G = DimensionalData.groupby(A, Ti=>seasons(; start=December))
G_w = broadcast_dims.(*, weights, G)
end

0 comments on commit b69b48f

Please sign in to comment.