Skip to content
Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
362 lines (328 sloc) 11.8 KB
--[[ A Confusion Matrix class
Example:
conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix
conf:zero() -- reset matrix
for i = 1,N do
conf:add( neuralnet:forward(sample), label ) -- accumulate errors
end
print(conf) -- print matrix
image.display(conf:render()) -- render matrix
]]
local ConfusionMatrix = torch.class('optim.ConfusionMatrix')
function ConfusionMatrix:__init(nclasses, classes)
if type(nclasses) == 'table' then
classes = nclasses
nclasses = #classes
end
self.mat = torch.LongTensor(nclasses,nclasses):zero()
self.valids = torch.FloatTensor(nclasses):zero()
self.unionvalids = torch.FloatTensor(nclasses):zero()
self.nclasses = nclasses
self.totalValid = 0
self.averageValid = 0
self.classes = classes or {}
-- buffers
self._mat_flat = self.mat:view(-1)
self._target = torch.FloatTensor()
self._prediction = torch.FloatTensor()
self._max = torch.FloatTensor()
self._pred_idx = torch.LongTensor()
self._targ_idx = torch.LongTensor()
end
-- takes scalar prediction and target as input
function ConfusionMatrix:_add(p, t)
assert(p and type(p) == 'number')
assert(t and type(t) == 'number')
-- non-positive values are considered missing
-- and therefore ignored
if t > 0 then
self.mat[t][p] = self.mat[t][p] + 1
end
end
function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self:_add(prediction, target)
else
self._prediction:resize(prediction:size()):copy(prediction)
assert(prediction:dim() == 1)
if type(target) == 'number' then
-- prediction is a vector, then target assumed to be an index
self._max:max(self._pred_idx, self._prediction, 1)
self:_add(self._pred_idx[1], target)
else
-- both prediction and target are vectors
assert(target:dim() == 1)
self._target:resize(target:size()):copy(target)
self._max:max(self._targ_idx, self._target, 1)
self._max:max(self._pred_idx, self._prediction, 1)
self:_add(self._pred_idx[1], self._targ_idx[1])
end
end
end
function ConfusionMatrix:batchAdd(predictions, targets)
local preds, targs, __
self._prediction:resize(predictions:size()):copy(predictions)
if predictions:dim() == 1 then
-- predictions is a vector of classes
preds = self._prediction
elseif predictions:dim() == 2 then
-- prediction is a matrix of class likelihoods
if predictions:size(2) == 1 then
-- or prediction just needs flattening
preds = self._prediction:select(2,1)
else
self._max:max(self._pred_idx, self._prediction, 2)
preds = self._pred_idx:select(2,1)
end
else
error("predictions has invalid number of dimensions")
end
self._target:resize(targets:size()):copy(targets)
if targets:dim() == 1 then
-- targets is a vector of classes
targs = self._target
elseif targets:dim() == 2 then
-- targets is a matrix of one-hot rows
if targets:size(2) == 1 then
-- or targets just needs flattening
targs = self._target:select(2,1)
else
self._max:max(self._targ_idx, self._target, 2)
targs = self._targ_idx:select(2,1)
end
else
error("targets has invalid number of dimensions")
end
-- non-positive values are considered missing and therefore ignored
local mask = targs:ge(1)
targs = targs[mask]
preds = preds[mask]
self._mat_flat = self._mat_flat or self.mat:view(-1) -- for backward compatibility
preds = preds:typeAs(targs)
assert(self.mat:isContiguous() and self.mat:stride(2) == 1)
local indices = ((targs - 1) * self.mat:stride(1) + preds):typeAs(self.mat)
local ones = torch.ones(1):typeAs(self.mat):expand(indices:size(1))
self._mat_flat:indexAdd(1, indices, ones)
end
function ConfusionMatrix:zero()
self.mat:zero()
self.valids:zero()
self.unionvalids:zero()
self.totalValid = 0
self.averageValid = 0
end
local function isNaN(number)
return number ~= number
end
function ConfusionMatrix:updateValids()
local total = 0
for t = 1,self.nclasses do
self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum()
self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t])
total = total + self.mat[t][t]
end
self.totalValid = total / self.mat:sum()
self.averageValid = 0
self.averageUnionValid = 0
local nvalids = 0
local nunionvalids = 0
for t = 1,self.nclasses do
if not isNaN(self.valids[t]) then
self.averageValid = self.averageValid + self.valids[t]
nvalids = nvalids + 1
end
if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then
self.averageUnionValid = self.averageUnionValid + self.unionvalids[t]
nunionvalids = nunionvalids + 1
end
end
self.averageValid = self.averageValid / nvalids
self.averageUnionValid = self.averageUnionValid / nunionvalids
end
-- Calculating FAR/FRR associated with the confusion matrix
function ConfusionMatrix:farFrr()
local cmat = self.mat
local noOfClasses = cmat:size()[1]
self._frrs = self._frrs or torch.zeros(noOfClasses)
self._frrs:zero()
self._classFrrs = self._classFrrs or torch.zeros(noOfClasses)
self._classFrrs:zero()
self._classFrrs:add(-1)
self._fars = self._fars or torch.zeros(noOfClasses)
self._fars:zero()
self._classFars = self._classFars or torch.zeros(noOfClasses)
self._classFars:zero()
self._classFars:add(-1)
local classSamplesCount = cmat:sum(2)
local indx = 1
for i=1,noOfClasses do
if classSamplesCount[i][1] ~= 0 then
self._frrs[indx] = 1 - cmat[i][i]/classSamplesCount[i][1]
self._classFrrs[i] = self._frrs[indx]
-- Calculating FARs
local farNumerator = 0
local farDenominator = 0
for j=1, noOfClasses do
if i ~= j then
if classSamplesCount[j][1] ~= 0 then
farNumerator = farNumerator + cmat[j][i]/classSamplesCount[j][1]
farDenominator = farDenominator + 1
end
end
end
self._fars[indx] = farNumerator/farDenominator
self._classFars[i] = self._fars[indx]
indx = indx + 1
end
end
indx = indx - 1
local returnFrrs = self._frrs[{{1, indx}}]
local returnFars = self._fars[{{1, indx}}]
return self._classFrrs, self._classFars, returnFrrs, returnFars
end
local function log10(n)
if math.log10 then
return math.log10(n)
else
return math.log(n) / math.log(10)
end
end
function ConfusionMatrix:__tostring__()
self:updateValids()
local str = {'ConfusionMatrix:\n'}
local nclasses = self.nclasses
table.insert(str, '[')
local maxCnt = self.mat:max()
local nDigits = math.max(8, 1 + math.ceil(log10(maxCnt)))
for t = 1,nclasses do
local pclass = self.valids[t] * 100
pclass = string.format('%2.3f', pclass)
if t == 1 then
table.insert(str, '[')
else
table.insert(str, ' [')
end
for p = 1,nclasses do
table.insert(str, string.format('%' .. nDigits .. 'd', self.mat[t][p]))
end
if self.classes and self.classes[1] then
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
else
table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
end
else
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \n')
else
table.insert(str, '] ' .. pclass .. '% \n')
end
end
end
table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n')
table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n')
table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%')
return table.concat(str)
end
function ConfusionMatrix:render(sortmode, display, block, legendwidth)
-- args
local confusion = self.mat:double()
local classes = self.classes
local sortmode = sortmode or 'score' -- 'score' or 'occurrence'
local block = block or 25
local legendwidth = legendwidth or 200
local display = display or false
-- legends
local legend = {
['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]',
['occurrence'] = 'Confusion matrix [sorted by occurrences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]'
}
-- parse matrix / normalize / count scores
local diag = torch.FloatTensor(#classes)
local freqs = torch.FloatTensor(#classes)
local unconf = confusion
local confusion = confusion:clone()
local corrects = 0
local total = 0
for target = 1,#classes do
freqs[target] = confusion[target]:sum()
corrects = corrects + confusion[target][target]
total = total + freqs[target]
confusion[target]:div( math.max(confusion[target]:sum(),1) )
diag[target] = confusion[target][target]
end
-- accuracies
local accuracy = corrects / total * 100
local perclass = 0
local total = 0
for target = 1,#classes do
if confusion[target]:sum() > 0 then
perclass = perclass + diag[target]
total = total + 1
end
end
perclass = perclass / total * 100
freqs:div(unconf:sum())
-- sort matrix
if sortmode == 'score' then
_,order = torch.sort(diag,1,true)
elseif sortmode == 'occurrence' then
_,order = torch.sort(freqs,1,true)
else
error('sort mode must be one of: score | occurrence')
end
-- render matrix
local render = torch.zeros(#classes*block, #classes*block)
for target = 1,#classes do
for prediction = 1,#classes do
render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]]
end
end
-- add grid
for target = 1,#classes do
render[{ {target*block},{} }] = 0.1
render[{ {},{target*block} }] = 0.1
end
-- create rendering
require 'image'
require 'qtwidget'
require 'qttorch'
local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] )
image.display{image=render, win=win1}
-- add legend
for i in ipairs(classes) do
-- background cell
win1:setcolor{r=0,g=0,b=0}
win1:rectangle((#render)[2],(i-1)*block,legendwidth,block)
win1:fill()
-- %
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4
win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2}
win1:moveto((#render)[2]+10,i*block-block/3)
win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100))
-- legend
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = diag[order[i]]*0.8+0.2
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto(120+(#render)[2]+10,i*block-block/3)
win1:show(classes[order[i]])
for j in ipairs(classes) do
-- scores
local score = confusion[order[j]][order[i]]
local gscale = (1-score)*(score*0.8+0.2)
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3)
win1:show(string.format('%02.0f',math.floor(score*100+0.5)))
end
end
-- generate tensor
local t = win1:image():toTensor()
-- display
if display then
image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)}
end
-- return rendering
return t
end
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.