Skip to content

Commit

Permalink
add bounds checking for get/setindex
Browse files Browse the repository at this point in the history
  • Loading branch information
jakebolewski committed May 27, 2015
1 parent 2585bab commit 1149823
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions src/pyarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,31 +212,43 @@ Base.summary{T}(a::PyArray{T}) = string(Base.dims2string(size(a)), " ",
#TODO: is this correct for all buffer types other than contig/dense?
#TODO: get rid of this, should be copy! but copy! uses similar under the hood
function Base.copy{T,N}(a::PyArray{T,N})
if N > 1 && a.c_contig # equivalent to f_contig with reversed dims
B = pointer_to_array(a.data, ntuple(N, n -> a.dims[N - n + 1]))
return N == 2 ? transpose(B) : permutedims(B, (N:-1:1))
if N > 1 && a.c_contig i
# equivalent to f_contig with reversed dims
B = pointer_to_array(a.data, (Int[a.dims[N - d + 1] for d in 1:N]...))
if N == 2
return transpose(B)
else
return permutedims(B, N:-1:1)
end
end
A = Array(T, a.dims)
if a.f_contig
ccall(:memcpy, Void, (Ptr{T}, Ptr{T}, Int), A, a, sizeof(T) * length(a))
return A
else
return copy!(A, a)
end
return copy!(A, a)
end

#TODO: Bounds checking is needed
Base.getindex{T}(a::PyArray{T,0}) = unsafe_load(a.data)

Base.getindex{T}(a::PyArray{T,1}, i::Integer) =
function Base.getindex{T}(a::PyArray{T,1}, i::Integer)
1 <= i <= length(a) || throw(BoundsError())
unsafe_load(a.data, 1 + (i-1) * a.strides[1])
end

Base.getindex{T}(a::PyArray{T,2}, i::Integer, j::Integer) =
function Base.getindex{T}(a::PyArray{T,2}, i::Integer, j::Integer)
1 <= i <= size(a,1) || throw(BoundsError())
1 <= j <= size(a,2) || throw(BoundsError())
unsafe_load(a.data, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2])
end

Base.getindex(a::PyArray, i::Integer) =
a.f_contig ? unsafe_load(a.data, i) :
getindex(a, ind2sub(a.dims, i)...)
function Base.getindex(a::PyArray, i::Integer)
if a.f_contig
1 <= i <= length(a) || throw(BoundsError())
return unsafe_load(a.data, i)
end
return getindex(a, ind2sub(a.dims, i)...)
end

function Base.getindex(a::PyArray, is::Integer...)
index = 1
Expand Down Expand Up @@ -265,14 +277,22 @@ end

Base.setindex!{T}(a::PyArray{T,0}, v) = (unsafe_store!(pointer(a), v, 1); v)

Base.setindex!{T}(a::PyArray{T,1}, v, i::Integer) =
(unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1]); v)
function Base.setindex!{T}(a::PyArray{T,1}, v, i::Integer)
1 <= i <= length(a) || throw(BoundsError())
unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1])
return v
end

Base.setindex!{T}(a::PyArray{T,2}, v, i::Integer, j::Integer) =
(unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2]); v)
function Base.setindex!{T}(a::PyArray{T,2}, v, i::Integer, j::Integer)
1 <= i <= size(a,1) || throw(BoundsError())
1 <= j <= size(a,2) || throw(BoundsError())
unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2])
return v
end

function Base.setindex!(a::PyArray, v, i::Integer)
if a.f_contig
1 <= i <= length(a) || throw(BoundsError())
unsafe_store!(pointer(a), v, i)
return v
end
Expand Down

0 comments on commit 1149823

Please sign in to comment.