Skip to content

Commit

Permalink
Some more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
tinybike committed Sep 26, 2014
1 parent 82d6472 commit 0637ac7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/WeightedStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ NumericArray = Union(Array{Int,1}, Array{Float32,1}, Array{Float64,1})

function weighted_mean(data::NumericArray, weights::NumericArray)
data = float(data)
weights = float(weights)
(data' * (weights ./ sum(weights)))[1]
weights = float(weights) / sum(weights)
(data' * weights)[1]
end

function weighted_median(data::NumericArray, weights::NumericArray)
sorted = sortrows([data weights])
midpoint = 0.5 * sum(sorted[:,2])
if any(weights .> midpoint)
median(data[weights .== maximum(weights)])
(data[weights .== maximum(weights)])[1]
else
cumulative_weight = cumsum(sorted[:,2])
below_midpoint_index = find(cumulative_weight .<= midpoint)[end]
Expand Down
32 changes: 22 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
using WeightedStats
using Base.Test

function approx_eq(a, b, tol)
abs(a - b) < tol
end

tol = 1e-6

data = (
[7, 1, 2, 4, 10],
[7, 1, 2, 4, 10],
[7, 1, 2, 4, 10, 15],
[1, 2, 4, 7, 10, 15],
[0, 10, 20, 30],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[30, 40, 50, 60, 35],
[2, 0.6, 1.3, 0.3, 0.3, 1.7, 0.7, 1.7, 0.4],
[3.7, 3.3, 3.5, 2.8],
[100, 125, 123, 60, 45, 56, 66],
)
weights = (
[1, 1/3, 1/3, 1/3, 1],
Expand All @@ -18,25 +27,28 @@ weights = (
[1/3, 1/3, 1/3, 1, 1, 1],
[30, 191, 9, 0],
[10, 1, 1, 1, 9],
[10, 1, 1, 1, 900],
[1, 3, 5, 4, 2],
[2, 2, 0, 1, 2, 2, 1, 6, 0],
[5, 5, 4, 1],
[30, 56, 144, 24, 55, 43, 67],
)
median_answers = (7, 4, 8.5, 8.5, 10, 2.5, 50, 1.7)
mean_answers = (6.444444444444444, 4.800000000000001, 8.583333333333334,
8.583333333333332, 9.08695652173913, 2.909090909090909,
47.33333333333333, 1.275)
median_answers = (7, 4, 8.5, 8.5, 10, 2.5, 5, 50, 1.7, 3.5, 100)
mean_answers = (6.444444, 4.800000, 8.583333,
8.583333, 9.086956, 2.909091,
4.949617, 47.333333, 1.275,
3.453333, 91.782816)

num_tests = length(median_answers)
num_tests = length(data)

@test_throws MethodError weighted_median(data[1])
# @test_throws MethodError weighted_mean(data[1])
@test_throws MethodError weighted_median("string", "input")
# @test_throws MethodError weighted_mean("string", "input")
@test_throws MethodError weighted_mean(data[1])
@test_throws MethodError weighted_mean("string", "input")

for i = 1:num_tests
@test isa(data[i], NumericArray)
@test isa(weights[i], NumericArray)
@test typeof(median_answers[i]) <: Real
@test weighted_median(data[i], weights[i]) == median_answers[i]
@test weighted_mean(data[i], weights[i]) == mean_answers[i]
@test approx_eq(weighted_median(data[i], weights[i]), median_answers[i], tol)
@test approx_eq(weighted_mean(data[i], weights[i]), mean_answers[i], tol)
end

0 comments on commit 0637ac7

Please sign in to comment.