Skip to content

Commit

Permalink
Update find (#17)
Browse files Browse the repository at this point in the history
* pass ReliabilityMeasure to find

* also test on julia 1.9

* add basic Term.jl progress bar

* quality with function call
  • Loading branch information
p-gw committed Dec 27, 2023
1 parent a703943 commit 63a8a26
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 34 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
matrix:
version:
- '1.8'
- '1.9'
- 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SCS = "c946c3f1-0d1f-5ce8-9dea-7daa1f7e2d13"
Expand All @@ -26,7 +25,6 @@ Combinatorics = "1"
Distributions = "0.25"
JuMP = "1"
PrecompileTools = "1"
ProgressMeter = "1"
SCS = "2"
StatsBase = "0.34"
Tables = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/ClassicalTestTheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ using JuMP
using LinearAlgebra
using OrderedCollections
using Printf
using ProgressMeter
using Random
using Reexport
using SCS
using StatsAPI
using StatsBase
using Tables
using Term
using Term.Progress

@reexport import StatsAPI: confint, stderror
import Base: split
Expand Down
54 changes: 24 additions & 30 deletions src/find.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
find(m::AbstractMatrix, n::Int; criterion = glb, progress = true)
find(m::AbstractMatrix, n::Int, method::ReliabilityMeasure = GLB(); progress = true)
Perform an exhaustive search to find the subset of `n` items with maximum reliability.
Perform an exhaustive search to find the subset of `n` items with maximum reliability, where
`method` is used to estimate the reliability.
"""
function find(m::AbstractMatrix, args...; kwargs...)
is = _find(m, args...; kwargs...)
Expand All @@ -10,11 +11,10 @@ end

function _find(
m::AbstractMatrix,
n::Int;
criterion::F = glb,
n::Int,
method::ReliabilityMeasure = GLB();
progress = true,
kwargs...,
) where {F}
)
if n >= size(m, 2)
throw(
ArgumentError(
Expand All @@ -27,32 +27,26 @@ function _find(
combs = combinations(is, n)

optimal_is = zeros(Int, n)
max_crit = -Inf

prog = Progress(
length(combs),
dt = 0.5,
barglyphs = BarGlyphs("[=> ]"),
enabled = progress,
)

for (i, c) in enumerate(combs)
subtest = view(m, :, c)
crit = criterion(subtest; kwargs...)

if crit > max_crit
max_crit = crit
optimal_is = c
end
max_reliability = -Inf

ProgressMeter.update!(
prog,
i,
showvalues = [(:items, optimal_is), (:reliability, max_crit)],
)
end
prog = ProgressBar(transient = true)

Progress.with(prog) do
prog_job =
addjob!(prog, N = length(combs), description = "Finding optimal item subset...")

ProgressMeter.finish!(prog)
for c in combs
subtest = view(m, :, c)
reliability = method(subtest)

if reliability > max_reliability
max_reliability = reliability
optimal_is = c
end

update!(prog_job)
end
end

return optimal_is
end
4 changes: 3 additions & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using PrecompileTools
ia = itemanalysis(m)

# find
find(m, 2, criterion = alpha)
for method in methods
find(m, 2, method)
end
end
end
26 changes: 26 additions & 0 deletions test/find.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@testset "find" begin
# setup data
n_items = 4
n_persons = 100
ρ = 0.8

μ = zeros(n_items)
Σ = fill(ρ, n_items, n_items)

for i in 1:n_items
Σ[i, i] = 1
end

dist = MvNormal(μ, Σ)

m = rand(dist, n_persons)' .< 0
m_extended = hcat(m, zeros(n_persons))

# test
@test_throws ArgumentError find(m, n_items + 2)
@test size(find(m, 2)) == (n_persons, 2)
@test size(find(m, 1)) == (n_persons, 1)

@test find(m_extended, n_items) == m
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ using Test
@testset "ClassicalTestTheory.jl" begin
include("reliability.jl")
include("split.jl")
include("find.jl")
end

0 comments on commit 63a8a26

Please sign in to comment.