In [None]:
csv2tensor = require 'csv2tensor'

In [None]:
function moving_aver(arr, step)
    assert(arr:dim() == 1, 'accept only 1d')
    local arr_aver = torch.Tensor(arr:size()):fill(0)
    local median = (step-1)/2
    for i = 1, arr:size(1) do
        if i - 1 < median then
            arr_aver[i] = torch.mean(arr[{{1, i+median}}])
        elseif arr:size(1) - i < median then
            arr_aver[i] = torch.mean(arr[{{i-median, -1}}])
        else
            arr_aver[i] = torch.mean(arr[{{i-median, i+median}}])
        end
    end
    return arr_aver
end

In [None]:
function plot_para(prefix,step)
    local prefixs = table.concat(prefix, '-')
    local train_error = {}
    for i, v in pairs(prefix) do
        local train_err, column_names = csv2tensor.load(v.."-para.log")
        table.insert(train_error, train_err)
    end
    train_error = torch.cat(train_error)
    
    local x = torch.range(1,(#train_error)[1])
    local y_aver = moving_aver(train_error, step)

    local Plot = require 'itorch.Plot'
    local plot = Plot():line(x, train_error,'blue','para')
                 :line(x, y_aver,'red','para-moving-aver')
                 :legend(true):title('train para'):draw()
end

In [None]:
function plot_moving_aver(prefix, upper_bound, step)
    local prefixs = table.concat(prefix, '-')
    local train_error = {}
    for i, v in pairs(prefix) do
        local train_err, column_names = csv2tensor.load(v.."-train.log")
        table.insert(train_error, train_err)
    end
    train_error = torch.cat(train_error)
    
    local x = torch.range(1,(#train_error)[1])
    local y1 = torch.clamp(train_error, 0, upper_bound)
    local y1_aver = moving_aver(y1, step)
    
    local Plot = require 'itorch.Plot'
    local plot = Plot():line(x, y1,'blue',prefixs)
                 :line(x, y1_aver,'red',prefixs..'-moving-aver')
                 :legend(true):title('train error'):draw()
end

In [None]:
function plot_1acc(prefix, suffix)
    local suffix = suffix or ''
    local prefixs = table.concat(prefix, '-')
    local train_error = {}
    for i, v in pairs(prefix) do
        local train_err, column_names = csv2tensor.load(v..suffix.."-eval.log", '\t')
        table.insert(train_error, train_err)
    end
    train_error = torch.cat(train_error, 1)
    
    local x = torch.range(1,(#train_error)[1])
    local at_at = train_error[{{},1}]
    local at_lt = train_error[{{},2}]
    local lt_at = train_error[{{},3}]
    local lt_lt = train_error[{{},4}]

    local y_upper = torch.range(1,(#train_error)[1]):fill(1)
    local y_at_acc = torch.cdiv(at_at, at_at+at_lt)
    local y_quan_acc = torch.cdiv(lt_lt, lt_at+lt_lt)
    local y_zhun_acc = torch.cdiv(lt_lt, lt_lt+at_lt)

    local Plot = require 'itorch.Plot'
    local plot = Plot():line(x, y_upper,'black')
                 :line(x, y_at_acc,'blue','at准确率')
                 :line(x, y_quan_acc,'red','查全率')
                 :line(x, y_zhun_acc,'red','查准率')
                 :legend(true):title('acc'):draw()
end

In [None]:
function plot_2acc(prefix)
    local prefixs = table.concat(prefix, '-')
    local function load_data(prefix, suffix)
        local data = {}
        for i, v in pairs(prefix) do
            local d, _ = csv2tensor.load(v.."-"..suffix.."-eval.log", '\t')
            table.insert(data, d)
        end
        return torch.cat(data, 1)
    end
    local function transpose_data(data)
        local at_at = data[{{},1}]
        local at_lt = data[{{},2}]
        local lt_at = data[{{},3}]
        local lt_lt = data[{{},4}]
        local y_at_acc = torch.cdiv(at_at, at_at+at_lt+1e-5)
        local y_quan_acc = torch.cdiv(lt_lt, lt_at+lt_lt+1e-5)
        local y_zhun_acc = torch.cdiv(lt_lt, lt_lt+at_lt+1e-5)
        return y_at_acc, y_quan_acc, y_zhun_acc
    end
    
    local train_acc = load_data(prefix, 'train')
    local vali_acc = load_data(prefix, 'vali')
    print(train_acc,vali_acc)
    
    local x = torch.range(1,(#train_acc)[1])
    local y_upper = torch.range(1,(#train_acc)[1]):fill(1)
    local y_tr_at_acc, y_tr_quan_acc, y_tr_zhun_acc = transpose_data(train_acc)
    local y_va_at_acc, y_va_quan_acc, y_va_zhun_acc = transpose_data(vali_acc)    
    
    local Plot = require 'itorch.Plot'
    local plot = Plot():line(x, y_upper,'white')
                 :line(x, y_tr_at_acc,'orange', '训练AT准确率')
                 :line(x, y_tr_quan_acc,'red','训练查全率')
                 :line(x, y_tr_zhun_acc,'yellow','训练查准率')
                 :line(x, y_va_at_acc,'black', '检查AT准确率')
                 :line(x, y_va_quan_acc,'green','检查查全率')
                 :line(x, y_va_zhun_acc,'blue','检查查准率')
                 :legend(true):title('compare acc'):draw()
end

In [None]:
function get_prefixs(prefix, start, num_of_iter, suffix)
    local suffix = suffix or ''
    local prefixs = {}
    for i = start, start+num_of_iter-1 do
        table.insert(prefixs, prefix .. '-it' .. tostring(i))
    end
    return prefixs
end

In [None]:
-- prefixs = get_prefixs('m8adam-s1', 1, 10)
-- print(prefixs)
-- plot_moving_aver(prefixs, 1, 101)
-- -- plot_para(prefixs, 101)

In [None]:
-- prefixs = get_prefixs('m8adam-s1', 1, 10)
-- -- print(prefixs)
-- plot_2acc(prefixs)

In [None]:
-- prefixs = get_prefixs('m9-s1', 1, 7)
-- print(prefixs)
-- plot_moving_aver(prefixs, 1, 101)
-- -- plot_para(prefixs, 101)

In [None]:
-- prefixs = get_prefixs('m9-s1', 1, 6)
-- -- print(prefixs)
-- plot_2acc(prefixs)

In [None]:
-- prefixs = get_prefixs('m8-s1', 1+20, 40-20)
-- print(prefixs)
-- plot_moving_aver(prefixs, 1, 101)
-- -- plot_para(prefixs, 101)

In [None]:
-- prefixs = get_prefixs('m8-s1', 1, 40)
-- -- print(prefixs)
-- plot_2acc(prefixs)

In [None]:
-- prefixs = get_prefixs('m6-s1', 2, 9)
-- print(prefixs)
-- plot_1acc(prefixs, '-vali')

In [None]:
prefixs = get_prefixs('m5-s1', 1, 10)
-- print(prefixs)
plot_moving_aver(prefixs, 1, 101)
plot_2acc(prefixs)

In [None]:
prefixs = get_prefixs('m2-s1', 1, 10)
-- print(prefixs)
plot_moving_aver(prefixs, 1, 101)
plot_2acc(prefixs)

In [None]:
prefixs = get_prefixs('m3-s1', 1, 10)
-- print(prefixs)
plot_moving_aver(prefixs, 1, 101)
plot_2acc(prefixs)

In [None]:
prefixs = get_prefixs('m4-s1', 1, 10)
-- print(prefixs)
plot_moving_aver(prefixs, 1, 101)
plot_2acc(prefixs)

In [None]:
prefixs = get_prefixs('m1-s1', 1, 10)
-- print(prefixs)
plot_moving_aver(prefixs, 1, 101)
plot_2acc(prefixs)