Skip to content

Commit

Permalink
✨ Add module function for calculating confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Mar 18, 2020
1 parent f703bfb commit 9f9683b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
34 changes: 34 additions & 0 deletions lib/rumale/evaluation_measure/function.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,40 @@ module Rumale
module EvaluationMeasure
module_function

# Calculate confusion matrix for evaluating classification performance.
#
# @example
# y_true = Numo::Int32[2, 0, 2, 2, 0, 1]
# y_pred = Numo::Int32[0, 0, 2, 2, 0, 2]
# p confusion_matrix(y_true, y_pred)
#
# # Numo::Int32#shape=[3,3]
# # [[2, 0, 0],
# # [0, 0, 1],
# # [1, 0, 2]]
#
# @param y_true [Numo::Int32] (shape: [n_samples]) The ground truth labels.
# @param y_pred [Numo::Int32] (shape: [n_samples]) The predicted labels.
# @return [Numo::Int32] (shape: [n_classes, n_classes]) The confusion matrix.
def confusion_matrix(y_true, y_pred)
y_true = Rumale::Validation.check_convert_label_array(y_true)
y_pred = Rumale::Validation.check_convert_label_array(y_pred)

labels = y_true.to_a.uniq.sort
n_labels = labels.size

conf_mat = Numo::Int32.zeros(n_labels, n_labels)

labels.each_with_index do |lbl_a, i|
y_p = y_pred[y_true.eq(lbl_a)]
labels.each_with_index do |lbl_b, j|
conf_mat[i, j] = y_p.eq(lbl_b).count
end
end

conf_mat
end

# Output a summary of classification performance for each class.
#
# @example
Expand Down
10 changes: 10 additions & 0 deletions spec/rumale/evaluation_measure/function_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
require 'spec_helper'

RSpec.describe Rumale::EvaluationMeasure do
describe '#confusion_matrix' do
let(:y_true) { Numo::Int32[0, 1, 2, 2, 2, 0, 1, 2, 0] }
let(:y_pred) { Numo::Int32[0, 0, 2, 0, 2, 0, 0, 2, 1] }
let(:res) { described_class.confusion_matrix(y_true, y_pred) }

it 'calculates confusion matrix' do
expect(res).to eq(Numo::Int32[[2, 1, 0], [2, 0, 0], [1, 0, 3]])
end
end

describe '#classification_report' do
let(:y_true) { Numo::Int32[0, 1, 2, 2, 2, 0, 1, 2] }
let(:y_pred) { Numo::Int32[0, 0, 2, 0, 2, 0, 0, 2] }
Expand Down

0 comments on commit 9f9683b

Please sign in to comment.