Skip to content

Commit

Permalink
Store the array length next to its dimensions. (JuliaGPU#1303)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored and simonbyrne committed Nov 13, 2023
1 parent 03b01b1 commit 8381dea
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/device/array.jl
Expand Up @@ -28,12 +28,13 @@ struct CuDeviceArray{T,N,A} <: DenseArray{T,N}
maxsize::Int

dims::Dims{N}
len::Int

# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
# TODO: deprecate; put `ptr` first like CuArray
CuDeviceArray{T,N,A}(dims::Dims{N}, ptr::LLVMPtr{T,A},
maxsize::Int=prod(dims)*sizeof(T)) where {T,A,N} =
new(ptr, maxsize, dims)
new(ptr, maxsize, dims, prod(dims))
end

const CuDeviceVector = CuDeviceArray{T,1,A} where {T,A}
Expand Down Expand Up @@ -67,11 +68,13 @@ CuDeviceMatrix{T,A}(m::Integer, n::Integer, p::LLVMPtr{T,A}) where {T,A} = C
## array interface

Base.elsize(::Type{<:CuDeviceArray{T}}) where {T} = sizeof(T)
Base.size(g::CuDeviceArray) = g.dims
Base.length(g::CuDeviceArray) = prod(g.dims)

Base.size(g::CuDeviceArray) = g.dims
Base.sizeof(x::CuDeviceArray) = Base.elsize(x) * length(x)

# we store the array length too; computing prod(size) is expensive
Base.length(g::CuDeviceArray) = g.len

Base.pointer(x::CuDeviceArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(LLVMPtr{T,A}, x)
@inline function Base.pointer(x::CuDeviceArray{T,<:Any,A}, i::Integer) where {T,A}
Base.unsafe_convert(LLVMPtr{T,A}, x) + Base._memory_offset(x, i)
Expand Down

0 comments on commit 8381dea

Please sign in to comment.